Testing infrastructure for the web socket server

This commit is contained in:
Kovid Goyal 2015-10-25 16:06:46 +05:30
parent 4a6f04fd95
commit fb579a6257
2 changed files with 163 additions and 11 deletions

View File

@ -4,8 +4,134 @@
from __future__ import (unicode_literals, division, absolute_import, from __future__ import (unicode_literals, division, absolute_import,
print_function) print_function)
import socket, os, struct
from base64 import standard_b64encode
from collections import deque, namedtuple
from functools import partial
from hashlib import sha1
from calibre.srv.tests.base import BaseTest, TestServer from calibre.srv.tests.base import BaseTest, TestServer
from calibre.srv.web_socket import GUID_STR, BINARY, TEXT, MessageWriter, create_frame, CLOSE, NORMAL_CLOSE
from calibre.utils.monotonic import monotonic
from calibre.utils.socket_inheritance import set_socket_inherit
HANDSHAKE_STR = '''\
GET / HTTP/1.1\r
Upgrade: websocket\r
Connection: Upgrade\r
Sec-WebSocket-Key: {}\r
Sec-WebSocket-Version: 13\r
''' + '\r\n'
Frame = namedtuple('Frame', 'fin opcode payload')
class WSClient(object):
def __init__(self, port, timeout=5):
self.timeout = timeout
self.socket = socket.create_connection(('localhost', port), timeout)
set_socket_inherit(self.socket, False)
self.key = standard_b64encode(os.urandom(8))
self.socket.sendall(HANDSHAKE_STR.format(self.key).encode('ascii'))
self.read_buf = deque()
self.read_upgrade_response()
self.mask = os.urandom(4)
self.frames = []
def read_upgrade_response(self):
from calibre.srv.http_request import read_headers
st = monotonic()
buf, idx = b'', -1
while idx == -1:
data = self.socket.recv(1024)
if not data:
raise ValueError('Server did not respond with a valid HTTP upgrade response')
buf += data
if len(buf) > 4096:
raise ValueError('Server responded with too much data to HTTP upgrade request')
if monotonic() - st > self.timeout:
raise ValueError('Timed out while waiting for server response to HTTP upgrade')
idx = buf.find(b'\r\n\r\n')
response, rest = buf[:idx+4], buf[idx+4:]
if rest:
self.read_buf.append(rest)
lines = (x + b'\r\n' for x in response.split(b'\r\n')[:-1])
rl = next(lines)
if rl != b'HTTP/1.1 101 Switching Protocols\r\n':
raise ValueError('Server did not respond with correct switching protocols line')
headers = read_headers(partial(next, lines))
key = standard_b64encode(sha1(self.key + GUID_STR).digest())
if headers.get('Sec-WebSocket-Accept') != key:
raise ValueError('Server did not respond with correct key in Sec-WebSocket-Accept')
def recv(self, max_amt):
if self.read_buf:
data = self.read_buf.popleft()
if len(data) <= max_amt:
return data
self.read_buf.appendleft(data[max_amt+1:])
return data[:max_amt + 1]
return self.socket.recv(max_amt)
def read_size(self, size):
ans = b''
while len(ans) < size:
d = self.recv(size - len(ans))
if not d:
raise ValueError('Connection to server closed, no data received')
ans += d
return ans
def read_frame(self):
b1, b2 = bytearray(self.read_size(2))
fin = b1 & 0b10000000
opcode = b1 & 0b1111
masked = b2 & 0b10000000
if masked:
raise ValueError('Got a frame with mask bit set from the server')
payload_length = b2 & 0b01111111
if payload_length == 126:
payload_length = struct.unpack(b'!H', self.read_size(2))
elif payload_length == 127:
payload_length = struct.unpack(b'!Q', self.read_size(8))
return Frame(fin, opcode, self.read_size(payload_length))
def read_message(self):
frames = []
while True:
frame = self.read_frame()
frames.append(frame)
if frame.fin:
break
ans, opcode = [], None
for frame in frames:
if frame is frames[0]:
opcode = frame.opcode
if frame.fin == 0 and frame.opcode not in (BINARY, TEXT):
raise ValueError('Server sent a start frame with fin=0 and bad opcode')
ans.append(frame.payload)
ans = b''.join(ans)
if opcode == TEXT:
ans = ans.decode('utf-8')
return opcode, ans
def write_message(self, msg, chunk_size=None):
w = MessageWriter(msg, self.mask, chunk_size)
while True:
frame = w.create_frame()
if frame is None:
break
self.socket.sendall(frame.getvalue())
def write_frame(self, fin, opcode, payload=b'', rsv=0, mask=True):
frame = create_frame(fin, opcode, payload, rsv=rsv, mask=self.mask if mask else None)
self.socket.sendall(frame)
def write_close(self, code, reason=b''):
if isinstance(reason, type('')):
reason = reason.encode('utf-8')
self.write_frame(1, CLOSE, struct.pack(b'!H', code) + reason)
class TestHandler(object): class TestHandler(object):
@ -14,13 +140,13 @@ class TestHandler(object):
self.connection_state = {} self.connection_state = {}
def conn(self, cid): def conn(self, cid):
ans = self.ws_connections.get(cid) ans = self.connections.get(cid)
if ans is not None: if ans is not None:
ans = ans() ans = ans()
return ans return ans
def handle_websocket_upgrade(self, connection_id, connection_ref, inheaders): def handle_websocket_upgrade(self, connection_id, connection_ref, inheaders):
self.ws_connections[connection_id] = connection_ref self.connections[connection_id] = connection_ref
def handle_websocket_data(self, data, message_starting, message_finished, connection_id): def handle_websocket_data(self, data, message_starting, message_finished, connection_id):
pass pass
@ -42,7 +168,7 @@ class EchoHandler(TestHandler):
j = '' if isinstance(self.msg_buf[0], type('')) else b'' j = '' if isinstance(self.msg_buf[0], type('')) else b''
msg = j.join(self.msg_buf) msg = j.join(self.msg_buf)
self.msg_buf = [] self.msg_buf = []
self.conn(connection_id).send_websocket_message(msg) self.conn(connection_id).send_websocket_message(msg, wakeup=False)
class WSTestServer(TestServer): class WSTestServer(TestServer):
@ -56,10 +182,32 @@ class WSTestServer(TestServer):
def ws_handler(self): def ws_handler(self):
return self.loop.handler.websocket_handler return self.loop.handler.websocket_handler
def ws_connect(self):
return WSClient(self.address[1])
class WebSocketTest(BaseTest): class WebSocketTest(BaseTest):
def simple_test(self, client, msgs, expected, close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE'):
for msg in msgs:
client.write_message(msg)
for ex in expected:
if isinstance(ex, type('')):
ex = TEXT, ex
elif isinstance(ex, bytes):
ex = BINARY, ex
elif isinstance(ex, int):
ex = ex, b''
self.ae(ex, client.read_message())
if send_close:
client.write_close(close_code, close_reason)
opcode, data = client.read_message()
self.ae(opcode, CLOSE)
self.ae(close_code, struct.unpack_from(b'!H', data, 0)[0])
def test_websocket_basic(self): def test_websocket_basic(self):
'Test basic interaction with the websocket server' 'Test basic interaction with the websocket server'
with WSTestServer(EchoHandler): with WSTestServer(EchoHandler) as server:
pass client = server.ws_connect()
st = partial(self.simple_test, client)
st([''], [''])

View File

@ -116,6 +116,9 @@ class ReadFrame(object): # {{{
self.state = self.read_payload self.state = self.read_payload
self.pos = 0 self.pos = 0
self.frame_starting = True self.frame_starting = True
if self.payload_length == 0:
conn.ws_data_received(b'', self.opcode, True, True, bool(self.fin))
self.state = None
def read_payload(self, conn): def read_payload(self, conn):
bytes_left = self.payload_length - self.pos bytes_left = self.payload_length - self.pos
@ -133,13 +136,13 @@ class ReadFrame(object): # {{{
# Sending frames {{{ # Sending frames {{{
def create_frame(fin, opcode, payload, mask=None): def create_frame(fin, opcode, payload, mask=None, rsv=0):
if isinstance(payload, type('')): if isinstance(payload, type('')):
payload = payload.encode('utf-8') payload = payload.encode('utf-8')
l = len(payload) l = len(payload)
opcode &= 0b1111 opcode &= 0b1111
b1 = opcode | (0b10000000 if fin else 0) b1 = opcode | (0b10000000 if fin else 0)
b2 = 0 if mask is None else 0b10000000 b2 = rsv | (0 if mask is None else 0b10000000)
if l < 126: if l < 126:
header = bytes(bytearray((b1, b2 | l))) header = bytes(bytearray((b1, b2 | l)))
elif 126 <= l <= 65535: elif 126 <= l <= 65535:
@ -159,13 +162,14 @@ def create_frame(fin, opcode, payload, mask=None):
class MessageWriter(object): class MessageWriter(object):
def __init__(self, buf): def __init__(self, buf, mask=None, chunk_size=None):
self.buf, self.data_type = buf, BINARY self.buf, self.data_type, self.mask = buf, BINARY, mask
if isinstance(buf, type('')): if isinstance(buf, type('')):
self.buf, self.data_type = BytesIO(buf.encode('utf-8')), TEXT self.buf, self.data_type = BytesIO(buf.encode('utf-8')), TEXT
elif isinstance(buf, bytes): elif isinstance(buf, bytes):
self.buf = BytesIO(buf) self.buf = BytesIO(buf)
buf = self.buf buf = self.buf
self.chunk_size = chunk_size or SEND_CHUNK_SIZE
try: try:
pos = buf.tell() pos = buf.tell()
buf.seek(0, os.SEEK_END) buf.seek(0, os.SEEK_END)
@ -179,12 +183,12 @@ class MessageWriter(object):
if self.exhausted: if self.exhausted:
return None return None
buf = self.buf buf = self.buf
raw = buf.read(SEND_CHUNK_SIZE) raw = buf.read(self.chunk_size)
has_more = True if self.size is None else self.size > buf.tell() has_more = True if self.size is None else self.size > buf.tell()
fin = 0 if has_more and raw else 1 fin = 0 if has_more and raw else 1
opcode = 0 if self.first_frame_created else self.data_type opcode = 0 if self.first_frame_created else self.data_type
self.first_frame_created, self.exhausted = True, bool(fin) self.first_frame_created, self.exhausted = True, bool(fin)
return BytesIO(create_frame(fin, opcode, raw)) return BytesIO(create_frame(fin, opcode, raw, self.mask))
# }}} # }}}
conn_id = 0 conn_id = 0