Commit 9e61a9c6 authored by Joel Martin's avatar Joel Martin

C wsproxy: seq numbers and decode multiple frames.

parent 8e1aa95b
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#include <openssl/err.h> #include <openssl/err.h>
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <resolv.h> /* base64 encode/decode */
#include "websocket.h" #include "websocket.h"
const char server_handshake[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\ const char server_handshake[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
...@@ -27,6 +28,16 @@ WebSocket-Protocol: sample\r\n\ ...@@ -27,6 +28,16 @@ WebSocket-Protocol: sample\r\n\
const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\"*\" to-ports=\"*\" /></cross-domain-policy>\n"; const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\"*\" to-ports=\"*\" /></cross-domain-policy>\n";
/*
* Global state
*
* Warning: not thread safe
*/
int ssl_initialized = 0;
char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
unsigned int bufsize, dbufsize;
client_settings_t client_settings;
void traffic(char * token) { void traffic(char * token) {
fprintf(stdout, "%s", token); fprintf(stdout, "%s", token);
fflush(stdout); fflush(stdout);
...@@ -47,9 +58,6 @@ void fatal(char *msg) ...@@ -47,9 +58,6 @@ void fatal(char *msg)
* SSL Wrapper Code * SSL Wrapper Code
*/ */
/* Warning: not thread safe */
int ssl_initialized = 0;
ssize_t ws_recv(ws_ctx_t *ctx, void *buf, size_t len) { ssize_t ws_recv(ws_ctx_t *ctx, void *buf, size_t len) {
if (ctx->ssl) { if (ctx->ssl) {
//printf("SSL recv\n"); //printf("SSL recv\n");
...@@ -147,7 +155,56 @@ int ws_socket_free(ws_ctx_t *ctx) { ...@@ -147,7 +155,56 @@ int ws_socket_free(ws_ctx_t *ctx) {
/* ------------------------------------------------------- */ /* ------------------------------------------------------- */
ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) { int encode(u_char const *src, size_t srclength, char *target, size_t targsize) {
int sz = 0, len = 0;
target[sz++] = '\x00';
if (client_settings.do_seq_num) {
sz += sprintf(target+sz, "%d:", client_settings.seq_num);
client_settings.seq_num++;
}
if (client_settings.do_b64encode) {
len = __b64_ntop(src, srclength, target+sz, targsize-sz);
} else {
fatal("UTF-8 not yet implemented");
}
if (len < 0) {
return len;
}
sz += len;
target[sz++] = '\xff';
return sz;
}
int decode(char *src, size_t srclength, u_char *target, size_t targsize) {
char *start, *end;
int len, retlen = 0;
if ((src[0] != '\x00') || (src[srclength-1] != '\xff')) {
fprintf(stderr, "WebSocket framing error\n");
return -1;
}
start = src+1; // Skip '\x00' start
do {
/* We may have more than one frame */
end = strchr(start, '\xff');
if (end < (src+srclength-1)) {
printf("More than one frame to decode\n");
}
*end = '\x00';
if (client_settings.do_b64encode) {
len = __b64_pton(start, target+retlen, targsize-retlen);
} else {
fatal("UTF-8 not yet implemented");
}
if (len < 0) {
return len;
}
retlen += len;
start = end + 2; // Skip '\xff' end and '\x00' start
} while (end < (src+srclength-1));
return retlen;
}
ws_ctx_t *do_handshake(int sock) {
char handshake[4096], response[4096]; char handshake[4096], response[4096];
char *scheme, *line, *path, *host, *origin; char *scheme, *line, *path, *host, *origin;
char *args_start, *args_end, *arg_idx; char *args_start, *args_end, *arg_idx;
...@@ -155,8 +212,9 @@ ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) { ...@@ -155,8 +212,9 @@ ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) {
ws_ctx_t * ws_ctx; ws_ctx_t * ws_ctx;
// Reset settings // Reset settings
client_settings->b64encode = 0; client_settings.do_b64encode = 0;
client_settings->seq_num = 0; client_settings.do_seq_num = 0;
client_settings.seq_num = 0;
len = recv(sock, handshake, 1024, MSG_PEEK); len = recv(sock, handshake, 1024, MSG_PEEK);
handshake[len] = 0; handshake[len] = 0;
...@@ -211,12 +269,12 @@ ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) { ...@@ -211,12 +269,12 @@ ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) {
arg_idx = strstr(args_start, "b64encode"); arg_idx = strstr(args_start, "b64encode");
if (arg_idx && arg_idx < args_end) { if (arg_idx && arg_idx < args_end) {
//printf("setting b64encode\n"); //printf("setting b64encode\n");
client_settings->b64encode = 1; client_settings.do_b64encode = 1;
} }
arg_idx = strstr(args_start, "seq_num"); arg_idx = strstr(args_start, "seq_num");
if (arg_idx && arg_idx < args_end) { if (arg_idx && arg_idx < args_end) {
//printf("setting seq_num\n"); //printf("setting seq_num\n");
client_settings->seq_num = 1; client_settings.do_seq_num = 1;
} }
} }
...@@ -228,12 +286,22 @@ ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) { ...@@ -228,12 +286,22 @@ ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) {
} }
void start_server(int listen_port, void start_server(int listen_port,
void (*handler)(ws_ctx_t*), void (*handler)(ws_ctx_t*)) {
client_settings_t *client_settings) {
int lsock, csock, clilen, sopt = 1; int lsock, csock, clilen, sopt = 1;
struct sockaddr_in serv_addr, cli_addr; struct sockaddr_in serv_addr, cli_addr;
ws_ctx_t *ws_ctx; ws_ctx_t *ws_ctx;
/* Initialize buffers */
bufsize = 65536;
if (! (tbuf = malloc(bufsize)) )
{ fatal("malloc()"); }
if (! (cbuf = malloc(bufsize)) )
{ fatal("malloc()"); }
if (! (tbuf_tmp = malloc(bufsize)) )
{ fatal("malloc()"); }
if (! (cbuf_tmp = malloc(bufsize)) )
{ fatal("malloc()"); }
lsock = socket(AF_INET, SOCK_STREAM, 0); lsock = socket(AF_INET, SOCK_STREAM, 0);
if (lsock < 0) { error("ERROR creating listener socket"); } if (lsock < 0) { error("ERROR creating listener socket"); }
bzero((char *) &serv_addr, sizeof(serv_addr)); bzero((char *) &serv_addr, sizeof(serv_addr));
...@@ -256,8 +324,23 @@ void start_server(int listen_port, ...@@ -256,8 +324,23 @@ void start_server(int listen_port,
error("ERROR on accept"); error("ERROR on accept");
} }
printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr)); printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr));
ws_ctx = do_handshake(csock, client_settings); ws_ctx = do_handshake(csock);
if (ws_ctx == NULL) { continue; } if (ws_ctx == NULL) {
close(csock);
continue;
}
/* Calculate dbufsize based on client_settings */
if (client_settings.do_b64encode) {
/* base64 is 4 bytes for every 3
* 20 for WS '\x00' / '\xff', seq_num and good measure */
dbufsize = (bufsize * 3)/4 - 20;
} else {
fatal("UTF-8 not yet implemented");
/* UTF-8 encoding is up to 2X larger */
dbufsize = (bufsize/2) - 15;
}
handler(ws_ctx); handler(ws_ctx);
close(csock); close(csock);
} }
......
...@@ -7,7 +7,8 @@ typedef struct { ...@@ -7,7 +7,8 @@ typedef struct {
} ws_ctx_t; } ws_ctx_t;
typedef struct { typedef struct {
int b64encode; int do_b64encode;
int do_seq_num;
int seq_num; int seq_num;
} client_settings_t; } client_settings_t;
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include <netinet/in.h> #include <netinet/in.h>
#include <netdb.h> #include <netdb.h>
#include <sys/select.h> #include <sys/select.h>
#include <resolv.h>
#include <fcntl.h> #include <fcntl.h>
#include <sys/stat.h> #include <sys/stat.h>
#include "websocket.h" #include "websocket.h"
...@@ -35,23 +34,21 @@ void usage() { ...@@ -35,23 +34,21 @@ void usage() {
char *target_host; char *target_host;
int target_port; int target_port;
client_settings_t client_settings;
char *record_filename = NULL; char *record_filename = NULL;
int recordfd = 0; int recordfd = 0;
char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
unsigned int bufsize, dbufsize; extern char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
extern unsigned int bufsize, dbufsize;
void do_proxy(ws_ctx_t *ws_ctx, int target) { void do_proxy(ws_ctx_t *ws_ctx, int target) {
fd_set rlist, wlist, elist; fd_set rlist, wlist, elist;
struct timeval tv; struct timeval tv;
int maxfd, client = ws_ctx->sockfd; int i, maxfd, client = ws_ctx->sockfd;
unsigned int tstart, tend, cstart, cend, ret; unsigned int tstart, tend, cstart, cend, ret;
ssize_t len, bytes; ssize_t len, bytes;
tstart = tend = cstart = cend = 0; tstart = tend = cstart = cend = 0;
maxfd = client > target ? client+1 : target+1; maxfd = client > target ? client+1 : target+1;
// Account for base64 encoding and WebSocket delims:
// 49150 = 65536 * 3/4 + 2 - 1
while (1) { while (1) {
tv.tv_sec = 1; tv.tv_sec = 1;
...@@ -137,18 +134,22 @@ void do_proxy(ws_ctx_t *ws_ctx, int target) { ...@@ -137,18 +134,22 @@ void do_proxy(ws_ctx_t *ws_ctx, int target) {
if (FD_ISSET(target, &rlist)) { if (FD_ISSET(target, &rlist)) {
bytes = recv(target, cbuf_tmp, dbufsize , 0); bytes = recv(target, cbuf_tmp, dbufsize , 0);
if (bytes <= 0) { if (bytes <= 0) {
error("target closed connection"); fprintf(stderr, "target closed connection");
break; break;
} }
cbuf[0] = '\x00';
cstart = 0; cstart = 0;
len = b64_ntop(cbuf_tmp, bytes, cbuf+1, bufsize-1); cend = encode(cbuf_tmp, bytes, cbuf, bufsize);
if (len < 0) { /*
fprintf(stderr, "base64 encoding error\n"); printf("encoded: ");
for (i=0; i< bytes; i++) {
printf("%d,", *(cbuf+i));
}
printf("\n");
*/
if (cend < 0) {
fprintf(stderr, "encoding error\n");
break; break;
} }
cbuf[len+1] = '\xff';
cend = len+1+1;
traffic("{"); traffic("{");
} }
...@@ -158,20 +159,21 @@ void do_proxy(ws_ctx_t *ws_ctx, int target) { ...@@ -158,20 +159,21 @@ void do_proxy(ws_ctx_t *ws_ctx, int target) {
fprintf(stderr, "client closed connection\n"); fprintf(stderr, "client closed connection\n");
break; break;
} }
if (tbuf_tmp[bytes-1] != '\xff') {
//traffic(".}");
fprintf(stderr, "Malformed packet\n");
break;
}
if (recordfd) { if (recordfd) {
write(recordfd, "'", 1); write(recordfd, "'", 1);
write(recordfd, tbuf_tmp + 1, bytes - 2); write(recordfd, tbuf_tmp + 1, bytes - 2);
write(recordfd, "',\n", 3); write(recordfd, "',\n", 3);
} }
tbuf_tmp[bytes-1] = '\0'; len = decode(tbuf_tmp, bytes, tbuf, bufsize-1);
len = b64_pton(tbuf_tmp+1, tbuf, bufsize-1); /*
printf("decoded: ");
for (i=0; i< bytes; i++) {
printf("%d,", *(tbuf+i));
}
printf("\n");
*/
if (len < 0) { if (len < 0) {
fprintf(stderr, "base64 decoding error\n"); fprintf(stderr, "decoding error\n");
break; break;
} }
traffic("}"); traffic("}");
...@@ -188,11 +190,6 @@ void proxy_handler(ws_ctx_t *ws_ctx) { ...@@ -188,11 +190,6 @@ void proxy_handler(ws_ctx_t *ws_ctx) {
printf("Connecting to: %s:%d\n", target_host, target_port); printf("Connecting to: %s:%d\n", target_host, target_port);
if (client_settings.b64encode) {
dbufsize = (bufsize * 3)/4 + 2 - 10; // padding and for good measure
} else {
}
tsock = socket(AF_INET, SOCK_STREAM, 0); tsock = socket(AF_INET, SOCK_STREAM, 0);
if (tsock < 0) { if (tsock < 0) {
error("Could not create target socket"); error("Could not create target socket");
...@@ -260,7 +257,7 @@ int main(int argc, char *argv[]) ...@@ -260,7 +257,7 @@ int main(int argc, char *argv[])
if (! (cbuf_tmp = malloc(bufsize)) ) if (! (cbuf_tmp = malloc(bufsize)) )
{ fatal("malloc()"); } { fatal("malloc()"); }
start_server(listen_port, &proxy_handler, &client_settings); start_server(listen_port, &proxy_handler);
free(tbuf); free(tbuf);
free(cbuf); free(cbuf);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment