From fb579a6257a682cad22799c5d9462312f86d87f4 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sun, 25 Oct 2015 16:06:46 +0530 Subject: [PATCH] Testing infrastructure for the web socket server --- src/calibre/srv/tests/web_sockets.py | 158 ++++++++++++++++++++++++++- src/calibre/srv/web_socket.py | 16 ++- 2 files changed, 163 insertions(+), 11 deletions(-) diff --git a/src/calibre/srv/tests/web_sockets.py b/src/calibre/srv/tests/web_sockets.py index 15c8f78463..c4fc4dca1d 100644 --- a/src/calibre/srv/tests/web_sockets.py +++ b/src/calibre/srv/tests/web_sockets.py @@ -4,8 +4,134 @@ from __future__ import (unicode_literals, division, absolute_import, print_function) +import socket, os, struct +from base64 import standard_b64encode +from collections import deque, namedtuple +from functools import partial +from hashlib import sha1 from calibre.srv.tests.base import BaseTest, TestServer +from calibre.srv.web_socket import GUID_STR, BINARY, TEXT, MessageWriter, create_frame, CLOSE, NORMAL_CLOSE +from calibre.utils.monotonic import monotonic +from calibre.utils.socket_inheritance import set_socket_inherit + +HANDSHAKE_STR = '''\ +GET / HTTP/1.1\r +Upgrade: websocket\r +Connection: Upgrade\r +Sec-WebSocket-Key: {}\r +Sec-WebSocket-Version: 13\r +''' + '\r\n' + +Frame = namedtuple('Frame', 'fin opcode payload') + +class WSClient(object): + + def __init__(self, port, timeout=5): + self.timeout = timeout + self.socket = socket.create_connection(('localhost', port), timeout) + set_socket_inherit(self.socket, False) + self.key = standard_b64encode(os.urandom(8)) + self.socket.sendall(HANDSHAKE_STR.format(self.key).encode('ascii')) + self.read_buf = deque() + self.read_upgrade_response() + self.mask = os.urandom(4) + self.frames = [] + + def read_upgrade_response(self): + from calibre.srv.http_request import read_headers + st = monotonic() + buf, idx = b'', -1 + while idx == -1: + data = self.socket.recv(1024) + if not data: + raise ValueError('Server did not respond with a valid HTTP upgrade response') + buf += data + if len(buf) > 4096: + raise ValueError('Server responded with too much data to HTTP upgrade request') + if monotonic() - st > self.timeout: + raise ValueError('Timed out while waiting for server response to HTTP upgrade') + idx = buf.find(b'\r\n\r\n') + response, rest = buf[:idx+4], buf[idx+4:] + if rest: + self.read_buf.append(rest) + lines = (x + b'\r\n' for x in response.split(b'\r\n')[:-1]) + rl = next(lines) + if rl != b'HTTP/1.1 101 Switching Protocols\r\n': + raise ValueError('Server did not respond with correct switching protocols line') + headers = read_headers(partial(next, lines)) + key = standard_b64encode(sha1(self.key + GUID_STR).digest()) + if headers.get('Sec-WebSocket-Accept') != key: + raise ValueError('Server did not respond with correct key in Sec-WebSocket-Accept') + + def recv(self, max_amt): + if self.read_buf: + data = self.read_buf.popleft() + if len(data) <= max_amt: + return data + self.read_buf.appendleft(data[max_amt+1:]) + return data[:max_amt + 1] + return self.socket.recv(max_amt) + + def read_size(self, size): + ans = b'' + while len(ans) < size: + d = self.recv(size - len(ans)) + if not d: + raise ValueError('Connection to server closed, no data received') + ans += d + return ans + + def read_frame(self): + b1, b2 = bytearray(self.read_size(2)) + fin = b1 & 0b10000000 + opcode = b1 & 0b1111 + masked = b2 & 0b10000000 + if masked: + raise ValueError('Got a frame with mask bit set from the server') + payload_length = b2 & 0b01111111 + if payload_length == 126: + payload_length = struct.unpack(b'!H', self.read_size(2)) + elif payload_length == 127: + payload_length = struct.unpack(b'!Q', self.read_size(8)) + return Frame(fin, opcode, self.read_size(payload_length)) + + def read_message(self): + frames = [] + while True: + frame = self.read_frame() + frames.append(frame) + if frame.fin: + break + ans, opcode = [], None + for frame in frames: + if frame is frames[0]: + opcode = frame.opcode + if frame.fin == 0 and frame.opcode not in (BINARY, TEXT): + raise ValueError('Server sent a start frame with fin=0 and bad opcode') + ans.append(frame.payload) + ans = b''.join(ans) + if opcode == TEXT: + ans = ans.decode('utf-8') + return opcode, ans + + def write_message(self, msg, chunk_size=None): + w = MessageWriter(msg, self.mask, chunk_size) + while True: + frame = w.create_frame() + if frame is None: + break + self.socket.sendall(frame.getvalue()) + + def write_frame(self, fin, opcode, payload=b'', rsv=0, mask=True): + frame = create_frame(fin, opcode, payload, rsv=rsv, mask=self.mask if mask else None) + self.socket.sendall(frame) + + def write_close(self, code, reason=b''): + if isinstance(reason, type('')): + reason = reason.encode('utf-8') + self.write_frame(1, CLOSE, struct.pack(b'!H', code) + reason) + class TestHandler(object): @@ -14,13 +140,13 @@ class TestHandler(object): self.connection_state = {} def conn(self, cid): - ans = self.ws_connections.get(cid) + ans = self.connections.get(cid) if ans is not None: ans = ans() return ans def handle_websocket_upgrade(self, connection_id, connection_ref, inheaders): - self.ws_connections[connection_id] = connection_ref + self.connections[connection_id] = connection_ref def handle_websocket_data(self, data, message_starting, message_finished, connection_id): pass @@ -42,7 +168,7 @@ class EchoHandler(TestHandler): j = '' if isinstance(self.msg_buf[0], type('')) else b'' msg = j.join(self.msg_buf) self.msg_buf = [] - self.conn(connection_id).send_websocket_message(msg) + self.conn(connection_id).send_websocket_message(msg, wakeup=False) class WSTestServer(TestServer): @@ -56,10 +182,32 @@ class WSTestServer(TestServer): def ws_handler(self): return self.loop.handler.websocket_handler + def ws_connect(self): + return WSClient(self.address[1]) + class WebSocketTest(BaseTest): + def simple_test(self, client, msgs, expected, close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE'): + for msg in msgs: + client.write_message(msg) + for ex in expected: + if isinstance(ex, type('')): + ex = TEXT, ex + elif isinstance(ex, bytes): + ex = BINARY, ex + elif isinstance(ex, int): + ex = ex, b'' + self.ae(ex, client.read_message()) + if send_close: + client.write_close(close_code, close_reason) + opcode, data = client.read_message() + self.ae(opcode, CLOSE) + self.ae(close_code, struct.unpack_from(b'!H', data, 0)[0]) + def test_websocket_basic(self): 'Test basic interaction with the websocket server' - with WSTestServer(EchoHandler): - pass + with WSTestServer(EchoHandler) as server: + client = server.ws_connect() + st = partial(self.simple_test, client) + st([''], ['']) diff --git a/src/calibre/srv/web_socket.py b/src/calibre/srv/web_socket.py index a1a9ca8321..1fe9bd1539 100644 --- a/src/calibre/srv/web_socket.py +++ b/src/calibre/srv/web_socket.py @@ -116,6 +116,9 @@ class ReadFrame(object): # {{{ self.state = self.read_payload self.pos = 0 self.frame_starting = True + if self.payload_length == 0: + conn.ws_data_received(b'', self.opcode, True, True, bool(self.fin)) + self.state = None def read_payload(self, conn): bytes_left = self.payload_length - self.pos @@ -133,13 +136,13 @@ class ReadFrame(object): # {{{ # Sending frames {{{ -def create_frame(fin, opcode, payload, mask=None): +def create_frame(fin, opcode, payload, mask=None, rsv=0): if isinstance(payload, type('')): payload = payload.encode('utf-8') l = len(payload) opcode &= 0b1111 b1 = opcode | (0b10000000 if fin else 0) - b2 = 0 if mask is None else 0b10000000 + b2 = rsv | (0 if mask is None else 0b10000000) if l < 126: header = bytes(bytearray((b1, b2 | l))) elif 126 <= l <= 65535: @@ -159,13 +162,14 @@ def create_frame(fin, opcode, payload, mask=None): class MessageWriter(object): - def __init__(self, buf): - self.buf, self.data_type = buf, BINARY + def __init__(self, buf, mask=None, chunk_size=None): + self.buf, self.data_type, self.mask = buf, BINARY, mask if isinstance(buf, type('')): self.buf, self.data_type = BytesIO(buf.encode('utf-8')), TEXT elif isinstance(buf, bytes): self.buf = BytesIO(buf) buf = self.buf + self.chunk_size = chunk_size or SEND_CHUNK_SIZE try: pos = buf.tell() buf.seek(0, os.SEEK_END) @@ -179,12 +183,12 @@ class MessageWriter(object): if self.exhausted: return None buf = self.buf - raw = buf.read(SEND_CHUNK_SIZE) + raw = buf.read(self.chunk_size) has_more = True if self.size is None else self.size > buf.tell() fin = 0 if has_more and raw else 1 opcode = 0 if self.first_frame_created else self.data_type self.first_frame_created, self.exhausted = True, bool(fin) - return BytesIO(create_frame(fin, opcode, raw)) + return BytesIO(create_frame(fin, opcode, raw, self.mask)) # }}} conn_id = 0