Commit 6ee61a4c authored by Joel Martin's avatar Joel Martin

Add daemonization support to wsproxy.*.

Refactor how settings are passed around.
parent b2fd1bc3
...@@ -159,4 +159,7 @@ if __name__ == '__main__': ...@@ -159,4 +159,7 @@ if __name__ == '__main__':
for i in range(0, 100000): for i in range(0, 100000):
rand_array.append(random.randint(0, 9)) rand_array.append(random.randint(0, 9))
start_server(listen_port, test_handler) settings['listen_port'] = listen_port
settings['daemon'] = False
settings['handler'] = test_handler
start_server()
...@@ -81,4 +81,7 @@ if __name__ == '__main__': ...@@ -81,4 +81,7 @@ if __name__ == '__main__':
print "Usage: <listen_port>" print "Usage: <listen_port>"
sys.exit(1) sys.exit(1)
start_server(listen_port, responder) settings['listen_port'] = listen_port
settings['daemon'] = False
settings['handler'] = responder
start_server()
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include <netinet/in.h> #include <netinet/in.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netdb.h> #include <netdb.h>
#include <signal.h> // daemonizing
#include <fcntl.h> // daemonizing
#include <openssl/err.h> #include <openssl/err.h>
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <resolv.h> /* base64 encode/decode */ #include <resolv.h> /* base64 encode/decode */
...@@ -37,6 +39,7 @@ const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\ ...@@ -37,6 +39,7 @@ const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\
int ssl_initialized = 0; int ssl_initialized = 0;
char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp; char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
unsigned int bufsize, dbufsize; unsigned int bufsize, dbufsize;
settings_t settings;
client_settings_t client_settings; client_settings_t client_settings;
void traffic(char * token) { void traffic(char * token) {
...@@ -269,7 +272,7 @@ int decode(char *src, size_t srclength, u_char *target, size_t targsize) { ...@@ -269,7 +272,7 @@ int decode(char *src, size_t srclength, u_char *target, size_t targsize) {
return retlen; return retlen;
} }
ws_ctx_t *do_handshake(int sock, int ssl_only) { ws_ctx_t *do_handshake(int sock) {
char handshake[4096], response[4096]; char handshake[4096], response[4096];
char *scheme, *line, *path, *host, *origin; char *scheme, *line, *path, *host, *origin;
char *args_start, *args_end, *arg_idx; char *args_start, *args_end, *arg_idx;
...@@ -281,6 +284,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) { ...@@ -281,6 +284,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
client_settings.do_seq_num = 0; client_settings.do_seq_num = 0;
client_settings.seq_num = 0; client_settings.seq_num = 0;
// Peek, but don't read the data
len = recv(sock, handshake, 1024, MSG_PEEK); len = recv(sock, handshake, 1024, MSG_PEEK);
handshake[len] = 0; handshake[len] = 0;
if (bcmp(handshake, "<policy-file-request/>", 22) == 0) { if (bcmp(handshake, "<policy-file-request/>", 22) == 0) {
...@@ -292,11 +296,11 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) { ...@@ -292,11 +296,11 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
return NULL; return NULL;
} else if (bcmp(handshake, "\x16", 1) == 0) { } else if (bcmp(handshake, "\x16", 1) == 0) {
// SSL // SSL
ws_ctx = ws_socket_ssl(sock, "self.pem"); ws_ctx = ws_socket_ssl(sock, settings.cert);
if (! ws_ctx) { return NULL; } if (! ws_ctx) { return NULL; }
scheme = "wss"; scheme = "wss";
printf("Using SSL socket\n"); printf(" using SSL socket\n");
} else if (ssl_only) { } else if (settings.ssl_only) {
printf("Non-SSL connection disallowed"); printf("Non-SSL connection disallowed");
close(sock); close(sock);
return NULL; return NULL;
...@@ -304,7 +308,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) { ...@@ -304,7 +308,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
ws_ctx = ws_socket(sock); ws_ctx = ws_socket(sock);
if (! ws_ctx) { return NULL; } if (! ws_ctx) { return NULL; }
scheme = "ws"; scheme = "ws";
printf("Using plain (not SSL) socket\n"); printf(" using plain (not SSL) socket\n");
} }
len = ws_recv(ws_ctx, handshake, 4096); len = ws_recv(ws_ctx, handshake, 4096);
handshake[len] = 0; handshake[len] = 0;
...@@ -327,7 +331,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) { ...@@ -327,7 +331,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
//printf("host: %s\n", host); //printf("host: %s\n", host);
//printf("origin: %s\n", origin); //printf("origin: %s\n", origin);
// TODO: parse out client settings // Parse client settings from the GET path
args_start = strstr(path, "?"); args_start = strstr(path, "?");
if (args_start) { if (args_start) {
if (strstr(args_start, "#")) { if (strstr(args_start, "#")) {
...@@ -337,31 +341,70 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) { ...@@ -337,31 +341,70 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
} }
arg_idx = strstr(args_start, "b64encode"); arg_idx = strstr(args_start, "b64encode");
if (arg_idx && arg_idx < args_end) { if (arg_idx && arg_idx < args_end) {
//printf("setting b64encode\n"); printf(" b64encode=1\n");
client_settings.do_b64encode = 1; client_settings.do_b64encode = 1;
} }
arg_idx = strstr(args_start, "seq_num"); arg_idx = strstr(args_start, "seq_num");
if (arg_idx && arg_idx < args_end) { if (arg_idx && arg_idx < args_end) {
//printf("setting seq_num\n"); printf(" seq_num=1\n");
client_settings.do_seq_num = 1; client_settings.do_seq_num = 1;
} }
} }
sprintf(response, server_handshake, origin, scheme, host, path); sprintf(response, server_handshake, origin, scheme, host, path);
printf("response: %s\n", response); //printf("response: %s\n", response);
ws_send(ws_ctx, response, strlen(response)); ws_send(ws_ctx, response, strlen(response));
return ws_ctx; return ws_ctx;
} }
void start_server(int listen_port, void signal_handler(sig) {
void (*handler)(ws_ctx_t*), switch (sig) {
char *listen_host, case SIGHUP: break; // ignore
int ssl_only) { case SIGTERM: exit(0); break;
}
}
void daemonize() {
int pid, i;
umask(0);
chdir('/');
setgid(getgid());
setuid(getuid());
/* Double fork to daemonize */
pid = fork();
if (pid<0) { fatal("fork error"); }
if (pid>0) { exit(0); } // parent exits
setsid(); // Obtain new process group
pid = fork();
if (pid<0) { fatal("fork error"); }
if (pid>0) { exit(0); } // parent exits
/* Signal handling */
signal(SIGHUP, signal_handler); // catch HUP
signal(SIGTERM, signal_handler); // catch kill
/* Close open files */
for (i=getdtablesize(); i>=0; --i) {
close(i);
}
i=open("/dev/null", O_RDWR); // Redirect stdin
dup(i); // Redirect stdout
dup(i); // Redirect stderr
}
void start_server() {
int lsock, csock, clilen, sopt = 1, i; int lsock, csock, clilen, sopt = 1, i;
struct sockaddr_in serv_addr, cli_addr; struct sockaddr_in serv_addr, cli_addr;
ws_ctx_t *ws_ctx; ws_ctx_t *ws_ctx;
if (settings.daemon) {
daemonize();
}
/* Initialize buffers */ /* Initialize buffers */
bufsize = 65536; bufsize = 65536;
if (! (tbuf = malloc(bufsize)) ) if (! (tbuf = malloc(bufsize)) )
...@@ -377,15 +420,15 @@ void start_server(int listen_port, ...@@ -377,15 +420,15 @@ void start_server(int listen_port,
if (lsock < 0) { error("ERROR creating listener socket"); } if (lsock < 0) { error("ERROR creating listener socket"); }
bzero((char *) &serv_addr, sizeof(serv_addr)); bzero((char *) &serv_addr, sizeof(serv_addr));
serv_addr.sin_family = AF_INET; serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(listen_port); serv_addr.sin_port = htons(settings.listen_port);
/* Resolve listen address */ /* Resolve listen address */
if ((listen_host == NULL) || (listen_host[0] == '\0')) { if (settings.listen_host && (settings.listen_host[0] != '\0')) {
serv_addr.sin_addr.s_addr = INADDR_ANY; if (resolve_host(&serv_addr.sin_addr, settings.listen_host) < -1) {
} else {
if (resolve_host(&serv_addr.sin_addr, listen_host) < -1) {
fatal("Could not resolve listen address"); fatal("Could not resolve listen address");
} }
} else {
serv_addr.sin_addr.s_addr = INADDR_ANY;
} }
setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, (char *)&sopt, sizeof(sopt)); setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, (char *)&sopt, sizeof(sopt));
...@@ -396,10 +439,12 @@ void start_server(int listen_port, ...@@ -396,10 +439,12 @@ void start_server(int listen_port,
while (1) { while (1) {
clilen = sizeof(cli_addr); clilen = sizeof(cli_addr);
if (listen_host) { if (settings.listen_host && settings.listen_host[0] != '\0') {
printf("waiting for connection on %s:%d\n", listen_host, listen_port); printf("waiting for connection on %s:%d\n",
settings.listen_host, settings.listen_port);
} else { } else {
printf("waiting for connection on port %d\n", listen_port); printf("waiting for connection on port %d\n",
settings.listen_port);
} }
csock = accept(lsock, csock = accept(lsock,
(struct sockaddr *) &cli_addr, (struct sockaddr *) &cli_addr,
...@@ -409,7 +454,7 @@ void start_server(int listen_port, ...@@ -409,7 +454,7 @@ void start_server(int listen_port,
continue; continue;
} }
printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr)); printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr));
ws_ctx = do_handshake(csock, ssl_only); ws_ctx = do_handshake(csock);
if (ws_ctx == NULL) { if (ws_ctx == NULL) {
close(csock); close(csock);
continue; continue;
...@@ -425,7 +470,7 @@ void start_server(int listen_port, ...@@ -425,7 +470,7 @@ void start_server(int listen_port,
dbufsize = (bufsize/2) - 20; dbufsize = (bufsize/2) - 20;
} }
handler(ws_ctx); settings.handler(ws_ctx);
close(csock); close(csock);
} }
......
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <unistd.h>
typedef struct { typedef struct {
int sockfd; int sockfd;
...@@ -6,6 +7,16 @@ typedef struct { ...@@ -6,6 +7,16 @@ typedef struct {
SSL *ssl; SSL *ssl;
} ws_ctx_t; } ws_ctx_t;
typedef struct {
char listen_host[256];
int listen_port;
void (*handler)(ws_ctx_t*);
int ssl_only;
int daemon;
char record[1024];
char cert[1024];
} settings_t;
typedef struct { typedef struct {
int do_b64encode; int do_b64encode;
int do_seq_num; int do_seq_num;
......
...@@ -10,9 +10,21 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ...@@ -10,9 +10,21 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import sys, socket, ssl, traceback import sys, socket, ssl, traceback
import os, resource, errno, signal # daemonizing
from base64 import b64encode, b64decode from base64 import b64encode, b64decode
client_settings = {} settings = {
'listen_host' : '',
'listen_port' : None,
'handler' : None,
'cert' : None,
'ssl_only' : False,
'daemon' : True,
'record' : None, }
client_settings = {
'b64encode' : False,
'seq_num' : False, }
send_seq = 0 send_seq = 0
server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
...@@ -33,35 +45,39 @@ def traffic(token="."): ...@@ -33,35 +45,39 @@ def traffic(token="."):
def decode(buf): def decode(buf):
""" Parse out WebSocket packets. """ """ Parse out WebSocket packets. """
if buf.count('\xff') > 1: if buf.count('\xff') > 1:
if client_settings["b64encode"]: if client_settings['b64encode']:
return [b64decode(d[1:]) for d in buf.split('\xff')] return [b64decode(d[1:]) for d in buf.split('\xff')]
else: else:
# Modified UTF-8 decode # Modified UTF-8 decode
return [d[1:].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1') for d in buf.split('\xff')] return [d[1:].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1') for d in buf.split('\xff')]
else: else:
if client_settings["b64encode"]: if client_settings['b64encode']:
return [b64decode(buf[1:-1])] return [b64decode(buf[1:-1])]
else: else:
return [buf[1:-1].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1')] return [buf[1:-1].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1')]
def encode(buf): def encode(buf):
global send_seq global send_seq
if client_settings["b64encode"]: if client_settings['b64encode']:
buf = b64encode(buf) buf = b64encode(buf)
else: else:
# Modified UTF-8 encode # Modified UTF-8 encode
buf = buf.decode('latin-1').encode('utf-8').replace("\x00", "\xc4\x80") buf = buf.decode('latin-1').encode('utf-8').replace("\x00", "\xc4\x80")
if client_settings["seq_num"]: if client_settings['seq_num']:
send_seq += 1 send_seq += 1
return "\x00%d:%s\xff" % (send_seq-1, buf) return "\x00%d:%s\xff" % (send_seq-1, buf)
else: else:
return "\x00%s\xff" % buf return "\x00%s\xff" % buf
def do_handshake(sock, ssl_only=False): def do_handshake(sock):
global client_settings, send_seq global client_settings, send_seq
client_settings['b64encode'] = False
client_settings['seq_num'] = False
send_seq = 0 send_seq = 0
# Peek, but don't read the data # Peek, but don't read the data
handshake = sock.recv(1024, socket.MSG_PEEK) handshake = sock.recv(1024, socket.MSG_PEEK)
#print "Handshake [%s]" % repr(handshake) #print "Handshake [%s]" % repr(handshake)
...@@ -75,54 +91,88 @@ def do_handshake(sock, ssl_only=False): ...@@ -75,54 +91,88 @@ def do_handshake(sock, ssl_only=False):
retsock = ssl.wrap_socket( retsock = ssl.wrap_socket(
sock, sock,
server_side=True, server_side=True,
certfile='self.pem', certfile=settings['cert'],
ssl_version=ssl.PROTOCOL_TLSv1) ssl_version=ssl.PROTOCOL_TLSv1)
scheme = "wss" scheme = "wss"
print "Using SSL/TLS" print " using SSL/TLS"
elif ssl_only: elif settings['ssl_only']:
print "Non-SSL connection disallowed" print "Non-SSL connection disallowed"
sock.close() sock.close()
return False return False
else: else:
retsock = sock retsock = sock
scheme = "ws" scheme = "ws"
print "Using plain (not SSL) socket" print " using plain (not SSL) socket"
handshake = retsock.recv(4096) handshake = retsock.recv(4096)
req_lines = handshake.split("\r\n") req_lines = handshake.split("\r\n")
_, path, _ = req_lines[0].split(" ") _, path, _ = req_lines[0].split(" ")
_, origin = req_lines[4].split(" ") _, origin = req_lines[4].split(" ")
_, host = req_lines[3].split(" ") _, host = req_lines[3].split(" ")
# Parse settings from the path # Parse client settings from the GET path
cvars = path.partition('?')[2].partition('#')[0].split('&') cvars = path.partition('?')[2].partition('#')[0].split('&')
client_settings = {'b64encode': None, 'seq_num': None}
for cvar in [c for c in cvars if c]: for cvar in [c for c in cvars if c]:
name, _, value = cvar.partition('=') name, _, val = cvar.partition('=')
client_settings[name] = value and value or True if name not in ['b64encode', 'seq_num']: continue
value = val and val or True
print "client_settings:", client_settings client_settings[name] = value
print " %s=%s" % (name, value)
retsock.send(server_handshake % (origin, scheme, host, path)) retsock.send(server_handshake % (origin, scheme, host, path))
return retsock return retsock
def start_server(listen_port, handler, listen_host='', ssl_only=False): def daemonize():
os.umask(0)
os.chdir('/')
os.setgid(os.getgid()) # relinquish elevations
os.setuid(os.getuid()) # relinquish elevations
# Double fork to daemonize
if os.fork() > 0: os._exit(0) # Parent exits
os.setsid() # Obtain new process group
if os.fork() > 0: os._exit(0) # Parent exits
# Signal handling
def terminate(a,b): os._exit(0)
signal.signal(signal.SIGTERM, terminate)
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Close open files
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
if maxfd == resource.RLIM_INFINITY: maxfd = 256
for fd in reversed(range(maxfd)):
try:
os.close(fd)
except OSError, exc:
if exc.errno != errno.EBADF: raise
# Redirect I/O to /dev/null
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
def start_server():
if settings['daemon']: daemonize()
lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
lsock.bind((listen_host, listen_port)) lsock.bind((settings['listen_host'], settings['listen_port']))
lsock.listen(100) lsock.listen(100)
while True: while True:
try: try:
csock = None csock = startsock = None
print 'waiting for connection on port %s' % listen_port print 'waiting for connection on port %s' % settings['listen_port']
startsock, address = lsock.accept() startsock, address = lsock.accept()
print 'Got client connection from %s' % address[0] print 'Got client connection from %s' % address[0]
csock = do_handshake(startsock, ssl_only=ssl_only) csock = do_handshake(startsock)
if not csock: continue if not csock: continue
handler(csock) settings['handler'](csock)
except Exception: except Exception:
print "Ignoring exception:" print "Ignoring exception:"
print traceback.format_exc() print traceback.format_exc()
if csock: csock.close() if csock: csock.close()
if startsock and startsock != csock: startsock.close()
...@@ -36,11 +36,11 @@ void usage() { ...@@ -36,11 +36,11 @@ void usage() {
exit(1); exit(1);
} }
char *target_host; char target_host[256];
int target_port; int target_port;
char *record_filename = NULL;
int recordfd = 0; int recordfd = 0;
extern settings_t settings;
extern char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp; extern char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
extern unsigned int bufsize, dbufsize; extern unsigned int bufsize, dbufsize;
...@@ -198,6 +198,11 @@ void proxy_handler(ws_ctx_t *ws_ctx) { ...@@ -198,6 +198,11 @@ void proxy_handler(ws_ctx_t *ws_ctx) {
int tsock = 0; int tsock = 0;
struct sockaddr_in taddr; struct sockaddr_in taddr;
if (settings.record) {
recordfd = open(settings.record, O_WRONLY | O_CREAT | O_TRUNC,
S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
}
printf("Connecting to: %s:%d\n", target_host, target_port); printf("Connecting to: %s:%d\n", target_host, target_port);
tsock = socket(AF_INET, SOCK_STREAM, 0); tsock = socket(AF_INET, SOCK_STREAM, 0);
...@@ -220,11 +225,6 @@ void proxy_handler(ws_ctx_t *ws_ctx) { ...@@ -220,11 +225,6 @@ void proxy_handler(ws_ctx_t *ws_ctx) {
return; return;
} }
if (record_filename) {
recordfd = open(record_filename, O_WRONLY | O_CREAT | O_TRUNC,
S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
}
printf("%s", traffic_legend); printf("%s", traffic_legend);
do_proxy(ws_ctx, tsock); do_proxy(ws_ctx, tsock);
...@@ -239,52 +239,74 @@ void proxy_handler(ws_ctx_t *ws_ctx) { ...@@ -239,52 +239,74 @@ void proxy_handler(ws_ctx_t *ws_ctx) {
int main(int argc, char *argv[]) int main(int argc, char *argv[])
{ {
int listen_port, c, option_index = 0; int listen_port, c, option_index = 0;
static int ssl_only = 0; static int ssl_only = 0, foreground = 0;
char *listen_host; char *found;
static struct option long_options[] = { static struct option long_options[] = {
{"ssl-only", no_argument, &ssl_only, 1}, {"ssl-only", no_argument, &ssl_only, 1 },
{"foreground", no_argument, &foreground, 'f'},
/* ---- */ /* ---- */
{"record", required_argument, 0, 'r'}, {"record", required_argument, 0, 'r'},
{"cert", required_argument, 0, 'c'},
{0, 0, 0, 0} {0, 0, 0, 0}
}; };
settings.record[0] = '\0';
strcpy(settings.cert, "self.pem");
while (1) { while (1) {
c = getopt_long (argc, argv, "r:", c = getopt_long (argc, argv, "fr:c:",
long_options, &option_index); long_options, &option_index);
/* Detect the end */ /* Detect the end */
if (c == -1) { break; } if (c == -1) { break; }
switch (c) { switch (c) {
case 0: break; // ignore case 0:
case 1: break; // ignore break; // ignore
case 'r': record_filename = optarg; break; case 1:
default: usage(); break; // ignore
case 'f':
foreground = 1;
break;
case 'r':
memcpy(settings.record, optarg, sizeof(settings.record));
break;
case 'c':
memcpy(settings.cert, optarg, sizeof(settings.cert));
break;
default:
usage();
} }
} }
settings.ssl_only = ssl_only;
settings.daemon = foreground ? 0: 1;
printf("ssl_only: %d\n", ssl_only); printf(" ssl_only: %d\n", settings.ssl_only);
printf("record_filename: %s\n", record_filename); printf(" daemon: %d\n", settings.daemon);
printf(" record: %s\n", settings.record);
printf(" cert: %s\n", settings.cert);
if ((argc-optind) != 2) { if ((argc-optind) != 2) {
usage(); usage();
} }
if (strstr(argv[optind], ":")) { found = strstr(argv[optind], ":");
listen_host = strtok(argv[optind], ":"); if (found) {
listen_port = strtol(strtok(NULL, ":"), NULL, 10); memcpy(settings.listen_host, argv[optind], found-argv[optind]);
settings.listen_port = strtol(found+1, NULL, 10);
} else { } else {
listen_host = NULL; settings.listen_host[0] = '\0';
listen_port = strtol(argv[optind], NULL, 10); settings.listen_port = strtol(argv[optind], NULL, 10);
} }
optind++; optind++;
if ((errno != 0) || (listen_port == 0)) { if ((errno != 0) || (listen_port == 0)) {
usage(); usage();
} }
if (strstr(argv[optind], ":")) { found = strstr(argv[optind], ":");
target_host = strtok(argv[optind], ":"); if (found) {
target_port = strtol(strtok(NULL, ":"), NULL, 10); memcpy(target_host, argv[optind], found-argv[optind]);
target_port = strtol(found+1, NULL, 10);
} else { } else {
usage(); usage();
} }
...@@ -303,7 +325,8 @@ int main(int argc, char *argv[]) ...@@ -303,7 +325,8 @@ int main(int argc, char *argv[])
if (! (cbuf_tmp = malloc(bufsize)) ) if (! (cbuf_tmp = malloc(bufsize)) )
{ fatal("malloc()"); } { fatal("malloc()"); }
start_server(listen_port, &proxy_handler, listen_host, ssl_only); settings.handler = proxy_handler;
start_server();
free(tbuf); free(tbuf);
free(cbuf); free(cbuf);
......
...@@ -99,14 +99,14 @@ def do_proxy(client, target): ...@@ -99,14 +99,14 @@ def do_proxy(client, target):
def proxy_handler(client): def proxy_handler(client):
global target_host, target_port, options, rec global target_host, target_port, options, rec
if settings['record']:
print "Opening record file: %s" % settings['record']
rec = open(settings['record'], 'w')
print "Connecting to: %s:%s" % (target_host, target_port) print "Connecting to: %s:%s" % (target_host, target_port)
tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tsock.connect((target_host, target_port)) tsock.connect((target_host, target_port))
if options.record:
print "Opening record file: %s" % options.record
rec = open(options.record, 'w')
print traffic_legend print traffic_legend
try: try:
...@@ -122,25 +122,35 @@ if __name__ == '__main__': ...@@ -122,25 +122,35 @@ if __name__ == '__main__':
parser = optparse.OptionParser(usage=usage) parser = optparse.OptionParser(usage=usage)
parser.add_option("--record", parser.add_option("--record",
help="record session to a file", metavar="FILE") help="record session to a file", metavar="FILE")
parser.add_option("--foreground", "-f",
dest="daemon", default=True, action="store_false",
help="stay in foreground, do not daemonize")
parser.add_option("--ssl-only", action="store_true", parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted connections") help="disallow non-encrypted connections")
parser.add_option("--cert", default="self.pem",
help="SSL certificate")
(options, args) = parser.parse_args() (options, args) = parser.parse_args()
if len(args) > 2: parser.error("Too many arguments") if len(args) > 2: parser.error("Too many arguments")
if len(args) < 2: parser.error("Too few arguments") if len(args) < 2: parser.error("Too few arguments")
if args[0].count(':') > 0: if args[0].count(':') > 0:
listen_host,listen_port = args[0].split(':') host,port = args[0].split(':')
else: else:
listen_host = '' host,port = '',args[0]
listen_port = args[0]
if args[1].count(':') > 0: if args[1].count(':') > 0:
target_host,target_port = args[1].split(':') target_host,target_port = args[1].split(':')
else: else:
parser.error("Error parsing target") parser.error("Error parsing target")
try: listen_port = int(listen_port) try: port = int(port)
except: parser.error("Error parsing listen port") except: parser.error("Error parsing listen port")
try: target_port = int(target_port) try: target_port = int(target_port)
except: parser.error("Error parsing target port") except: parser.error("Error parsing target port")
start_server(listen_port, proxy_handler, listen_host=listen_host, settings['listen_host'] = host
ssl_only=options.ssl_only) settings['listen_port'] = port
settings['handler'] = proxy_handler
settings['cert'] = os.path.abspath(options.cert)
settings['ssl_only'] = options.ssl_only
settings['daemon'] = options.daemon
settings['record'] = os.path.abspath(options.record)
start_server()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment