Commit 082027dc authored by samhed's avatar samhed

Sync with websockify

Pull 90b519edf0c1857d
parent 60a415ae
...@@ -262,7 +262,7 @@ function on(evt, handler) { ...@@ -262,7 +262,7 @@ function on(evt, handler) {
eventHandlers[evt] = handler; eventHandlers[evt] = handler;
} }
function init(protocols, ws_schema) { function init(protocols) {
rQ = []; rQ = [];
rQi = 0; rQi = 0;
sQ = []; sQ = [];
...@@ -278,14 +278,11 @@ function init(protocols, ws_schema) { ...@@ -278,14 +278,11 @@ function init(protocols, ws_schema) {
bt = true; bt = true;
} }
// Check for full binary type support in WebSocket // Check for full binary type support in WebSockets
// Inspired by: // TODO: this sucks, the property should exist on the prototype
// https://github.com/Modernizr/Modernizr/issues/370 // but it does not.
// https://github.com/Modernizr/Modernizr/blob/master/feature-detects/websockets/binary.js
try { try {
if (bt && if (bt && ('binaryType' in (new WebSocket("ws://localhost:17523")))) {
('binaryType' in WebSocket.prototype ||
!!(new WebSocket(ws_schema + '://.').binaryType))) {
Util.Info("Detected binaryType support in WebSockets"); Util.Info("Detected binaryType support in WebSockets");
wsbt = true; wsbt = true;
} }
...@@ -328,8 +325,7 @@ function init(protocols, ws_schema) { ...@@ -328,8 +325,7 @@ function init(protocols, ws_schema) {
} }
function open(uri, protocols) { function open(uri, protocols) {
var ws_schema = uri.match(/^([a-z]+):\/\//)[1]; protocols = init(protocols);
protocols = init(protocols, ws_schema);
if (test_mode) { if (test_mode) {
websocket = {}; websocket = {};
......
...@@ -16,7 +16,7 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ...@@ -16,7 +16,7 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import os, sys, time, errno, signal, socket, traceback, select import os, sys, time, errno, signal, socket, select, logging
import array, struct import array, struct
from base64 import b64encode, b64decode from base64 import b64encode, b64decode
...@@ -59,177 +59,57 @@ for mod, msg in [('numpy', 'HyBi protocol will be slower'), ...@@ -59,177 +59,57 @@ for mod, msg in [('numpy', 'HyBi protocol will be slower'),
except ImportError: except ImportError:
globals()[mod] = None globals()[mod] = None
print("WARNING: no '%s' module, %s" % (mod, msg)) print("WARNING: no '%s' module, %s" % (mod, msg))
if multiprocessing and sys.platform == 'win32': if multiprocessing and sys.platform == 'win32':
# make sockets pickle-able/inheritable # make sockets pickle-able/inheritable
import multiprocessing.reduction import multiprocessing.reduction
class WebSocketServer(object): # HTTP handler with WebSocket upgrade support
class WebSocketRequestHandler(SimpleHTTPRequestHandler):
""" """
WebSockets server class. WebSocket Request Handler Class, derived from SimpleHTTPRequestHandler.
Must be sub-classed with new_client method definition. Must be sub-classed with new_websocket_client method definition.
The request handler can be configured by setting optional
attributes on the server object:
* only_upgrade: If true, SimpleHTTPRequestHandler will not be enabled,
only websocket is allowed.
* verbose: If true, verbose logging is activated.
* daemon: Running as daemon, do not write to console etc
* record: Record raw frame data as JavaScript array into specified filename
* run_once: Handle a single request
* handler_id: A sequence number for this connection, appended to record filename
""" """
buffer_size = 65536 buffer_size = 65536
server_handshake_hybi = """HTTP/1.1 101 Switching Protocols\r
Upgrade: websocket\r
Connection: Upgrade\r
Sec-WebSocket-Accept: %s\r
"""
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n""" server_version = "WebSockify"
# An exception before the WebSocket connection was established protocol_version = "HTTP/1.1"
class EClose(Exception):
pass
# An exception while the WebSocket client was connected # An exception while the WebSocket client was connected
class CClose(Exception): class CClose(Exception):
pass pass
def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False, def __init__(self, req, addr, server):
verbose=False, cert='', key='', ssl_only=None, # Retrieve a few configuration variables from the server
daemon=False, record='', web='', self.only_upgrade = getattr(server, "only_upgrade", False)
run_once=False, timeout=0, idle_timeout=0): self.verbose = getattr(server, "verbose", False)
self.daemon = getattr(server, "daemon", False)
# settings self.record = getattr(server, "record", False)
self.verbose = verbose self.run_once = getattr(server, "run_once", False)
self.listen_host = listen_host self.rec = None
self.listen_port = listen_port self.handler_id = getattr(server, "handler_id", False)
self.prefer_ipv6 = source_is_ipv6 self.file_only = getattr(server, "file_only", False)
self.ssl_only = ssl_only self.traffic = getattr(server, "traffic", False)
self.daemon = daemon
self.run_once = run_once
self.timeout = timeout
self.idle_timeout = idle_timeout
self.launch_time = time.time()
self.ws_connection = False
self.handler_id = 1
# Make paths settings absolute
self.cert = os.path.abspath(cert)
self.key = self.web = self.record = ''
if key:
self.key = os.path.abspath(key)
if web:
self.web = os.path.abspath(web)
if record:
self.record = os.path.abspath(record)
if self.web:
os.chdir(self.web)
# Sanity checks
if not ssl and self.ssl_only:
raise Exception("No 'ssl' module and SSL-only specified")
if self.daemon and not resource:
raise Exception("Module 'resource' required to daemonize")
# Show configuration
print("WebSocket server settings:")
print(" - Listen on %s:%s" % (
self.listen_host, self.listen_port))
print(" - Flash security policy server")
if self.web:
print(" - Web server. Web root: %s" % self.web)
if ssl:
if os.path.exists(self.cert):
print(" - SSL/TLS support")
if self.ssl_only:
print(" - Deny non-SSL/TLS connections")
else:
print(" - No SSL/TLS support (no cert file)")
else:
print(" - No SSL/TLS support (no 'ssl' module)")
if self.daemon:
print(" - Backgrounding (daemon)")
if self.record:
print(" - Recording to '%s.*'" % self.record)
#
# WebSocketServer static methods
#
@staticmethod
def socket(host, port=None, connect=False, prefer_ipv6=False, unix_socket=None, use_ssl=False):
""" Resolve a host (and optional port) to an IPv4 or IPv6
address. Create a socket. Bind to it if listen is set,
otherwise connect to it. Return the socket.
"""
flags = 0
if host == '':
host = None
if connect and not (port or unix_socket):
raise Exception("Connect mode requires a port")
if use_ssl and not ssl:
raise Exception("SSL socket requested but Python SSL module not loaded.");
if not connect and use_ssl:
raise Exception("SSL only supported in connect mode (for now)")
if not connect:
flags = flags | socket.AI_PASSIVE
if not unix_socket:
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM,
socket.IPPROTO_TCP, flags)
if not addrs:
raise Exception("Could not resolve host '%s'" % host)
addrs.sort(key=lambda x: x[0])
if prefer_ipv6:
addrs.reverse()
sock = socket.socket(addrs[0][0], addrs[0][1])
if connect:
sock.connect(addrs[0][4])
if use_ssl:
sock = ssl.wrap_socket(sock)
else:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(addrs[0][4])
sock.listen(100)
else:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(unix_socket)
return sock
@staticmethod
def daemonize(keepfd=None, chdir='/'):
os.umask(0)
if chdir:
os.chdir(chdir)
else:
os.chdir('/')
os.setgid(os.getgid()) # relinquish elevations
os.setuid(os.getuid()) # relinquish elevations
# Double fork to daemonize
if os.fork() > 0: os._exit(0) # Parent exits
os.setsid() # Obtain new process group
if os.fork() > 0: os._exit(0) # Parent exits
# Signal handling
def terminate(a,b): os._exit(0)
signal.signal(signal.SIGTERM, terminate)
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Close open files self.logger = getattr(server, "logger", None)
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] if self.logger is None:
if maxfd == resource.RLIM_INFINITY: maxfd = 256 self.logger = WebSocketServer.get_logger()
for fd in reversed(range(maxfd)):
try:
if fd != keepfd:
os.close(fd)
except OSError:
_, exc, _ = sys.exc_info()
if exc.errno != errno.EBADF: raise
# Redirect I/O to /dev/null SimpleHTTPRequestHandler.__init__(self, req, addr, server)
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
@staticmethod @staticmethod
def unmask(buf, hlen, plen): def unmask(buf, hlen, plen):
...@@ -246,7 +126,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -246,7 +126,7 @@ Sec-WebSocket-Accept: %s\r
b = numpy.bitwise_xor(data, mask).tostring() b = numpy.bitwise_xor(data, mask).tostring()
if plen % 4: if plen % 4:
#print("Partial unmask") #self.msg("Partial unmask")
mask = numpy.frombuffer(buf, dtype=numpy.dtype('B'), mask = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
offset=hlen, count=(plen % 4)) offset=hlen, count=(plen % 4))
data = numpy.frombuffer(buf, dtype=numpy.dtype('B'), data = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
...@@ -287,12 +167,12 @@ Sec-WebSocket-Accept: %s\r ...@@ -287,12 +167,12 @@ Sec-WebSocket-Accept: %s\r
elif payload_len >= 65536: elif payload_len >= 65536:
header = pack('>BBQ', b1, 127, payload_len) header = pack('>BBQ', b1, 127, payload_len)
#print("Encoded: %s" % repr(header + buf)) #self.msg("Encoded: %s", repr(header + buf))
return header + buf, len(header), 0 return header + buf, len(header), 0
@staticmethod @staticmethod
def decode_hybi(buf, base64=False): def decode_hybi(buf, base64=False, logger=None):
""" Decode HyBi style WebSocket packets. """ Decode HyBi style WebSocket packets.
Returns: Returns:
{'fin' : 0_or_1, {'fin' : 0_or_1,
...@@ -316,6 +196,9 @@ Sec-WebSocket-Accept: %s\r ...@@ -316,6 +196,9 @@ Sec-WebSocket-Accept: %s\r
'close_code' : 1000, 'close_code' : 1000,
'close_reason' : ''} 'close_reason' : ''}
if logger is None:
logger = WebSocketServer.get_logger()
blen = len(buf) blen = len(buf)
f['left'] = blen f['left'] = blen
...@@ -351,18 +234,18 @@ Sec-WebSocket-Accept: %s\r ...@@ -351,18 +234,18 @@ Sec-WebSocket-Accept: %s\r
# Process 1 frame # Process 1 frame
if f['masked']: if f['masked']:
# unmask payload # unmask payload
f['payload'] = WebSocketServer.unmask(buf, f['hlen'], f['payload'] = WebSocketRequestHandler.unmask(buf, f['hlen'],
f['length']) f['length'])
else: else:
print("Unmasked frame: %s" % repr(buf)) logger.debug("Unmasked frame: %s" % repr(buf))
f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len] f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len]
if base64 and f['opcode'] in [1, 2]: if base64 and f['opcode'] in [1, 2]:
try: try:
f['payload'] = b64decode(f['payload']) f['payload'] = b64decode(f['payload'])
except: except:
print("Exception while b64decoding buffer: %s" % logger.exception("Exception while b64decoding buffer: %s" %
repr(buf)) (repr(buf)))
raise raise
if f['opcode'] == 0x08: if f['opcode'] == 0x08:
...@@ -375,27 +258,32 @@ Sec-WebSocket-Accept: %s\r ...@@ -375,27 +258,32 @@ Sec-WebSocket-Accept: %s\r
# #
# WebSocketServer logging/output functions # WebSocketRequestHandler logging/output functions
# #
def traffic(self, token="."): def print_traffic(self, token="."):
""" Show traffic flow in verbose mode. """ """ Show traffic flow mode. """
if self.verbose and not self.daemon: if self.traffic:
sys.stdout.write(token) sys.stdout.write(token)
sys.stdout.flush() sys.stdout.flush()
def msg(self, msg): def msg(self, msg, *args, **kwargs):
""" Output message with handler_id prefix. """ """ Output message with handler_id prefix. """
if not self.daemon: prefix = "% 3d: " % self.handler_id
print("% 3d: %s" % (self.handler_id, msg)) self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs)
def vmsg(self, msg): def vmsg(self, msg, *args, **kwargs):
""" Same as msg() but only if verbose. """ """ Same as msg() but as debug. """
if self.verbose: prefix = "% 3d: " % self.handler_id
self.msg(msg) self.logger.log(logging.DEBUG, "%s%s" % (prefix, msg), *args, **kwargs)
def warn(self, msg, *args, **kwargs):
""" Same as msg() but as warning. """
prefix = "% 3d: " % self.handler_id
self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs)
# #
# Main WebSocketServer methods # Main WebSocketRequestHandler methods
# #
def send_frames(self, bufs=None): def send_frames(self, bufs=None):
""" Encode and send WebSocket frames. Any frames already """ Encode and send WebSocket frames. Any frames already
...@@ -424,12 +312,12 @@ Sec-WebSocket-Accept: %s\r ...@@ -424,12 +312,12 @@ Sec-WebSocket-Accept: %s\r
while self.send_parts: while self.send_parts:
# Send pending frames # Send pending frames
buf = self.send_parts.pop(0) buf = self.send_parts.pop(0)
sent = self.client.send(buf) sent = self.request.send(buf)
if sent == len(buf): if sent == len(buf):
self.traffic("<") self.print_traffic("<")
else: else:
self.traffic("<.") self.print_traffic("<.")
self.send_parts.insert(0, buf[sent:]) self.send_parts.insert(0, buf[sent:])
break break
...@@ -446,7 +334,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -446,7 +334,7 @@ Sec-WebSocket-Accept: %s\r
bufs = [] bufs = []
tdelta = int(time.time()*1000) - self.start_time tdelta = int(time.time()*1000) - self.start_time
buf = self.client.recv(self.buffer_size) buf = self.request.recv(self.buffer_size)
if len(buf) == 0: if len(buf) == 0:
closed = {'code': 1000, 'reason': "Client closed abruptly"} closed = {'code': 1000, 'reason': "Client closed abruptly"}
return bufs, closed return bufs, closed
...@@ -457,12 +345,13 @@ Sec-WebSocket-Accept: %s\r ...@@ -457,12 +345,13 @@ Sec-WebSocket-Accept: %s\r
self.recv_part = None self.recv_part = None
while buf: while buf:
frame = self.decode_hybi(buf, base64=self.base64) frame = self.decode_hybi(buf, base64=self.base64,
#print("Received buf: %s, frame: %s" % (repr(buf), frame)) logger=self.logger)
#self.msg("Received buf: %s, frame: %s", repr(buf), frame)
if frame['payload'] == None: if frame['payload'] == None:
# Incomplete/partial frame # Incomplete/partial frame
self.traffic("}.") self.print_traffic("}.")
if frame['left'] > 0: if frame['left'] > 0:
self.recv_part = buf[-frame['left']:] self.recv_part = buf[-frame['left']:]
break break
...@@ -472,13 +361,13 @@ Sec-WebSocket-Accept: %s\r ...@@ -472,13 +361,13 @@ Sec-WebSocket-Accept: %s\r
'reason': frame['close_reason']} 'reason': frame['close_reason']}
break break
self.traffic("}") self.print_traffic("}")
if self.rec: if self.rec:
start = frame['hlen'] start = frame['hlen']
end = frame['hlen'] + frame['length'] end = frame['hlen'] + frame['length']
if frame['masked']: if frame['masked']:
recbuf = WebSocketServer.unmask(buf, frame['hlen'], recbuf = WebSocketRequestHandler.unmask(buf, frame['hlen'],
frame['length']) frame['length'])
else: else:
recbuf = buf[frame['hlen']:frame['hlen'] + recbuf = buf[frame['hlen']:frame['hlen'] +
...@@ -501,11 +390,10 @@ Sec-WebSocket-Accept: %s\r ...@@ -501,11 +390,10 @@ Sec-WebSocket-Accept: %s\r
msg = pack(">H%ds" % len(reason), code, reason) msg = pack(">H%ds" % len(reason), code, reason)
buf, h, t = self.encode_hybi(msg, opcode=0x08, base64=False) buf, h, t = self.encode_hybi(msg, opcode=0x08, base64=False)
self.client.send(buf) self.request.send(buf)
def do_websocket_handshake(self, headers, path): def do_websocket_handshake(self):
h = self.headers = headers h = self.headers
self.path = path
prot = 'WebSocket-Protocol' prot = 'WebSocket-Protocol'
protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',') protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',')
...@@ -520,7 +408,8 @@ Sec-WebSocket-Accept: %s\r ...@@ -520,7 +408,8 @@ Sec-WebSocket-Accept: %s\r
if ver in ['7', '8', '13']: if ver in ['7', '8', '13']:
self.version = "hybi-%02d" % int(ver) self.version = "hybi-%02d" % int(ver)
else: else:
raise self.EClose('Unsupported protocol version %s' % ver) self.send_error(400, "Unsupported protocol version %s" % ver)
return False
key = h['Sec-WebSocket-Key'] key = h['Sec-WebSocket-Key']
...@@ -530,23 +419,320 @@ Sec-WebSocket-Accept: %s\r ...@@ -530,23 +419,320 @@ Sec-WebSocket-Accept: %s\r
elif 'base64' in protocols: elif 'base64' in protocols:
self.base64 = True self.base64 = True
else: else:
raise self.EClose("Client must support 'binary' or 'base64' protocol") self.send_error(400, "Client must support 'binary' or 'base64' protocol")
return False
# Generate the hash value for the accept header # Generate the hash value for the accept header
accept = b64encode(sha1(s2b(key + self.GUID)).digest()) accept = b64encode(sha1(s2b(key + self.GUID)).digest())
response = self.server_handshake_hybi % b2s(accept) self.send_response(101, "Switching Protocols")
self.send_header("Upgrade", "websocket")
self.send_header("Connection", "Upgrade")
self.send_header("Sec-WebSocket-Accept", b2s(accept))
if self.base64: if self.base64:
response += "Sec-WebSocket-Protocol: base64\r\n" self.send_header("Sec-WebSocket-Protocol", "base64")
else:
self.send_header("Sec-WebSocket-Protocol", "binary")
self.end_headers()
return True
else:
self.send_error(400, "Missing Sec-WebSocket-Version header. Hixie protocols not supported.")
return False
def handle_websocket(self):
"""Upgrade a connection to Websocket, if requested. If this succeeds,
new_websocket_client() will be called. Otherwise, False is returned.
"""
if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):
if not self.do_websocket_handshake():
return False
# Indicate to server that a Websocket upgrade was done
self.server.ws_connection = True
# Initialize per client settings
self.send_parts = []
self.recv_part = None
self.start_time = int(time.time()*1000)
# client_address is empty with, say, UNIX domain sockets
client_addr = ""
is_ssl = False
try:
client_addr = self.client_address[0]
is_ssl = self.client_address[2]
except IndexError:
pass
if is_ssl:
self.stype = "SSL/TLS (wss://)"
else:
self.stype = "Plain non-SSL (ws://)"
self.log_message("%s: %s WebSocket connection", client_addr,
self.stype)
self.log_message("%s: Version %s, base64: '%s'", client_addr,
self.version, self.base64)
if self.path != '/':
self.log_message("%s: Path: '%s'", client_addr, self.path)
if self.record:
# Record raw frame data as JavaScript array
fname = "%s.%s" % (self.record,
self.handler_id)
self.log_message("opening record file: %s", fname)
self.rec = open(fname, 'w+')
encoding = "binary"
if self.base64: encoding = "base64"
self.rec.write("var VNC_frame_encoding = '%s';\n"
% encoding)
self.rec.write("var VNC_frame_data = [\n")
try:
self.new_websocket_client()
except self.CClose:
# Close the client
_, exc, _ = sys.exc_info()
self.send_close(exc.args[0], exc.args[1])
return True
else:
return False
def do_GET(self):
"""Handle GET request. Calls handle_websocket(). If unsuccessful,
and web server is enabled, SimpleHTTPRequestHandler.do_GET will be called."""
if not self.handle_websocket():
if self.only_upgrade:
self.send_error(405, "Method Not Allowed")
else:
SimpleHTTPRequestHandler.do_GET(self)
def list_directory(self, path):
if self.file_only:
self.send_error(404, "No such file")
else:
return SimpleHTTPRequestHandler.list_directory(self, path)
def new_websocket_client(self):
""" Do something with a WebSockets client connection. """
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")
def do_HEAD(self):
if self.only_upgrade:
self.send_error(405, "Method Not Allowed")
else:
SimpleHTTPRequestHandler.do_HEAD(self)
def finish(self):
if self.rec:
self.rec.write("'EOF'];\n")
self.rec.close()
def handle(self):
# When using run_once, we have a single process, so
# we cannot loop in BaseHTTPRequestHandler.handle; we
# must return and handle new connections
if self.run_once:
self.handle_one_request()
else:
SimpleHTTPRequestHandler.handle(self)
def log_request(self, code='-', size='-'):
if self.verbose:
SimpleHTTPRequestHandler.log_request(self, code, size)
class WebSocketServer(object):
"""
WebSockets server class.
As an alternative, the standard library SocketServer can be used
"""
policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n"""
log_prefix = "websocket"
# An exception before the WebSocket connection was established
class EClose(Exception):
pass
class Terminate(Exception):
pass
def __init__(self, RequestHandlerClass, listen_host='',
listen_port=None, source_is_ipv6=False,
verbose=False, cert='', key='', ssl_only=None,
daemon=False, record='', web='',
file_only=False,
run_once=False, timeout=0, idle_timeout=0, traffic=False,
tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None,
tcp_keepintvl=None):
# settings
self.RequestHandlerClass = RequestHandlerClass
self.verbose = verbose
self.listen_host = listen_host
self.listen_port = listen_port
self.prefer_ipv6 = source_is_ipv6
self.ssl_only = ssl_only
self.daemon = daemon
self.run_once = run_once
self.timeout = timeout
self.idle_timeout = idle_timeout
self.traffic = traffic
self.launch_time = time.time()
self.ws_connection = False
self.handler_id = 1
self.logger = self.get_logger()
self.tcp_keepalive = tcp_keepalive
self.tcp_keepcnt = tcp_keepcnt
self.tcp_keepidle = tcp_keepidle
self.tcp_keepintvl = tcp_keepintvl
# Make paths settings absolute
self.cert = os.path.abspath(cert)
self.key = self.web = self.record = ''
if key:
self.key = os.path.abspath(key)
if web:
self.web = os.path.abspath(web)
if record:
self.record = os.path.abspath(record)
if self.web:
os.chdir(self.web)
self.only_upgrade = not self.web
# Sanity checks
if not ssl and self.ssl_only:
raise Exception("No 'ssl' module and SSL-only specified")
if self.daemon and not resource:
raise Exception("Module 'resource' required to daemonize")
# Show configuration
self.msg("WebSocket server settings:")
self.msg(" - Listen on %s:%s",
self.listen_host, self.listen_port)
self.msg(" - Flash security policy server")
if self.web:
self.msg(" - Web server. Web root: %s", self.web)
if ssl:
if os.path.exists(self.cert):
self.msg(" - SSL/TLS support")
if self.ssl_only:
self.msg(" - Deny non-SSL/TLS connections")
else:
self.msg(" - No SSL/TLS support (no cert file)")
else: else:
response += "Sec-WebSocket-Protocol: binary\r\n" self.msg(" - No SSL/TLS support (no 'ssl' module)")
response += "\r\n" if self.daemon:
self.msg(" - Backgrounding (daemon)")
if self.record:
self.msg(" - Recording to '%s.*'", self.record)
#
# WebSocketServer static methods
#
@staticmethod
def get_logger():
return logging.getLogger("%s.%s" % (
WebSocketServer.log_prefix,
WebSocketServer.__class__.__name__))
@staticmethod
def socket(host, port=None, connect=False, prefer_ipv6=False,
unix_socket=None, use_ssl=False, tcp_keepalive=True,
tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None):
""" Resolve a host (and optional port) to an IPv4 or IPv6
address. Create a socket. Bind to it if listen is set,
otherwise connect to it. Return the socket.
"""
flags = 0
if host == '':
host = None
if connect and not (port or unix_socket):
raise Exception("Connect mode requires a port")
if use_ssl and not ssl:
raise Exception("SSL socket requested but Python SSL module not loaded.");
if not connect and use_ssl:
raise Exception("SSL only supported in connect mode (for now)")
if not connect:
flags = flags | socket.AI_PASSIVE
if not unix_socket:
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM,
socket.IPPROTO_TCP, flags)
if not addrs:
raise Exception("Could not resolve host '%s'" % host)
addrs.sort(key=lambda x: x[0])
if prefer_ipv6:
addrs.reverse()
sock = socket.socket(addrs[0][0], addrs[0][1])
if tcp_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if tcp_keepcnt:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT,
tcp_keepcnt)
if tcp_keepidle:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE,
tcp_keepidle)
if tcp_keepintvl:
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL,
tcp_keepintvl)
if connect:
sock.connect(addrs[0][4])
if use_ssl:
sock = ssl.wrap_socket(sock)
else: else:
raise self.EClose("Missing Sec-WebSocket-Version header. Hixie protocols not supported.") sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(addrs[0][4])
sock.listen(100)
else:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(unix_socket)
return response return sock
@staticmethod
def daemonize(keepfd=None, chdir='/'):
os.umask(0)
if chdir:
os.chdir(chdir)
else:
os.chdir('/')
os.setgid(os.getgid()) # relinquish elevations
os.setuid(os.getuid()) # relinquish elevations
# Double fork to daemonize
if os.fork() > 0: os._exit(0) # Parent exits
os.setsid() # Obtain new process group
if os.fork() > 0: os._exit(0) # Parent exits
# Signal handling
signal.signal(signal.SIGTERM, signal.SIG_IGN)
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Close open files
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
if maxfd == resource.RLIM_INFINITY: maxfd = 256
for fd in reversed(range(maxfd)):
try:
if fd != keepfd:
os.close(fd)
except OSError:
_, exc, _ = sys.exc_info()
if exc.errno != errno.EBADF: raise
# Redirect I/O to /dev/null
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
def do_handshake(self, sock, address): def do_handshake(self, sock, address):
""" """
...@@ -565,7 +751,6 @@ Sec-WebSocket-Accept: %s\r ...@@ -565,7 +751,6 @@ Sec-WebSocket-Accept: %s\r
- Send a WebSockets handshake server response. - Send a WebSockets handshake server response.
- Return the socket for this WebSocket client. - Return the socket for this WebSocket client.
""" """
stype = ""
ready = select.select([sock], [], [], 3)[0] ready = select.select([sock], [], [], 3)[0]
...@@ -609,45 +794,38 @@ Sec-WebSocket-Accept: %s\r ...@@ -609,45 +794,38 @@ Sec-WebSocket-Accept: %s\r
else: else:
raise raise
self.scheme = "wss"
stype = "SSL/TLS (wss://)"
elif self.ssl_only: elif self.ssl_only:
raise self.EClose("non-SSL connection received but disallowed") raise self.EClose("non-SSL connection received but disallowed")
else: else:
retsock = sock retsock = sock
self.scheme = "ws"
stype = "Plain non-SSL (ws://)"
wsh = WSRequestHandler(retsock, address, not self.web) # If the address is like (host, port), we are extending it
if wsh.last_code == 101: # with a flag indicating SSL. Not many other options
# Continue on to handle WebSocket upgrade # available...
pass if len(address) == 2:
elif wsh.last_code == 405: address = (address[0], address[1], (retsock != sock))
raise self.EClose("Normal web request received but disallowed")
elif wsh.last_code < 200 or wsh.last_code >= 300:
raise self.EClose(wsh.last_message)
elif self.verbose:
raise self.EClose(wsh.last_message)
else:
raise self.EClose("")
response = self.do_websocket_handshake(wsh.headers, wsh.path) self.RequestHandlerClass(retsock, address, self)
self.msg("%s: %s WebSocket connection" % (address[0], stype)) # Return the WebSockets socket which may be SSL wrapped
self.msg("%s: Version %s, base64: '%s'" % (address[0], return retsock
self.version, self.base64))
if self.path != '/':
self.msg("%s: Path: '%s'" % (address[0], self.path))
#
# WebSocketServer logging/output functions
#
# Send server WebSockets handshake response def msg(self, *args, **kwargs):
#self.msg("sending response [%s]" % response) """ Output message as info """
retsock.send(s2b(response)) self.logger.log(logging.INFO, *args, **kwargs)
# Return the WebSockets socket which may be SSL wrapped def vmsg(self, *args, **kwargs):
return retsock """ Same as msg() but as debug. """
self.logger.log(logging.DEBUG, *args, **kwargs)
def warn(self, *args, **kwargs):
""" Same as msg() but as warning. """
self.logger.log(logging.WARN, *args, **kwargs)
# #
...@@ -662,6 +840,12 @@ Sec-WebSocket-Accept: %s\r ...@@ -662,6 +840,12 @@ Sec-WebSocket-Accept: %s\r
#self.vmsg("Running poll()") #self.vmsg("Running poll()")
pass pass
def terminate(self):
raise self.Terminate()
def multiprocessing_SIGCHLD(self, sig, stack):
self.vmsg('Reaing zombies, active child count is %s', len(multiprocessing.active_children()))
def fallback_SIGCHLD(self, sig, stack): def fallback_SIGCHLD(self, sig, stack):
# Reap zombies when using os.fork() (python 2.4) # Reap zombies when using os.fork() (python 2.4)
self.vmsg("Got SIGCHLD, reaping zombies") self.vmsg("Got SIGCHLD, reaping zombies")
...@@ -675,95 +859,83 @@ Sec-WebSocket-Accept: %s\r ...@@ -675,95 +859,83 @@ Sec-WebSocket-Accept: %s\r
def do_SIGINT(self, sig, stack): def do_SIGINT(self, sig, stack):
self.msg("Got SIGINT, exiting") self.msg("Got SIGINT, exiting")
sys.exit(0) self.terminate()
def do_SIGTERM(self, sig, stack):
self.msg("Got SIGTERM, exiting")
self.terminate()
def top_new_client(self, startsock, address): def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """ """ Do something with a WebSockets client connection. """
# Initialize per client settings
self.send_parts = []
self.recv_part = None
self.base64 = False
self.rec = None
self.start_time = int(time.time()*1000)
# handler process # handler process
client = None
try: try:
try: try:
self.client = self.do_handshake(startsock, address) client = self.do_handshake(startsock, address)
if self.record:
# Record raw frame data as JavaScript array
fname = "%s.%s" % (self.record,
self.handler_id)
self.msg("opening record file: %s" % fname)
self.rec = open(fname, 'w+')
encoding = "binary"
if self.base64: encoding = "base64"
self.rec.write("var VNC_frame_encoding = '%s';\n"
% encoding)
self.rec.write("var VNC_frame_data = [\n")
self.ws_connection = True
self.new_client()
except self.CClose:
# Close the client
_, exc, _ = sys.exc_info()
if self.client:
self.send_close(exc.args[0], exc.args[1])
except self.EClose: except self.EClose:
_, exc, _ = sys.exc_info() _, exc, _ = sys.exc_info()
# Connection was not a WebSockets connection # Connection was not a WebSockets connection
if exc.args[0]: if exc.args[0]:
self.msg("%s: %s" % (address[0], exc.args[0])) self.msg("%s: %s" % (address[0], exc.args[0]))
except WebSocketServer.Terminate:
raise
except Exception: except Exception:
_, exc, _ = sys.exc_info() _, exc, _ = sys.exc_info()
self.msg("handler exception: %s" % str(exc)) self.msg("handler exception: %s" % str(exc))
if self.verbose: self.vmsg("exception", exc_info=True)
self.msg(traceback.format_exc())
finally: finally:
if self.rec:
self.rec.write("'EOF'];\n")
self.rec.close()
if self.client and self.client != startsock: if client and client != startsock:
# Close the SSL wrapped socket # Close the SSL wrapped socket
# Original socket closed by caller # Original socket closed by caller
self.client.close() client.close()
def new_client(self):
""" Do something with a WebSockets client connection. """
raise("WebSocketServer.new_client() must be overloaded")
def start_server(self): def start_server(self):
""" """
Daemonize if requested. Listen for for connections. Run Daemonize if requested. Listen for for connections. Run
do_handshake() method for each connection. If the connection do_handshake() method for each connection. If the connection
is a WebSockets client then call new_client() method (which must is a WebSockets client then call new_websocket_client() method (which must
be overridden) for each new client connection. be overridden) for each new client connection.
""" """
lsock = self.socket(self.listen_host, self.listen_port, False, self.prefer_ipv6) lsock = self.socket(self.listen_host, self.listen_port, False,
self.prefer_ipv6,
tcp_keepalive=self.tcp_keepalive,
tcp_keepcnt=self.tcp_keepcnt,
tcp_keepidle=self.tcp_keepidle,
tcp_keepintvl=self.tcp_keepintvl)
if self.daemon: if self.daemon:
self.daemonize(keepfd=lsock.fileno(), chdir=self.web) self.daemonize(keepfd=lsock.fileno(), chdir=self.web)
self.started() # Some things need to happen after daemonizing self.started() # Some things need to happen after daemonizing
# Allow override of SIGINT # Allow override of signals
original_signals = {
signal.SIGINT: signal.getsignal(signal.SIGINT),
signal.SIGTERM: signal.getsignal(signal.SIGTERM),
signal.SIGCHLD: signal.getsignal(signal.SIGCHLD),
}
signal.signal(signal.SIGINT, self.do_SIGINT) signal.signal(signal.SIGINT, self.do_SIGINT)
signal.signal(signal.SIGTERM, self.do_SIGTERM)
if not multiprocessing: if not multiprocessing:
# os.fork() (python 2.4) child reaper # os.fork() (python 2.4) child reaper
signal.signal(signal.SIGCHLD, self.fallback_SIGCHLD) signal.signal(signal.SIGCHLD, self.fallback_SIGCHLD)
else:
# make sure that _cleanup is called when children die
# by calling active_children on SIGCHLD
signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD)
last_active_time = self.launch_time last_active_time = self.launch_time
try:
while True: while True:
try: try:
try: try:
self.client = None
startsock = None startsock = None
pid = err = 0 pid = err = 0
child_count = 0 child_count = 0
if multiprocessing and self.idle_timeout: if multiprocessing:
# Collect zombie child processes
child_count = len(multiprocessing.active_children()) child_count = len(multiprocessing.active_children())
time_elapsed = time.time() - self.launch_time time_elapsed = time.time() - self.launch_time
...@@ -793,6 +965,8 @@ Sec-WebSocket-Accept: %s\r ...@@ -793,6 +965,8 @@ Sec-WebSocket-Accept: %s\r
startsock, address = lsock.accept() startsock, address = lsock.accept()
else: else:
continue continue
except self.Terminate:
raise
except Exception: except Exception:
_, exc, _ = sys.exc_info() _, exc, _ = sys.exc_info()
if hasattr(exc, 'errno'): if hasattr(exc, 'errno'):
...@@ -833,55 +1007,24 @@ Sec-WebSocket-Accept: %s\r ...@@ -833,55 +1007,24 @@ Sec-WebSocket-Accept: %s\r
# parent process # parent process
self.handler_id += 1 self.handler_id += 1
except KeyboardInterrupt: except (self.Terminate, SystemExit, KeyboardInterrupt):
_, exc, _ = sys.exc_info() self.msg("In exit")
print("In KeyboardInterrupt")
pass
except SystemExit:
_, exc, _ = sys.exc_info()
print("In SystemExit")
break break
except Exception: except Exception:
_, exc, _ = sys.exc_info() self.msg("handler exception: %s", str(exc))
self.msg("handler exception: %s" % str(exc)) self.vmsg("exception", exc_info=True)
if self.verbose:
self.msg(traceback.format_exc())
finally: finally:
if startsock: if startsock:
startsock.close() startsock.close()
finally:
# Close listen port # Close listen port
self.vmsg("Closing socket listening at %s:%s" self.vmsg("Closing socket listening at %s:%s",
% (self.listen_host, self.listen_port)) self.listen_host, self.listen_port)
lsock.close() lsock.close()
# Restore signals
for sig, func in original_signals.items():
signal.signal(sig, func)
# HTTP handler with WebSocket upgrade support
class WSRequestHandler(SimpleHTTPRequestHandler):
def __init__(self, req, addr, only_upgrade=False):
self.only_upgrade = only_upgrade # only allow upgrades
SimpleHTTPRequestHandler.__init__(self, req, addr, object())
def do_GET(self):
if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):
# Just indicate that an WebSocket upgrade is needed
self.last_code = 101
self.last_message = "101 Switching Protocols"
elif self.only_upgrade:
# Normal web request responses are disabled
self.last_code = 405
self.last_message = "405 Method Not Allowed"
else:
SimpleHTTPRequestHandler.do_GET(self)
def send_response(self, code, message=None):
# Save the status code
self.last_code = code
SimpleHTTPRequestHandler.send_response(self, code, message)
def log_message(self, f, *args):
# Save instead of printing
self.last_message = f % args
...@@ -11,7 +11,11 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ...@@ -11,7 +11,11 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import signal, socket, optparse, time, os, sys, subprocess import signal, socket, optparse, time, os, sys, subprocess, logging
try: from socketserver import ForkingMixIn
except: from SocketServer import ForkingMixIn
try: from http.server import HTTPServer
except: from BaseHTTPServer import HTTPServer
from select import select from select import select
import websocket import websocket
try: try:
...@@ -20,15 +24,7 @@ except: ...@@ -20,15 +24,7 @@ except:
from cgi import parse_qs from cgi import parse_qs
from urlparse import urlparse from urlparse import urlparse
class WebSocketProxy(websocket.WebSocketServer): class ProxyRequestHandler(websocket.WebSocketRequestHandler):
"""
Proxy traffic to and from a WebSockets client to a normal TCP
socket server target. All traffic to/from the client is base64
encoded/decoded to allow binary data to be sent/received to/from
the target.
"""
buffer_size = 65536
traffic_legend = """ traffic_legend = """
Traffic Legend: Traffic Legend:
...@@ -42,148 +38,33 @@ Traffic Legend: ...@@ -42,148 +38,33 @@ Traffic Legend:
<. - Client send partial <. - Client send partial
""" """
def __init__(self, *args, **kwargs): def new_websocket_client(self):
# Save off proxy specific options
self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None)
self.target_cfg = kwargs.pop('target_cfg', None)
# Last 3 timestamps command was run
self.wrap_times = [0, 0, 0]
if self.wrap_cmd:
rebinder_path = ['./', os.path.dirname(sys.argv[0])]
self.rebinder = None
for rdir in rebinder_path:
rpath = os.path.join(rdir, "rebind.so")
if os.path.exists(rpath):
self.rebinder = rpath
break
if not self.rebinder:
raise Exception("rebind.so not found, perhaps you need to run make")
self.rebinder = os.path.abspath(self.rebinder)
self.target_host = "127.0.0.1" # Loopback
# Find a free high port
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('', 0))
self.target_port = sock.getsockname()[1]
sock.close()
os.environ.update({
"LD_PRELOAD": self.rebinder,
"REBIND_OLD_PORT": str(kwargs['listen_port']),
"REBIND_NEW_PORT": str(self.target_port)})
if self.target_cfg:
self.target_cfg = os.path.abspath(self.target_cfg)
websocket.WebSocketServer.__init__(self, *args, **kwargs)
def run_wrap_cmd(self):
print("Starting '%s'" % " ".join(self.wrap_cmd))
self.wrap_times.append(time.time())
self.wrap_times.pop(0)
self.cmd = subprocess.Popen(
self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup)
self.spawn_message = True
def started(self):
"""
Called after Websockets server startup (i.e. after daemonize)
"""
# Need to call wrapped command after daemonization so we can
# know when the wrapped command exits
if self.wrap_cmd:
dst_string = "'%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port)
elif self.unix_target:
dst_string = self.unix_target
else:
dst_string = "%s:%s" % (self.target_host, self.target_port)
if self.target_cfg:
msg = " - proxying from %s:%s to targets in %s" % (
self.listen_host, self.listen_port, self.target_cfg)
else:
msg = " - proxying from %s:%s to %s" % (
self.listen_host, self.listen_port, dst_string)
if self.ssl_target:
msg += " (using SSL)"
print(msg + "\n")
if self.wrap_cmd:
self.run_wrap_cmd()
def poll(self):
# If we are wrapping a command, check it's status
if self.wrap_cmd and self.cmd:
ret = self.cmd.poll()
if ret != None:
self.vmsg("Wrapped command exited (or daemon). Returned %s" % ret)
self.cmd = None
if self.wrap_cmd and self.cmd == None:
# Response to wrapped command being gone
if self.wrap_mode == "ignore":
pass
elif self.wrap_mode == "exit":
sys.exit(ret)
elif self.wrap_mode == "respawn":
now = time.time()
avg = sum(self.wrap_times)/len(self.wrap_times)
if (now - avg) < 10:
# 3 times in the last 10 seconds
if self.spawn_message:
print("Command respawning too fast")
self.spawn_message = False
else:
self.run_wrap_cmd()
#
# Routines above this point are run in the master listener
# process.
#
#
# Routines below this point are connection handler routines and
# will be run in a separate forked process for each connection.
#
def new_client(self):
""" """
Called after a new WebSocket connection has been established. Called after a new WebSocket connection has been established.
""" """
# Checks if we receive a token, and look # Checks if we receive a token, and look
# for a valid target for it then # for a valid target for it then
if self.target_cfg: if self.server.target_cfg:
(self.target_host, self.target_port) = self.get_target(self.target_cfg, self.path) (self.server.target_host, self.server.target_port) = self.get_target(self.server.target_cfg, self.path)
# Connect to the target # Connect to the target
if self.wrap_cmd: if self.server.wrap_cmd:
msg = "connecting to command: '%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port) msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port)
elif self.unix_target: elif self.server.unix_target:
msg = "connecting to unix socket: %s" % self.unix_target msg = "connecting to unix socket: %s" % self.server.unix_target
else: else:
msg = "connecting to: %s:%s" % ( msg = "connecting to: %s:%s" % (
self.target_host, self.target_port) self.server.target_host, self.server.target_port)
if self.ssl_target: if self.server.ssl_target:
msg += " (using SSL)" msg += " (using SSL)"
self.msg(msg) self.log_message(msg)
tsock = self.socket(self.target_host, self.target_port, tsock = websocket.WebSocketServer.socket(self.server.target_host,
connect=True, use_ssl=self.ssl_target, unix_socket=self.unix_target) self.server.target_port,
connect=True, use_ssl=self.server.ssl_target, unix_socket=self.server.unix_target)
if self.verbose and not self.daemon: self.print_traffic(self.traffic_legend)
print(self.traffic_legend)
# Start proxying # Start proxying
try: try:
...@@ -192,8 +73,9 @@ Traffic Legend: ...@@ -192,8 +73,9 @@ Traffic Legend:
if tsock: if tsock:
tsock.shutdown(socket.SHUT_RDWR) tsock.shutdown(socket.SHUT_RDWR)
tsock.close() tsock.close()
self.vmsg("%s:%s: Closed target" %( if self.verbose:
self.target_host, self.target_port)) self.log_message("%s:%s: Closed target",
self.server.target_host, self.server.target_port)
raise raise
def get_target(self, target_cfg, path): def get_target(self, target_cfg, path):
...@@ -242,31 +124,32 @@ Traffic Legend: ...@@ -242,31 +124,32 @@ Traffic Legend:
cqueue = [] cqueue = []
c_pend = 0 c_pend = 0
tqueue = [] tqueue = []
rlist = [self.client, target] rlist = [self.request, target]
while True: while True:
wlist = [] wlist = []
if tqueue: wlist.append(target) if tqueue: wlist.append(target)
if cqueue or c_pend: wlist.append(self.client) if cqueue or c_pend: wlist.append(self.request)
ins, outs, excepts = select(rlist, wlist, [], 1) ins, outs, excepts = select(rlist, wlist, [], 1)
if excepts: raise Exception("Socket exception") if excepts: raise Exception("Socket exception")
if self.client in outs: if self.request in outs:
# Send queued target data to the client # Send queued target data to the client
c_pend = self.send_frames(cqueue) c_pend = self.send_frames(cqueue)
cqueue = [] cqueue = []
if self.client in ins: if self.request in ins:
# Receive client data, decode it, and queue for target # Receive client data, decode it, and queue for target
bufs, closed = self.recv_frames() bufs, closed = self.recv_frames()
tqueue.extend(bufs) tqueue.extend(bufs)
if closed: if closed:
# TODO: What about blocking on client socket? # TODO: What about blocking on client socket?
self.vmsg("%s:%s: Client closed connection" %( if self.verbose:
self.target_host, self.target_port)) self.log_message("%s:%s: Client closed connection",
self.server.target_host, self.server.target_port)
raise self.CClose(closed['code'], closed['reason']) raise self.CClose(closed['code'], closed['reason'])
...@@ -275,24 +158,139 @@ Traffic Legend: ...@@ -275,24 +158,139 @@ Traffic Legend:
dat = tqueue.pop(0) dat = tqueue.pop(0)
sent = target.send(dat) sent = target.send(dat)
if sent == len(dat): if sent == len(dat):
self.traffic(">") self.print_traffic(">")
else: else:
# requeue the remaining data # requeue the remaining data
tqueue.insert(0, dat[sent:]) tqueue.insert(0, dat[sent:])
self.traffic(".>") self.print_traffic(".>")
if target in ins: if target in ins:
# Receive target data, encode it and queue for client # Receive target data, encode it and queue for client
buf = target.recv(self.buffer_size) buf = target.recv(self.buffer_size)
if len(buf) == 0: if len(buf) == 0:
self.vmsg("%s:%s: Target closed connection" %( if self.verbose:
self.target_host, self.target_port)) self.log_message("%s:%s: Target closed connection",
self.server.target_host, self.server.target_port)
raise self.CClose(1000, "Target closed") raise self.CClose(1000, "Target closed")
cqueue.append(buf) cqueue.append(buf)
self.traffic("{") self.print_traffic("{")
class WebSocketProxy(websocket.WebSocketServer):
"""
Proxy traffic to and from a WebSockets client to a normal TCP
socket server target. All traffic to/from the client is base64
encoded/decoded to allow binary data to be sent/received to/from
the target.
"""
buffer_size = 65536
def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs):
# Save off proxy specific options
self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None)
self.target_cfg = kwargs.pop('target_cfg', None)
# Last 3 timestamps command was run
self.wrap_times = [0, 0, 0]
if self.wrap_cmd:
wsdir = os.path.dirname(sys.argv[0])
rebinder_path = [os.path.join(wsdir, "..", "lib"),
os.path.join(wsdir, "..", "lib", "websockify"),
wsdir]
self.rebinder = None
for rdir in rebinder_path:
rpath = os.path.join(rdir, "rebind.so")
if os.path.exists(rpath):
self.rebinder = rpath
break
if not self.rebinder:
raise Exception("rebind.so not found, perhaps you need to run make")
self.rebinder = os.path.abspath(self.rebinder)
self.target_host = "127.0.0.1" # Loopback
# Find a free high port
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('', 0))
self.target_port = sock.getsockname()[1]
sock.close()
os.environ.update({
"LD_PRELOAD": self.rebinder,
"REBIND_OLD_PORT": str(kwargs['listen_port']),
"REBIND_NEW_PORT": str(self.target_port)})
websocket.WebSocketServer.__init__(self, RequestHandlerClass, *args, **kwargs)
def run_wrap_cmd(self):
self.msg("Starting '%s'", " ".join(self.wrap_cmd))
self.wrap_times.append(time.time())
self.wrap_times.pop(0)
self.cmd = subprocess.Popen(
self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup)
self.spawn_message = True
def started(self):
"""
Called after Websockets server startup (i.e. after daemonize)
"""
# Need to call wrapped command after daemonization so we can
# know when the wrapped command exits
if self.wrap_cmd:
dst_string = "'%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port)
elif self.unix_target:
dst_string = self.unix_target
else:
dst_string = "%s:%s" % (self.target_host, self.target_port)
if self.target_cfg:
msg = " - proxying from %s:%s to targets in %s" % (
self.listen_host, self.listen_port, self.target_cfg)
else:
msg = " - proxying from %s:%s to %s" % (
self.listen_host, self.listen_port, dst_string)
if self.ssl_target:
msg += " (using SSL)"
self.msg("%s", msg)
if self.wrap_cmd:
self.run_wrap_cmd()
def poll(self):
# If we are wrapping a command, check it's status
if self.wrap_cmd and self.cmd:
ret = self.cmd.poll()
if ret != None:
self.vmsg("Wrapped command exited (or daemon). Returned %s" % ret)
self.cmd = None
if self.wrap_cmd and self.cmd == None:
# Response to wrapped command being gone
if self.wrap_mode == "ignore":
pass
elif self.wrap_mode == "exit":
sys.exit(ret)
elif self.wrap_mode == "respawn":
now = time.time()
avg = sum(self.wrap_times)/len(self.wrap_times)
if (now - avg) < 10:
# 3 times in the last 10 seconds
if self.spawn_message:
self.warn("Command respawning too fast")
self.spawn_message = False
else:
self.run_wrap_cmd()
def _subprocess_setup(): def _subprocess_setup():
...@@ -301,14 +299,28 @@ def _subprocess_setup(): ...@@ -301,14 +299,28 @@ def _subprocess_setup():
signal.signal(signal.SIGPIPE, signal.SIG_DFL) signal.signal(signal.SIGPIPE, signal.SIG_DFL)
def logger_init():
logger = logging.getLogger(WebSocketProxy.log_prefix)
logger.propagate = False
logger.setLevel(logging.INFO)
h = logging.StreamHandler()
h.setLevel(logging.DEBUG)
h.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(h)
def websockify_init(): def websockify_init():
logger_init()
usage = "\n %prog [options]" usage = "\n %prog [options]"
usage += " [source_addr:]source_port [target_addr:target_port]" usage += " [source_addr:]source_port [target_addr:target_port]"
usage += "\n %prog [options]" usage += "\n %prog [options]"
usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE" usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE"
parser = optparse.OptionParser(usage=usage) parser = optparse.OptionParser(usage=usage)
parser.add_option("--verbose", "-v", action="store_true", parser.add_option("--verbose", "-v", action="store_true",
help="verbose messages and per frame traffic") help="verbose messages")
parser.add_option("--traffic", action="store_true",
help="per frame traffic")
parser.add_option("--record", parser.add_option("--record",
help="record sessions to FILE.[session_number]", metavar="FILE") help="record sessions to FILE.[session_number]", metavar="FILE")
parser.add_option("--daemon", "-D", parser.add_option("--daemon", "-D",
...@@ -345,8 +357,13 @@ def websockify_init(): ...@@ -345,8 +357,13 @@ def websockify_init():
help="Configuration file containing valid targets " help="Configuration file containing valid targets "
"in the form 'token: host:port' or, alternatively, a " "in the form 'token: host:port' or, alternatively, a "
"directory containing configuration files of this form") "directory containing configuration files of this form")
parser.add_option("--libserver", action="store_true",
help="use Python library SocketServer engine")
(opts, args) = parser.parse_args() (opts, args) = parser.parse_args()
if opts.verbose:
logging.getLogger(WebSocketProxy.log_prefix).setLevel(logging.DEBUG)
# Sanity checks # Sanity checks
if len(args) < 2 and not (opts.target_cfg or opts.unix_target): if len(args) < 2 and not (opts.target_cfg or opts.unix_target):
parser.error("Too few arguments") parser.error("Too few arguments")
...@@ -385,9 +402,70 @@ def websockify_init(): ...@@ -385,9 +402,70 @@ def websockify_init():
try: opts.target_port = int(opts.target_port) try: opts.target_port = int(opts.target_port)
except: parser.error("Error parsing target port") except: parser.error("Error parsing target port")
# Transform to absolute path as daemon may chdir
if opts.target_cfg:
opts.target_cfg = os.path.abspath(opts.target_cfg)
# Create and start the WebSockets proxy # Create and start the WebSockets proxy
libserver = opts.libserver
del opts.libserver
if libserver:
# Use standard Python SocketServer framework
server = LibProxyServer(**opts.__dict__)
server.serve_forever()
else:
# Use internal service framework
server = WebSocketProxy(**opts.__dict__) server = WebSocketProxy(**opts.__dict__)
server.start_server() server.start_server()
class LibProxyServer(ForkingMixIn, HTTPServer):
"""
Just like WebSocketProxy, but uses standard Python SocketServer
framework.
"""
def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs):
# Save off proxy specific options
self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None)
self.target_cfg = kwargs.pop('target_cfg', None)
self.daemon = False
self.target_cfg = None
# Server configuration
listen_host = kwargs.pop('listen_host', '')
listen_port = kwargs.pop('listen_port', None)
web = kwargs.pop('web', '')
# Configuration affecting base request handler
self.only_upgrade = not web
self.verbose = kwargs.pop('verbose', False)
record = kwargs.pop('record', '')
if record:
self.record = os.path.abspath(record)
self.run_once = kwargs.pop('run_once', False)
self.handler_id = 0
for arg in kwargs.keys():
print("warning: option %s ignored when using --libserver" % arg)
if web:
os.chdir(web)
HTTPServer.__init__(self, (listen_host, listen_port),
RequestHandlerClass)
def process_request(self, request, client_address):
"""Override process_request to implement a counter"""
self.handler_id += 1
ForkingMixIn.process_request(self, request, client_address)
if __name__ == '__main__': if __name__ == '__main__':
websockify_init() websockify_init()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment