Commit 6a883409 authored by Joel Martin's avatar Joel Martin

Refactor and cleanup websocket.py and deps.

Moved websocket.py code into a class WebSocketServer. WebSockets
server implementations will sub-class and define a handler() method
which is passed the client socket after. Global variable settings have been
changed to be parameters for WebSocketServer when created.

Subclass implementations still have to handle queueing and sending but
the parent class handles everything else (daemonizing, websocket
handshake, encode/decode, etc). It would be better if the parent class
could handle queueing and sending. This adds some buffering and
polling complexity to the parent class but it would be better to do so
at some point. However, the result is still much cleaner as can be
seen in wsecho.py.

Refactored wsproxy.py and wstest.py (formerly ws.py) to use the new
class. Added wsecho.py as a simple echo server.

- rename tests/ws.py to utils/wstest.py and add a symlink from
  tests/wstest.py

- rename tests/ws.html to tests/wstest.html to match utils/wstest.py.

- add utils/wsecho.py

- add tests/wsecho.html which communicates with wsecho.py and simply
  sends periodic messages and shows what is received.
parent 6ace64d3
#!/usr/bin/python
'''
WebSocket server-side load test program. Sends and receives traffic
that has a random payload (length and content) that is checksummed and
given a sequence number. Any errors are reported and counted.
'''
import sys, os, socket, ssl, time, traceback
import random, time
from base64 import b64encode, b64decode
from select import select
sys.path.insert(0,os.path.dirname(__file__) + "/../utils/")
from websocket import *
buffer_size = 65536
max_packet_size = 10000
recv_cnt = send_cnt = 0
def check(buf):
global recv_cnt
try:
data_list = decode(buf)
except:
print "\n<BOF>" + repr(buf) + "<EOF>"
return "Failed to decode"
err = ""
for data in data_list:
if data.count('$') > 1:
raise Exception("Multiple parts within single packet")
if len(data) == 0:
traffic("_")
continue
if data[0] != "^":
err += "buf did not start with '^'\n"
continue
try:
cnt, length, chksum, nums = data[1:-1].split(':')
cnt = int(cnt)
length = int(length)
chksum = int(chksum)
except:
print "\n<BOF>" + repr(data) + "<EOF>"
err += "Invalid data format\n"
continue
if recv_cnt != cnt:
err += "Expected count %d but got %d\n" % (recv_cnt, cnt)
recv_cnt = cnt + 1
continue
recv_cnt += 1
if len(nums) != length:
err += "Expected length %d but got %d\n" % (length, len(nums))
continue
inv = nums.translate(None, "0123456789")
if inv:
err += "Invalid characters found: %s\n" % inv
continue
real_chksum = 0
for num in nums:
real_chksum += int(num)
if real_chksum != chksum:
err += "Expected checksum %d but real chksum is %d\n" % (chksum, real_chksum)
return err
def generate():
global send_cnt, rand_array
length = random.randint(10, max_packet_size)
numlist = rand_array[max_packet_size-length:]
# Error in length
#numlist.append(5)
chksum = sum(numlist)
# Error in checksum
#numlist[0] = 5
nums = "".join( [str(n) for n in numlist] )
data = "^%d:%d:%d:%s$" % (send_cnt, length, chksum, nums)
send_cnt += 1
return encode(data)
def responder(client, delay=10):
global errors
cqueue = []
cpartial = ""
socks = [client]
last_send = time.time() * 1000
while True:
ins, outs, excepts = select(socks, socks, socks, 1)
if excepts: raise Exception("Socket exception")
if client in ins:
buf = client.recv(buffer_size)
if len(buf) == 0: raise Exception("Client closed")
#print "Client recv: %s (%d)" % (repr(buf[1:-1]), len(buf))
if buf[-1] == '\xff':
if cpartial:
err = check(cpartial + buf)
cpartial = ""
else:
err = check(buf)
if err:
traffic("}")
errors = errors + 1
print err
else:
traffic(">")
else:
traffic(".>")
cpartial = cpartial + buf
now = time.time() * 1000
if client in outs and now > (last_send + delay):
last_send = now
#print "Client send: %s" % repr(cqueue[0])
client.send(generate())
traffic("<")
def test_handler(client):
global errors, delay, send_cnt, recv_cnt
send_cnt = 0
recv_cnt = 0
try:
responder(client, delay)
except:
print "accumulated errors:", errors
errors = 0
raise
if __name__ == '__main__':
errors = 0
try:
if len(sys.argv) < 2: raise
listen_port = int(sys.argv[1])
if len(sys.argv) == 3:
delay = int(sys.argv[2])
else:
delay = 10
except:
print "Usage: <listen_port> [delay_ms]"
sys.exit(1)
print "Prepopulating random array"
rand_array = []
for i in range(0, max_packet_size):
rand_array.append(random.randint(0, 9))
settings['listen_port'] = listen_port
settings['daemon'] = False
settings['handler'] = test_handler
start_server()
<html>
<head>
<title>WebSockets Echo Test</title>
<script src="include/base64.js"></script>
<script src="include/util.js"></script>
<script src="include/webutil.js"></script>
<!-- Uncomment to activate firebug lite -->
<!--
<script type='text/javascript'
src='http://getfirebug.com/releases/lite/1.2/firebug-lite-compressed.js'></script>
-->
</head>
<body>
Host: <input id='host' style='width:100'>&nbsp;
Port: <input id='port' style='width:50'>&nbsp;
Encrypt: <input id='encrypt' type='checkbox'>&nbsp;
<input id='connectButton' type='button' value='Start' style='width:100px'
onclick="connect();">&nbsp;
<br>
Log:<br>
<textarea id="messages" style="font-size: 9;" cols=80 rows=25></textarea>
</body>
<script>
var ws, host = null, port = null,
msg_cnt = 0, send_cnt = 1, echoDelay = 500,
echo_ref;
function message(str) {
console.log(str);
cell = $D('messages');
cell.innerHTML += msg_cnt + ": " + str + "\n";
cell.scrollTop = cell.scrollHeight;
msg_cnt++;
}
Array.prototype.pushStr = function (str) {
var n = str.length;
for (var i=0; i < n; i++) {
this.push(str.charCodeAt(i));
}
}
function send_msg() {
if (ws.bufferedAmount > 0) {
console.log("Delaying send");
return;
}
var str = "Message #" + send_cnt, arr = [];
arr.pushStr(str)
ws.send(Base64.encode(arr));
message("Sent message: '" + str + "'");
send_cnt++;
}
function update_stats() {
$D('sent').innerHTML = sent;
$D('received').innerHTML = received;
$D('errors').innerHTML = errors;
}
function init_ws() {
console.log(">> init_ws");
console.log("<< init_ws");
}
function connect() {
var host = $D('host').value,
port = $D('port').value,
scheme = "ws://", uri;
console.log(">> connect");
if ((!host) || (!port)) {
console.log("must set host and port");
return;
}
if (ws) {
ws.close();
}
if ($D('encrypt').checked) {
scheme = "wss://";
}
uri = scheme + host + ":" + port;
message("connecting to " + uri);
ws = new WebSocket(uri);
ws.onmessage = function(e) {
//console.log(">> WebSockets.onmessage");
var arr = Base64.decode(e.data), str = "", i;
for (i = 0; i < arr.length; i++) {
str = str + String.fromCharCode(arr[i]);
}
message("Received message '" + str + "'");
//console.log("<< WebSockets.onmessage");
};
ws.onopen = function(e) {
console.log(">> WebSockets.onopen");
echo_ref = setInterval(send_msg, echoDelay);
console.log("<< WebSockets.onopen");
};
ws.onclose = function(e) {
console.log(">> WebSockets.onclose");
if (echo_ref) {
clearInterval(echo_ref);
echo_ref = null;
}
console.log("<< WebSockets.onclose");
};
ws.onerror = function(e) {
console.log(">> WebSockets.onerror");
if (echo_ref) {
clearInterval(echo_ref);
echo_ref = null;
}
console.log("<< WebSockets.onerror");
};
$D('connectButton').value = "Stop";
$D('connectButton').onclick = disconnect;
console.log("<< connect");
}
function disconnect() {
console.log(">> disconnect");
if (ws) {
ws.close();
}
if (echo_ref) {
clearInterval(echo_ref);
}
$D('connectButton').value = "Start";
$D('connectButton').onclick = connect;
console.log("<< disconnect");
}
/* If no builtin websockets then load web_socket.js */
if (window.WebSocket) {
VNC_native_ws = true;
} else {
VNC_native_ws = false;
console.log("Loading web-socket-js flash bridge");
var extra = "<script src='include/web-socket-js/swfobject.js'><\/script>";
extra += "<script src='include/web-socket-js/FABridge.js'><\/script>";
extra += "<script src='include/web-socket-js/web_socket.js'><\/script>";
document.write(extra);
}
window.onload = function() {
console.log("onload");
if (!VNC_native_ws) {
console.log("initializing web-socket-js flash bridge");
WebSocket.__swfLocation = "include/web-socket-js/WebSocketMain.swf";
WebSocket.__initialize();
}
var url = document.location.href;
$D('host').value = (url.match(/host=([^&#]*)/) || ['',''])[1];
$D('port').value = (url.match(/port=([^&#]*)/) || ['',''])[1];
}
</script>
</html>
../utils/wstest.py
\ No newline at end of file
This diff is collapsed.
...@@ -11,14 +11,21 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ...@@ -11,14 +11,21 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
''' '''
import socket, optparse, time import socket, optparse, time, os
from select import select from select import select
from websocket import * from websocket import WebSocketServer
buffer_size = 65536 class WebSocketProxy(WebSocketServer):
rec = None """
Proxy traffic to and from a WebSockets client to a normal TCP
socket server target. All traffic to/from the client is base64
encoded/decoded to allow binary data to be sent/received to/from
the target.
"""
traffic_legend = """ buffer_size = 65536
traffic_legend = """
Traffic Legend: Traffic Legend:
} - Client receive } - Client receive
}. - Client receive partial }. - Client receive partial
...@@ -30,9 +37,49 @@ Traffic Legend: ...@@ -30,9 +37,49 @@ Traffic Legend:
<. - Client send partial <. - Client send partial
""" """
def do_proxy(client, target): def __init__(self, *args, **kwargs):
""" Proxy WebSocket to normal socket. """ # Save off the target host:port
global rec self.target_host = kwargs.pop('target_host')
self.target_port = kwargs.pop('target_port')
WebSocketServer.__init__(self, *args, **kwargs)
def handler(self, client):
"""
Called after a new WebSocket connection has been established.
"""
self.rec = None
if self.record:
# Record raw frame data as a JavaScript compatible file
fname = "%s.%s" % (self.record,
self.handler_id)
self.msg("opening record file: %s" % fname)
self.rec = open(fname, 'w+')
self.rec.write("var VNC_frame_data = [\n")
# Connect to the target
self.msg("connecting to: %s:%s" % (
self.target_host, self.target_port))
tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tsock.connect((self.target_host, self.target_port))
if self.verbose and not self.daemon:
print self.traffic_legend
# Stat proxying
try:
self.do_proxy(client, tsock)
except:
if tsock: tsock.close()
if self.rec:
self.rec.write("'EOF']\n")
self.rec.close()
raise
def do_proxy(self, client, target):
"""
Proxy client WebSocket to normal target socket.
"""
cqueue = [] cqueue = []
cpartial = "" cpartial = ""
tqueue = [] tqueue = []
...@@ -42,90 +89,71 @@ def do_proxy(client, target): ...@@ -42,90 +89,71 @@ def do_proxy(client, target):
while True: while True:
wlist = [] wlist = []
tdelta = int(time.time()*1000) - tstart tdelta = int(time.time()*1000) - tstart
if tqueue: wlist.append(target) if tqueue: wlist.append(target)
if cqueue: wlist.append(client) if cqueue: wlist.append(client)
ins, outs, excepts = select(rlist, wlist, [], 1) ins, outs, excepts = select(rlist, wlist, [], 1)
if excepts: raise Exception("Socket exception") if excepts: raise Exception("Socket exception")
if target in outs: if target in outs:
# Send queued client data to the target
dat = tqueue.pop(0) dat = tqueue.pop(0)
sent = target.send(dat) sent = target.send(dat)
if sent == len(dat): if sent == len(dat):
traffic(">") self.traffic(">")
else: else:
# requeue the remaining data
tqueue.insert(0, dat[sent:]) tqueue.insert(0, dat[sent:])
traffic(".>") self.traffic(".>")
##if rec: rec.write("Target send: %s\n" % map(ord, dat))
if client in outs: if client in outs:
# Send queued target data to the client
dat = cqueue.pop(0) dat = cqueue.pop(0)
sent = client.send(dat) sent = client.send(dat)
if sent == len(dat): if sent == len(dat):
traffic("<") self.traffic("<")
##if rec: rec.write("Client send: %s ...\n" % repr(dat[0:80])) if self.rec:
if rec: rec.write("%s,\n" % repr("{%s{" % tdelta + dat[1:-1])) self.rec.write("%s,\n" %
repr("{%s{" % tdelta + dat[1:-1]))
else: else:
cqueue.insert(0, dat[sent:]) cqueue.insert(0, dat[sent:])
traffic("<.") self.traffic("<.")
##if rec: rec.write("Client send partial: %s\n" % repr(dat[0:send]))
if target in ins: if target in ins:
buf = target.recv(buffer_size) # Receive target data, encode it and queue for client
if len(buf) == 0: raise EClose("Target closed") buf = target.recv(self.buffer_size)
if len(buf) == 0: raise self.EClose("Target closed")
cqueue.append(encode(buf)) cqueue.append(self.encode(buf))
traffic("{") self.traffic("{")
##if rec: rec.write("Target recv (%d): %s\n" % (len(buf), map(ord, buf)))
if client in ins: if client in ins:
buf = client.recv(buffer_size) # Receive client data, decode it, and queue for target
if len(buf) == 0: raise EClose("Client closed") buf = client.recv(self.buffer_size)
if len(buf) == 0: raise self.EClose("Client closed")
if buf == '\xff\x00': if buf == '\xff\x00':
raise EClose("Client sent orderly close frame") raise self.EClose("Client sent orderly close frame")
elif buf[-1] == '\xff': elif buf[-1] == '\xff':
if buf.count('\xff') > 1: if buf.count('\xff') > 1:
traffic(str(buf.count('\xff'))) self.traffic(str(buf.count('\xff')))
traffic("}") self.traffic("}")
##if rec: rec.write("Client recv (%d): %s\n" % (len(buf), repr(buf))) if self.rec:
if rec: rec.write("%s,\n" % (repr("}%s}" % tdelta + buf[1:-1]))) self.rec.write("%s,\n" %
(repr("}%s}" % tdelta + buf[1:-1])))
if cpartial: if cpartial:
tqueue.extend(decode(cpartial + buf)) # Prepend saved partial and decode frame(s)
tqueue.extend(self.decode(cpartial + buf))
cpartial = "" cpartial = ""
else: else:
tqueue.extend(decode(buf)) # decode frame(s)
tqueue.extend(self.decode(buf))
else: else:
traffic(".}") # Save off partial WebSockets frame
##if rec: rec.write("Client recv partial (%d): %s\n" % (len(buf), repr(buf))) self.traffic(".}")
cpartial = cpartial + buf cpartial = cpartial + buf
def proxy_handler(client):
global target_host, target_port, options, rec, fname
if settings['record']:
fname = "%s.%s" % (settings['record'],
settings['handler_id'])
handler_msg("opening record file: %s" % fname)
rec = open(fname, 'w+')
rec.write("var VNC_frame_data = [\n")
handler_msg("connecting to: %s:%s" % (target_host, target_port))
tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tsock.connect((target_host, target_port))
if settings['verbose'] and not settings['daemon']:
print traffic_legend
try:
do_proxy(client, tsock)
except:
if tsock: tsock.close()
if rec:
rec.write("'EOF']\n")
rec.close()
raise
if __name__ == '__main__': if __name__ == '__main__':
usage = "%prog [--record FILE]" usage = "%prog [--record FILE]"
usage += " [source_addr:]source_port target_addr:target_port" usage += " [source_addr:]source_port target_addr:target_port"
...@@ -145,40 +173,31 @@ if __name__ == '__main__': ...@@ -145,40 +173,31 @@ if __name__ == '__main__':
help="disallow non-encrypted connections") help="disallow non-encrypted connections")
parser.add_option("--web", default=None, metavar="DIR", parser.add_option("--web", default=None, metavar="DIR",
help="run webserver on same port. Serve files from DIR.") help="run webserver on same port. Serve files from DIR.")
(options, args) = parser.parse_args() (opts, args) = parser.parse_args()
# Sanity checks
if len(args) > 2: parser.error("Too many arguments") if len(args) > 2: parser.error("Too many arguments")
if len(args) < 2: parser.error("Too few arguments") if len(args) < 2: parser.error("Too few arguments")
if opts.ssl_only and not os.path.exists(opts.cert):
parser.error("SSL only and %s not found" % opts.cert)
elif not os.path.exists(opts.cert):
print "Warning: %s not found" % opts.cert
# Parse host:port and convert ports to numbers
if args[0].count(':') > 0: if args[0].count(':') > 0:
host,port = args[0].split(':') opts.listen_host, opts.listen_port = args[0].split(':')
else: else:
host,port = '',args[0] opts.listen_host, opts.listen_port = '', args[0]
if args[1].count(':') > 0: if args[1].count(':') > 0:
target_host,target_port = args[1].split(':') opts.target_host, opts.target_port = args[1].split(':')
else: else:
parser.error("Error parsing target") parser.error("Error parsing target")
try: port = int(port) try: opts.listen_port = int(opts.listen_port)
except: parser.error("Error parsing listen port") except: parser.error("Error parsing listen port")
try: target_port = int(target_port) try: opts.target_port = int(opts.target_port)
except: parser.error("Error parsing target port") except: parser.error("Error parsing target port")
if options.ssl_only and not os.path.exists(options.cert): # Create and start the WebSockets proxy
parser.error("SSL only and %s not found" % options.cert) server = WebSocketProxy(**opts.__dict__)
elif not os.path.exists(options.cert): server.start_server()
print "Warning: %s not found" % options.cert
settings['verbose'] = options.verbose
settings['listen_host'] = host
settings['listen_port'] = port
settings['handler'] = proxy_handler
settings['cert'] = os.path.abspath(options.cert)
if options.key:
settings['key'] = os.path.abspath(options.key)
settings['ssl_only'] = options.ssl_only
settings['daemon'] = options.daemon
if options.record:
settings['record'] = os.path.abspath(options.record)
if options.web:
os.chdir = options.web
settings['web'] = options.web
start_server()
#!/usr/bin/python
'''
WebSocket server-side load test program. Sends and receives traffic
that has a random payload (length and content) that is checksummed and
given a sequence number. Any errors are reported and counted.
'''
import sys, os, socket, ssl, time, traceback
import random, time
from select import select
sys.path.insert(0,os.path.dirname(__file__) + "/../utils/")
from websocket import WebSocketServer
class WebSocketTest(WebSocketServer):
buffer_size = 65536
max_packet_size = 10000
recv_cnt = 0
send_cnt = 0
def __init__(self, *args, **kwargs):
self.errors = 0
self.delay = kwargs.pop('delay')
print "Prepopulating random array"
self.rand_array = []
for i in range(0, self.max_packet_size):
self.rand_array.append(random.randint(0, 9))
WebSocketServer.__init__(self, *args, **kwargs)
def handler(self, client):
self.send_cnt = 0
self.recv_cnt = 0
try:
self.responder(client)
except:
print "accumulated errors:", self.errors
self.errors = 0
raise
def responder(self, client):
cqueue = []
cpartial = ""
socks = [client]
last_send = time.time() * 1000
while True:
ins, outs, excepts = select(socks, socks, socks, 1)
if excepts: raise Exception("Socket exception")
if client in ins:
buf = client.recv(self.buffer_size)
if len(buf) == 0:
raise self.EClose("Client closed")
#print "Client recv: %s (%d)" % (repr(buf[1:-1]), len(buf))
if buf[-1] == '\xff':
if cpartial:
err = self.check(cpartial + buf)
cpartial = ""
else:
err = self.check(buf)
if err:
self.traffic("}")
self.errors = self.errors + 1
print err
else:
self.traffic(">")
else:
self.traffic(".>")
cpartial = cpartial + buf
now = time.time() * 1000
if client in outs and now > (last_send + self.delay):
last_send = now
#print "Client send: %s" % repr(cqueue[0])
client.send(self.generate())
self.traffic("<")
def generate(self):
length = random.randint(10, self.max_packet_size)
numlist = self.rand_array[self.max_packet_size-length:]
# Error in length
#numlist.append(5)
chksum = sum(numlist)
# Error in checksum
#numlist[0] = 5
nums = "".join( [str(n) for n in numlist] )
data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums)
self.send_cnt += 1
return WebSocketServer.encode(data)
def check(self, buf):
try:
data_list = WebSocketServer.decode(buf)
except:
print "\n<BOF>" + repr(buf) + "<EOF>"
return "Failed to decode"
err = ""
for data in data_list:
if data.count('$') > 1:
raise Exception("Multiple parts within single packet")
if len(data) == 0:
self.traffic("_")
continue
if data[0] != "^":
err += "buf did not start with '^'\n"
continue
try:
cnt, length, chksum, nums = data[1:-1].split(':')
cnt = int(cnt)
length = int(length)
chksum = int(chksum)
except:
print "\n<BOF>" + repr(data) + "<EOF>"
err += "Invalid data format\n"
continue
if self.recv_cnt != cnt:
err += "Expected count %d but got %d\n" % (self.recv_cnt, cnt)
self.recv_cnt = cnt + 1
continue
self.recv_cnt += 1
if len(nums) != length:
err += "Expected length %d but got %d\n" % (length, len(nums))
continue
inv = nums.translate(None, "0123456789")
if inv:
err += "Invalid characters found: %s\n" % inv
continue
real_chksum = 0
for num in nums:
real_chksum += int(num)
if real_chksum != chksum:
err += "Expected checksum %d but real chksum is %d\n" % (chksum, real_chksum)
return err
if __name__ == '__main__':
try:
if len(sys.argv) < 2: raise
listen_port = int(sys.argv[1])
if len(sys.argv) == 3:
delay = int(sys.argv[2])
else:
delay = 10
except:
print "Usage: %s <listen_port> [delay_ms]" % sys.argv[0]
sys.exit(1)
server = WebSocketTest(
listen_port=listen_port,
verbose=True,
cert='self.pem',
web='.',
delay=delay)
server.start_server()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment