Commit 8c305c60 authored by Joel Martin's avatar Joel Martin

Pull fix of recording from websockify.

Pull websockify 7f487fdbd.

The reocrd parameter will turn on recording of all messages sent
to and from the client. The record parameter is a file prefix. The
full file-name will be the prefix with an extension '.HANDLER_ID'
based on the handler ID.
parent fa8f14d5
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
''' '''
Python WebSocket library with support for "wss://" encryption. Python WebSocket library with support for "wss://" encryption.
Copyright 2010 Joel Martin Copyright 2011 Joel Martin
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
Supports following protocol versions: Supports following protocol versions:
...@@ -16,23 +16,48 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ...@@ -16,23 +16,48 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import sys, socket, ssl, struct, traceback, select import os, sys, time, errno, signal, socket, struct, traceback, select
import os, resource, errno, signal # daemonizing from cgi import parse_qsl
from SimpleHTTPServer import SimpleHTTPRequestHandler
from cStringIO import StringIO
from base64 import b64encode, b64decode from base64 import b64encode, b64decode
try:
# Imports that vary by python version
if sys.hexversion > 0x3000000:
# python >= 3.0
from io import StringIO
from http.server import SimpleHTTPRequestHandler
from urllib.parse import urlsplit
b2s = lambda buf: buf.decode('latin_1')
s2b = lambda s: s.encode('latin_1')
else:
# python 2.X
from cStringIO import StringIO
from SimpleHTTPServer import SimpleHTTPRequestHandler
from urlparse import urlsplit
# No-ops
b2s = lambda buf: buf
s2b = lambda s: s
if sys.hexversion >= 0x2060000:
# python >= 2.6
from multiprocessing import Process
from hashlib import md5, sha1 from hashlib import md5, sha1
except: else:
# Support python 2.4 # python < 2.6
Process = None
from md5 import md5 from md5 import md5
from sha import sha as sha1 from sha import sha as sha1
try:
import numpy, ctypes # Degraded functionality if these imports are missing
except: for mod, sup in [('numpy', 'HyBi protocol'),
numpy = ctypes = None ('ctypes', 'HyBi protocol'), ('ssl', 'TLS/SSL/wss'),
from urlparse import urlsplit ('resource', 'daemonizing')]:
from cgi import parse_qsl try:
globals()[mod] = __import__(mod)
except ImportError:
globals()[mod] = None
print("WARNING: no '%s' module, %s support disabled" % (
mod, sup))
class WebSocketServer(object): class WebSocketServer(object):
""" """
...@@ -72,6 +97,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -72,6 +97,7 @@ Sec-WebSocket-Accept: %s\r
self.listen_port = listen_port self.listen_port = listen_port
self.ssl_only = ssl_only self.ssl_only = ssl_only
self.daemon = daemon self.daemon = daemon
self.handler_id = 1
# Make paths settings absolute # Make paths settings absolute
self.cert = os.path.abspath(cert) self.cert = os.path.abspath(cert)
...@@ -86,22 +112,32 @@ Sec-WebSocket-Accept: %s\r ...@@ -86,22 +112,32 @@ Sec-WebSocket-Accept: %s\r
if self.web: if self.web:
os.chdir(self.web) os.chdir(self.web)
self.handler_id = 1 # Sanity checks
if ssl and self.ssl_only:
print "WebSocket server settings:" raise Exception("No 'ssl' module and SSL-only specified")
print " - Listen on %s:%s" % ( if self.daemon and not resource:
self.listen_host, self.listen_port) raise Exception("Module 'resource' required to daemonize")
print " - Flash security policy server"
# Show configuration
print("WebSocket server settings:")
print(" - Listen on %s:%s" % (
self.listen_host, self.listen_port))
print(" - Flash security policy server")
if self.web: if self.web:
print " - Web server" print(" - Web server")
if os.path.exists(self.cert): if ssl:
print " - SSL/TLS support" if os.path.exists(self.cert):
if self.ssl_only: print(" - SSL/TLS support")
print " - Deny non-SSL/TLS connections" if self.ssl_only:
print(" - Deny non-SSL/TLS connections")
else:
print(" - No SSL/TLS support (no cert file)")
else: else:
print " - No SSL/TLS support (no cert file)" print(" - No SSL/TLS support (no 'ssl' module)")
if self.daemon: if self.daemon:
print " - Backgrounding (daemon)" print(" - Backgrounding (daemon)")
if self.record:
print(" - Recording to '%s.*'" % self.record)
# #
# WebSocketServer static methods # WebSocketServer static methods
...@@ -133,7 +169,8 @@ Sec-WebSocket-Accept: %s\r ...@@ -133,7 +169,8 @@ Sec-WebSocket-Accept: %s\r
try: try:
if fd != keepfd: if fd != keepfd:
os.close(fd) os.close(fd)
except OSError, exc: except OSError:
_, exc, _ = sys.exc_info()
if exc.errno != errno.EBADF: raise if exc.errno != errno.EBADF: raise
# Redirect I/O to /dev/null # Redirect I/O to /dev/null
...@@ -164,9 +201,9 @@ Sec-WebSocket-Accept: %s\r ...@@ -164,9 +201,9 @@ Sec-WebSocket-Accept: %s\r
elif payload_len >= 65536: elif payload_len >= 65536:
header = struct.pack('>BBQ', b1, 127, payload_len) header = struct.pack('>BBQ', b1, 127, payload_len)
#print "Encoded: %s" % repr(header + buf) #print("Encoded: %s" % repr(header + buf))
return header + buf return header + buf, len(header), 0
@staticmethod @staticmethod
def decode_hybi(buf, base64=False): def decode_hybi(buf, base64=False):
...@@ -175,6 +212,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -175,6 +212,7 @@ Sec-WebSocket-Accept: %s\r
{'fin' : 0_or_1, {'fin' : 0_or_1,
'opcode' : number, 'opcode' : number,
'mask' : 32_bit_number, 'mask' : 32_bit_number,
'hlen' : header_bytes_number,
'length' : payload_bytes_number, 'length' : payload_bytes_number,
'payload' : decoded_buffer, 'payload' : decoded_buffer,
'left' : bytes_left_number, 'left' : bytes_left_number,
...@@ -182,122 +220,103 @@ Sec-WebSocket-Accept: %s\r ...@@ -182,122 +220,103 @@ Sec-WebSocket-Accept: %s\r
'close_reason' : string} 'close_reason' : string}
""" """
ret = {'fin' : 0, f = {'fin' : 0,
'opcode' : 0, 'opcode' : 0,
'mask' : 0, 'mask' : 0,
'length' : 0, 'hlen' : 2,
'payload' : None, 'length' : 0,
'left' : 0, 'payload' : None,
'close_code' : None, 'left' : 0,
'close_reason' : None} 'close_code' : None,
'close_reason' : None}
blen = len(buf) blen = len(buf)
ret['left'] = blen f['left'] = blen
header_len = 2
if blen < header_len: if blen < f['hlen']:
return ret # Incomplete frame header return f # Incomplete frame header
b1, b2 = struct.unpack_from(">BB", buf) b1, b2 = struct.unpack_from(">BB", buf)
ret['opcode'] = b1 & 0x0f f['opcode'] = b1 & 0x0f
ret['fin'] = (b1 & 0x80) >> 7 f['fin'] = (b1 & 0x80) >> 7
has_mask = (b2 & 0x80) >> 7 has_mask = (b2 & 0x80) >> 7
ret['length'] = b2 & 0x7f f['length'] = b2 & 0x7f
if ret['length'] == 126: if f['length'] == 126:
header_len = 4 f['hlen'] = 4
if blen < header_len: if blen < f['hlen']:
return ret # Incomplete frame header return f # Incomplete frame header
(ret['length'],) = struct.unpack_from('>xxH', buf) (f['length'],) = struct.unpack_from('>xxH', buf)
elif ret['length'] == 127: elif f['length'] == 127:
header_len = 10 f['hlen'] = 10
if blen < header_len: if blen < f['hlen']:
return ret # Incomplete frame header return f # Incomplete frame header
(ret['length'],) = struct.unpack_from('>xxQ', buf) (f['length'],) = struct.unpack_from('>xxQ', buf)
full_len = header_len + has_mask * 4 + ret['length'] full_len = f['hlen'] + has_mask * 4 + f['length']
if blen < full_len: # Incomplete frame if blen < full_len: # Incomplete frame
return ret # Incomplete frame header return f # Incomplete frame header
# Number of bytes that are part of the next frame(s) # Number of bytes that are part of the next frame(s)
ret['left'] = blen - full_len f['left'] = blen - full_len
# Process 1 frame # Process 1 frame
if has_mask: if has_mask:
# unmask payload # unmask payload
ret['mask'] = buf[header_len:header_len+4] f['mask'] = buf[f['hlen']:f['hlen']+4]
b = c = '' b = c = ''
if ret['length'] >= 4: if f['length'] >= 4:
mask = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'), mask = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'),
offset=header_len, count=1) offset=f['hlen'], count=1)
data = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'), data = numpy.frombuffer(buf, dtype=numpy.dtype('<L4'),
offset=header_len + 4, count=int(ret['length'] / 4)) offset=f['hlen'] + 4, count=int(f['length'] / 4))
#b = numpy.bitwise_xor(data, mask).data #b = numpy.bitwise_xor(data, mask).data
b = numpy.bitwise_xor(data, mask).tostring() b = numpy.bitwise_xor(data, mask).tostring()
if ret['length'] % 4: if f['length'] % 4:
print "Partial unmask" print("Partial unmask")
mask = numpy.frombuffer(buf, dtype=numpy.dtype('B'), mask = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
offset=header_len, count=(ret['length'] % 4)) offset=f['hlen'], count=(f['length'] % 4))
data = numpy.frombuffer(buf, dtype=numpy.dtype('B'), data = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
offset=full_len - (ret['length'] % 4), offset=full_len - (f['length'] % 4),
count=(ret['length'] % 4)) count=(f['length'] % 4))
c = numpy.bitwise_xor(data, mask).tostring() c = numpy.bitwise_xor(data, mask).tostring()
ret['payload'] = b + c f['payload'] = b + c
else: else:
print "Unmasked frame:", repr(buf) print("Unmasked frame: %s" % repr(buf))
ret['payload'] = buf[(header_len + has_mask * 4):full_len] f['payload'] = buf[(f['hlen'] + has_mask * 4):full_len]
if base64 and ret['opcode'] in [1, 2]: if base64 and f['opcode'] in [1, 2]:
try: try:
ret['payload'] = b64decode(ret['payload']) f['payload'] = b64decode(f['payload'])
except: except:
print "Exception while b64decoding buffer:", repr(buf) print("Exception while b64decoding buffer: %s" %
repr(buf))
raise raise
if ret['opcode'] == 0x08: if f['opcode'] == 0x08:
if ret['length'] >= 2: if f['length'] >= 2:
ret['close_code'] = struct.unpack_from( f['close_code'] = struct.unpack_from(">H", f['payload'])
">H", ret['payload']) if f['length'] > 3:
if ret['length'] > 3: f['close_reason'] = f['payload'][2:]
ret['close_reason'] = ret['payload'][2:]
return ret return f
@staticmethod @staticmethod
def encode_hixie(buf): def encode_hixie(buf):
return "\x00" + b64encode(buf) + "\xff" return s2b("\x00" + b2s(b64encode(buf)) + "\xff"), 1, 1
@staticmethod @staticmethod
def decode_hixie(buf): def decode_hixie(buf):
end = buf.find('\xff') end = buf.find(s2b('\xff'))
return {'payload': b64decode(buf[1:end]), return {'payload': b64decode(buf[1:end]),
'hlen': 1,
'length': end - 1,
'left': len(buf) - (end + 1)} 'left': len(buf) - (end + 1)}
@staticmethod
def parse_handshake(handshake):
""" Parse fields from client WebSockets handshake. """
ret = {}
req_lines = handshake.split("\r\n")
if not req_lines[0].startswith("GET "):
raise Exception("Invalid handshake: no GET request line")
ret['path'] = req_lines[0].split(" ")[1]
for line in req_lines[1:]:
if line == "": break
try:
var, val = line.split(": ")
except:
raise Exception("Invalid handshake header: %s" % line)
ret[var] = val
if req_lines[-2] == "":
ret['key3'] = req_lines[-1]
return ret
@staticmethod @staticmethod
def gen_md5(keys): def gen_md5(keys):
""" Generate hash value for WebSockets hixie-76. """ """ Generate hash value for WebSockets hixie-76. """
...@@ -309,7 +328,8 @@ Sec-WebSocket-Accept: %s\r ...@@ -309,7 +328,8 @@ Sec-WebSocket-Accept: %s\r
num1 = int("".join([c for c in key1 if c.isdigit()])) / spaces1 num1 = int("".join([c for c in key1 if c.isdigit()])) / spaces1
num2 = int("".join([c for c in key2 if c.isdigit()])) / spaces2 num2 = int("".join([c for c in key2 if c.isdigit()])) / spaces2
return md5(struct.pack('>II8s', num1, num2, key3)).digest() return b2s(md5(struct.pack('>II8s',
int(num1), int(num2), key3)).digest())
# #
# WebSocketServer logging/output functions # WebSocketServer logging/output functions
...@@ -324,7 +344,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -324,7 +344,7 @@ Sec-WebSocket-Accept: %s\r
def msg(self, msg): def msg(self, msg):
""" Output message with handler_id prefix. """ """ Output message with handler_id prefix. """
if not self.daemon: if not self.daemon:
print "% 3d: %s" % (self.handler_id, msg) print("% 3d: %s" % (self.handler_id, msg))
def vmsg(self, msg): def vmsg(self, msg):
""" Same as msg() but only if verbose. """ """ Same as msg() but only if verbose. """
...@@ -342,17 +362,27 @@ Sec-WebSocket-Accept: %s\r ...@@ -342,17 +362,27 @@ Sec-WebSocket-Accept: %s\r
than 0, then the caller should call again when the socket is than 0, then the caller should call again when the socket is
ready. """ ready. """
tdelta = int(time.time()*1000) - self.start_time
if bufs: if bufs:
for buf in bufs: for buf in bufs:
if self.version.startswith("hybi"): if self.version.startswith("hybi"):
if self.base64: if self.base64:
self.send_parts.append(self.encode_hybi(buf, encbuf, lenhead, lentail = self.encode_hybi(
opcode=1, base64=True)) buf, opcode=1, base64=True)
else: else:
self.send_parts.append(self.encode_hybi(buf, encbuf, lenhead, lentail = self.encode_hybi(
opcode=2, base64=False)) buf, opcode=2, base64=False)
else: else:
self.send_parts.append(self.encode_hixie(buf)) encbuf, lenhead, lentail = self.encode_hixie(buf)
if self.rec:
self.rec.write("%s,\n" %
repr("{%s{" % tdelta
+ encbuf[lenhead:-lentail]))
self.send_parts.append(encbuf)
while self.send_parts: while self.send_parts:
# Send pending frames # Send pending frames
...@@ -377,6 +407,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -377,6 +407,7 @@ Sec-WebSocket-Accept: %s\r
closed = False closed = False
bufs = [] bufs = []
tdelta = int(time.time()*1000) - self.start_time
buf = self.client.recv(self.buffer_size) buf = self.client.recv(self.buffer_size)
if len(buf) == 0: if len(buf) == 0:
...@@ -392,7 +423,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -392,7 +423,7 @@ Sec-WebSocket-Accept: %s\r
if self.version.startswith("hybi"): if self.version.startswith("hybi"):
frame = self.decode_hybi(buf, base64=self.base64) frame = self.decode_hybi(buf, base64=self.base64)
#print "Received buf: %s, frame: %s" % (repr(buf), frame) #print("Received buf: %s, frame: %s" % (repr(buf), frame))
if frame['payload'] == None: if frame['payload'] == None:
# Incomplete/partial frame # Incomplete/partial frame
...@@ -416,7 +447,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -416,7 +447,7 @@ Sec-WebSocket-Accept: %s\r
buf = buf[2:] buf = buf[2:]
continue # No-op continue # No-op
elif buf.count('\xff') == 0: elif buf.count(s2b('\xff')) == 0:
# Partial frame # Partial frame
self.traffic("}.") self.traffic("}.")
self.recv_part = buf self.recv_part = buf
...@@ -426,6 +457,13 @@ Sec-WebSocket-Accept: %s\r ...@@ -426,6 +457,13 @@ Sec-WebSocket-Accept: %s\r
self.traffic("}") self.traffic("}")
if self.rec:
start = frame['hlen']
end = frame['hlen'] + frame['length']
self.rec.write("%s,\n" %
repr("}%s}" % tdelta + buf[start:end]))
bufs.append(frame['payload']) bufs.append(frame['payload'])
if frame['left']: if frame['left']:
...@@ -439,7 +477,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -439,7 +477,7 @@ Sec-WebSocket-Accept: %s\r
""" Send a WebSocket orderly close frame. """ """ Send a WebSocket orderly close frame. """
if self.version.startswith("hybi"): if self.version.startswith("hybi"):
msg = '' msg = s2b('')
if code != None: if code != None:
msg = struct.pack(">H%ds" % (len(reason)), code) msg = struct.pack(">H%ds" % (len(reason)), code)
...@@ -447,7 +485,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -447,7 +485,7 @@ Sec-WebSocket-Accept: %s\r
self.client.send(buf) self.client.send(buf)
elif self.version == "hixie-76": elif self.version == "hixie-76":
buf = self.encode_hixie('\xff\x00') buf = s2b('\xff\x00')
self.client.send(buf) self.client.send(buf)
# No orderly close for 75 # No orderly close for 75
...@@ -483,14 +521,16 @@ Sec-WebSocket-Accept: %s\r ...@@ -483,14 +521,16 @@ Sec-WebSocket-Accept: %s\r
if handshake == "": if handshake == "":
raise self.EClose("ignoring empty handshake") raise self.EClose("ignoring empty handshake")
elif handshake.startswith("<policy-file-request/>"): elif handshake.startswith(s2b("<policy-file-request/>")):
# Answer Flash policy request # Answer Flash policy request
handshake = sock.recv(1024) handshake = sock.recv(1024)
sock.send(self.policy_response) sock.send(s2b(self.policy_response))
raise self.EClose("Sending flash policy response") raise self.EClose("Sending flash policy response")
elif handshake[0] in ("\x16", "\x80"): elif handshake[0] in ("\x16", "\x80"):
# SSL wrap the connection # SSL wrap the connection
if not ssl:
raise self.EClose("SSL connection but no 'ssl' module")
if not os.path.exists(self.cert): if not os.path.exists(self.cert):
raise self.EClose("SSL connection but '%s' not found" raise self.EClose("SSL connection but '%s' not found"
% self.cert) % self.cert)
...@@ -500,7 +540,8 @@ Sec-WebSocket-Accept: %s\r ...@@ -500,7 +540,8 @@ Sec-WebSocket-Accept: %s\r
server_side=True, server_side=True,
certfile=self.cert, certfile=self.cert,
keyfile=self.key) keyfile=self.key)
except ssl.SSLError, x: except ssl.SSLError:
_, x, _ = sys.exc_info()
if x.args[0] == ssl.SSL_ERROR_EOF: if x.args[0] == ssl.SSL_ERROR_EOF:
raise self.EClose("") raise self.EClose("")
else: else:
...@@ -517,29 +558,21 @@ Sec-WebSocket-Accept: %s\r ...@@ -517,29 +558,21 @@ Sec-WebSocket-Accept: %s\r
scheme = "ws" scheme = "ws"
stype = "Plain non-SSL (ws://)" stype = "Plain non-SSL (ws://)"
# Now get the data from the socket wsh = WSRequestHandler(retsock, address, not self.web)
handshake = retsock.recv(4096) if wsh.last_code == 101:
# Continue on to handle WebSocket upgrade
if len(handshake) == 0: pass
raise self.EClose("Client closed during handshake") elif wsh.last_code == 405:
raise self.EClose("Normal web request received but disallowed")
# Check for and handle normal web requests elif wsh.last_code < 200 or wsh.last_code >= 300:
if (handshake.startswith('GET ') and raise self.EClose(wsh.last_message)
handshake.find('Upgrade: WebSocket\r\n') == -1 and elif self.verbose:
handshake.find('Upgrade: websocket\r\n') == -1): raise self.EClose(wsh.last_message)
if not self.web: else:
raise self.EClose("Normal web request received but disallowed") raise self.EClose("")
sh = SplitHTTPHandler(handshake, retsock, address)
if sh.last_code < 200 or sh.last_code >= 300:
raise self.EClose(sh.last_message)
elif self.verbose:
raise self.EClose(sh.last_message)
else:
raise self.EClose("")
#self.msg("handshake: " + repr(handshake)) h = self.headers = wsh.headers
# Parse client WebSockets handshake path = self.path = wsh.path
h = self.headers = self.parse_handshake(handshake)
prot = 'WebSocket-Protocol' prot = 'WebSocket-Protocol'
protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',') protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',')
...@@ -548,8 +581,8 @@ Sec-WebSocket-Accept: %s\r ...@@ -548,8 +581,8 @@ Sec-WebSocket-Accept: %s\r
if ver: if ver:
# HyBi/IETF version of the protocol # HyBi/IETF version of the protocol
if not numpy or not ctypes: if sys.hexversion < 0x2060000 or not numpy:
self.EClose("Python numpy and ctypes modules required for HyBi-07 or greater") raise self.EClose("Python >= 2.6 and numpy module is required for HyBi-07 or greater")
if ver == '7': if ver == '7':
self.version = "hybi-07" self.version = "hybi-07"
...@@ -567,7 +600,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -567,7 +600,7 @@ Sec-WebSocket-Accept: %s\r
raise self.EClose("Client must support 'binary' or 'base64' protocol") raise self.EClose("Client must support 'binary' or 'base64' protocol")
# Generate the hash value for the accept header # Generate the hash value for the accept header
accept = b64encode(sha1(key + self.GUID).digest()) accept = b64encode(sha1(s2b(key + self.GUID)).digest())
response = self.server_handshake_hybi % accept response = self.server_handshake_hybi % accept
if self.base64: if self.base64:
...@@ -592,7 +625,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -592,7 +625,7 @@ Sec-WebSocket-Accept: %s\r
self.base64 = True self.base64 = True
response = self.server_handshake_hixie % (pre, response = self.server_handshake_hixie % (pre,
h['Origin'], pre, scheme, h['Host'], h['path']) h['Origin'], pre, scheme, h['Host'], path)
if 'base64' in protocols: if 'base64' in protocols:
response += "%sWebSocket-Protocol: base64\r\n" % pre response += "%sWebSocket-Protocol: base64\r\n" % pre
...@@ -606,7 +639,7 @@ Sec-WebSocket-Accept: %s\r ...@@ -606,7 +639,7 @@ Sec-WebSocket-Accept: %s\r
# Send server WebSockets handshake response # Send server WebSockets handshake response
#self.msg("sending response [%s]" % response) #self.msg("sending response [%s]" % response)
retsock.send(response) retsock.send(s2b(response))
# Return the WebSockets socket which may be SSL wrapped # Return the WebSockets socket which may be SSL wrapped
return retsock return retsock
...@@ -624,9 +657,8 @@ Sec-WebSocket-Accept: %s\r ...@@ -624,9 +657,8 @@ Sec-WebSocket-Accept: %s\r
#self.vmsg("Running poll()") #self.vmsg("Running poll()")
pass pass
def top_SIGCHLD(self, sig, stack): def fallback_SIGCHLD(self, sig, stack):
# Reap zombies after calling child SIGCHLD handler # Reap zombies when using os.fork() (python 2.4)
self.do_SIGCHLD(sig, stack)
self.vmsg("Got SIGCHLD, reaping zombies") self.vmsg("Got SIGCHLD, reaping zombies")
try: try:
result = os.waitpid(-1, os.WNOHANG) result = os.waitpid(-1, os.WNOHANG)
...@@ -636,14 +668,52 @@ Sec-WebSocket-Accept: %s\r ...@@ -636,14 +668,52 @@ Sec-WebSocket-Accept: %s\r
except (OSError): except (OSError):
pass pass
def do_SIGCHLD(self, sig, stack):
pass
def do_SIGINT(self, sig, stack): def do_SIGINT(self, sig, stack):
self.msg("Got SIGINT, exiting") self.msg("Got SIGINT, exiting")
sys.exit(0) sys.exit(0)
def new_client(self, client): def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """
# Initialize per client settings
self.send_parts = []
self.recv_part = None
self.base64 = False
self.rec = None
self.start_time = int(time.time()*1000)
# handler process
try:
try:
self.client = self.do_handshake(startsock, address)
if self.record:
# Record raw frame data as JavaScript array
fname = "%s.%s" % (self.record,
self.handler_id)
self.msg("opening record file: %s" % fname)
self.rec = open(fname, 'w+')
self.rec.write("var VNC_frame_data = [\n")
self.new_client()
except self.EClose:
_, exc, _ = sys.exc_info()
# Connection was not a WebSockets connection
if exc.args[0]:
self.msg("%s: %s" % (address[0], exc.args[0]))
except Exception:
_, exc, _ = sys.exc_info()
self.msg("handler exception: %s" % str(exc))
if self.verbose:
self.msg(traceback.format_exc())
finally:
if self.rec:
self.rec.write("'EOF']\n")
self.rec.close()
if self.client and self.client != startsock:
self.client.close()
def new_client(self):
""" Do something with a WebSockets client connection. """ """ Do something with a WebSockets client connection. """
raise("WebSocketServer.new_client() must be overloaded") raise("WebSocketServer.new_client() must be overloaded")
...@@ -665,9 +735,11 @@ Sec-WebSocket-Accept: %s\r ...@@ -665,9 +735,11 @@ Sec-WebSocket-Accept: %s\r
self.started() # Some things need to happen after daemonizing self.started() # Some things need to happen after daemonizing
# Reep zombies # Allow override of SIGINT
signal.signal(signal.SIGCHLD, self.top_SIGCHLD)
signal.signal(signal.SIGINT, self.do_SIGINT) signal.signal(signal.SIGINT, self.do_SIGINT)
if not Process:
# os.fork() (python 2.4) child reaper
signal.signal(signal.SIGCHLD, self.fallback_SIGCHLD)
while True: while True:
try: try:
...@@ -679,14 +751,17 @@ Sec-WebSocket-Accept: %s\r ...@@ -679,14 +751,17 @@ Sec-WebSocket-Accept: %s\r
try: try:
self.poll() self.poll()
ready = select.select([lsock], [], [], 1)[0]; ready = select.select([lsock], [], [], 1)[0]
if lsock in ready: if lsock in ready:
startsock, address = lsock.accept() startsock, address = lsock.accept()
else: else:
continue continue
except Exception, exc: except Exception:
_, exc, _ = sys.exc_info()
if hasattr(exc, 'errno'): if hasattr(exc, 'errno'):
err = exc.errno err = exc.errno
elif hasattr(exc, 'args'):
err = exc.args[0]
else: else:
err = exc[0] err = exc[0]
if err == errno.EINTR: if err == errno.EINTR:
...@@ -695,55 +770,67 @@ Sec-WebSocket-Accept: %s\r ...@@ -695,55 +770,67 @@ Sec-WebSocket-Accept: %s\r
else: else:
raise raise
self.vmsg('%s: forking handler' % address[0]) if Process:
pid = os.fork() self.vmsg('%s: new handler Process' % address[0])
p = Process(target=self.top_new_client,
if pid == 0: args=(startsock, address))
# Initialize per client settings p.start()
self.send_parts = [] # child will not return
self.recv_part = None
self.base64 = False
# handler process
self.client = self.do_handshake(
startsock, address)
self.new_client()
else: else:
# parent process # python 2.4
self.handler_id += 1 self.vmsg('%s: forking handler' % address[0])
pid = os.fork()
except self.EClose, exc: if pid == 0:
# Connection was not a WebSockets connection # child handler process
if exc.args[0]: self.top_new_client(startsock, address)
self.msg("%s: %s" % (address[0], exc.args[0])) break # child process exits
except KeyboardInterrupt, exc:
# parent process
self.handler_id += 1
except KeyboardInterrupt:
_, exc, _ = sys.exc_info()
print("In KeyboardInterrupt")
pass pass
except Exception, exc: except SystemExit:
_, exc, _ = sys.exc_info()
print("In SystemExit")
break
except Exception:
_, exc, _ = sys.exc_info()
self.msg("handler exception: %s" % str(exc)) self.msg("handler exception: %s" % str(exc))
if self.verbose: if self.verbose:
self.msg(traceback.format_exc()) self.msg(traceback.format_exc())
finally: finally:
if self.client and self.client != startsock:
self.client.close()
if startsock: if startsock:
startsock.close() startsock.close()
if pid == 0:
break # Child process exits
# HTTP handler with request from a string and response to a socket # HTTP handler with WebSocket upgrade support
class SplitHTTPHandler(SimpleHTTPRequestHandler): class WSRequestHandler(SimpleHTTPRequestHandler):
def __init__(self, req, resp, addr): def __init__(self, req, addr, only_upgrade=False):
# Save the response socket self.only_upgrade = only_upgrade # only allow upgrades
self.response = resp
SimpleHTTPRequestHandler.__init__(self, req, addr, object()) SimpleHTTPRequestHandler.__init__(self, req, addr, object())
def setup(self): def do_GET(self):
self.connection = self.response if (self.headers.get('upgrade') and
# Duck type request string to file object self.headers.get('upgrade').lower() == 'websocket'):
self.rfile = StringIO(self.request)
self.wfile = self.connection.makefile('wb', self.wbufsize) if (self.headers.get('sec-websocket-key1') or
self.headers.get('websocket-key1')):
# For Hixie-76 read out the key hash
self.headers.__setitem__('key3', self.rfile.read(8))
# Just indicate that an WebSocket upgrade is needed
self.last_code = 101
self.last_message = "101 Switching Protocols"
elif self.only_upgrade:
# Normal web request responses are disabled
self.last_code = 405
self.last_message = "405 Method Not Allowed"
else:
SimpleHTTPRequestHandler.do_GET(self)
def send_response(self, code, message=None): def send_response(self, code, message=None):
# Save the status code # Save the status code
...@@ -754,4 +841,3 @@ class SplitHTTPHandler(SimpleHTTPRequestHandler): ...@@ -754,4 +841,3 @@ class SplitHTTPHandler(SimpleHTTPRequestHandler):
# Save instead of printing # Save instead of printing
self.last_message = f % args self.last_message = f % args
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
''' '''
A WebSocket to TCP socket proxy with support for "wss://" encryption. A WebSocket to TCP socket proxy with support for "wss://" encryption.
Copyright 2010 Joel Martin Copyright 2011 Joel Martin
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3) Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
You can make a cert/key with openssl using: You can make a cert/key with openssl using:
...@@ -74,7 +74,7 @@ Traffic Legend: ...@@ -74,7 +74,7 @@ Traffic Legend:
WebSocketServer.__init__(self, *args, **kwargs) WebSocketServer.__init__(self, *args, **kwargs)
def run_wrap_cmd(self): def run_wrap_cmd(self):
print "Starting '%s'" % " ".join(self.wrap_cmd) print("Starting '%s'" % " ".join(self.wrap_cmd))
self.wrap_times.append(time.time()) self.wrap_times.append(time.time())
self.wrap_times.pop(0) self.wrap_times.pop(0)
self.cmd = subprocess.Popen( self.cmd = subprocess.Popen(
...@@ -88,14 +88,14 @@ Traffic Legend: ...@@ -88,14 +88,14 @@ Traffic Legend:
# Need to call wrapped command after daemonization so we can # Need to call wrapped command after daemonization so we can
# know when the wrapped command exits # know when the wrapped command exits
if self.wrap_cmd: if self.wrap_cmd:
print " - proxying from %s:%s to '%s' (port %s)\n" % ( print(" - proxying from %s:%s to '%s' (port %s)\n" % (
self.listen_host, self.listen_port, self.listen_host, self.listen_port,
" ".join(self.wrap_cmd), self.target_port) " ".join(self.wrap_cmd), self.target_port))
self.run_wrap_cmd() self.run_wrap_cmd()
else: else:
print " - proxying from %s:%s to %s:%s\n" % ( print(" - proxying from %s:%s to %s:%s\n" % (
self.listen_host, self.listen_port, self.listen_host, self.listen_port,
self.target_host, self.target_port) self.target_host, self.target_port))
def poll(self): def poll(self):
# If we are wrapping a command, check it's status # If we are wrapping a command, check it's status
...@@ -118,7 +118,7 @@ Traffic Legend: ...@@ -118,7 +118,7 @@ Traffic Legend:
if (now - avg) < 10: if (now - avg) < 10:
# 3 times in the last 10 seconds # 3 times in the last 10 seconds
if self.spawn_message: if self.spawn_message:
print "Command respawning too fast" print("Command respawning too fast")
self.spawn_message = False self.spawn_message = False
else: else:
self.run_wrap_cmd() self.run_wrap_cmd()
...@@ -138,15 +138,6 @@ Traffic Legend: ...@@ -138,15 +138,6 @@ Traffic Legend:
Called after a new WebSocket connection has been established. Called after a new WebSocket connection has been established.
""" """
self.rec = None
if self.record:
# Record raw frame data as a JavaScript compatible file
fname = "%s.%s" % (self.record,
self.handler_id)
self.msg("opening record file: %s" % fname)
self.rec = open(fname, 'w+')
self.rec.write("var VNC_frame_data = [\n")
# Connect to the target # Connect to the target
self.msg("connecting to: %s:%s" % ( self.msg("connecting to: %s:%s" % (
self.target_host, self.target_port)) self.target_host, self.target_port))
...@@ -154,19 +145,17 @@ Traffic Legend: ...@@ -154,19 +145,17 @@ Traffic Legend:
tsock.connect((self.target_host, self.target_port)) tsock.connect((self.target_host, self.target_port))
if self.verbose and not self.daemon: if self.verbose and not self.daemon:
print self.traffic_legend print(self.traffic_legend)
# Start proxying # Start proxying
try: try:
self.do_proxy(tsock) self.do_proxy(tsock)
except: except:
if tsock: if tsock:
tsock.shutdown(socket.SHUT_RDWR)
tsock.close() tsock.close()
self.vmsg("%s:%s: Target closed" %( self.vmsg("%s:%s: Target closed" %(
self.target_host, self.target_port)) self.target_host, self.target_port))
if self.rec:
self.rec.write("'EOF']\n")
self.rec.close()
raise raise
def do_proxy(self, target): def do_proxy(self, target):
...@@ -177,11 +166,9 @@ Traffic Legend: ...@@ -177,11 +166,9 @@ Traffic Legend:
c_pend = 0 c_pend = 0
tqueue = [] tqueue = []
rlist = [self.client, target] rlist = [self.client, target]
tstart = int(time.time()*1000)
while True: while True:
wlist = [] wlist = []
tdelta = int(time.time()*1000) - tstart
if tqueue: wlist.append(target) if tqueue: wlist.append(target)
if cqueue or c_pend: wlist.append(self.client) if cqueue or c_pend: wlist.append(self.client)
...@@ -212,11 +199,8 @@ Traffic Legend: ...@@ -212,11 +199,8 @@ Traffic Legend:
if self.client in outs: if self.client in outs:
# Send queued target data to the client # Send queued target data to the client
c_pend = self.send_frames(cqueue) c_pend = self.send_frames(cqueue)
cqueue = []
#if self.rec: cqueue = []
# self.rec.write("%s,\n" %
# repr("{%s{" % tdelta + dat[1:-1]))
if self.client in ins: if self.client in ins:
...@@ -224,11 +208,6 @@ Traffic Legend: ...@@ -224,11 +208,6 @@ Traffic Legend:
bufs, closed = self.recv_frames() bufs, closed = self.recv_frames()
tqueue.extend(bufs) tqueue.extend(bufs)
#if self.rec:
# for b in bufs:
# self.rec.write(
# repr("}%s}%s" % (tdelta, b)) + ",\n")
if closed: if closed:
# TODO: What about blocking on client socket? # TODO: What about blocking on client socket?
self.send_close() self.send_close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment