diff --git a/src/calibre/srv/tests/web_sockets.py b/src/calibre/srv/tests/web_sockets.py index bcb92d2d16..c6812580fc 100644 --- a/src/calibre/srv/tests/web_sockets.py +++ b/src/calibre/srv/tests/web_sockets.py @@ -11,7 +11,9 @@ from functools import partial from hashlib import sha1 from calibre.srv.tests.base import BaseTest, TestServer -from calibre.srv.web_socket import GUID_STR, BINARY, TEXT, MessageWriter, create_frame, CLOSE, NORMAL_CLOSE +from calibre.srv.web_socket import ( + GUID_STR, BINARY, TEXT, MessageWriter, create_frame, CLOSE, NORMAL_CLOSE, + PING, PONG, PROTOCOL_ERROR) from calibre.utils.monotonic import monotonic from calibre.utils.socket_inheritance import set_socket_inherit @@ -116,6 +118,11 @@ class WSClient(object): return opcode, ans def write_message(self, msg, chunk_size=None): + if isinstance(msg, tuple): + opcode, msg = msg + if isinstance(msg, type('')): + msg = msg.encode('utf-8') + return self.write_frame(1, opcode, msg) w = MessageWriter(msg, self.mask, chunk_size) while True: frame = w.create_frame() @@ -187,7 +194,7 @@ class WSTestServer(TestServer): class WebSocketTest(BaseTest): - def simple_test(self, client, msgs, expected, close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE'): + def simple_test(self, client, msgs, expected=(), close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE'): for msg in msgs: client.write_message(msg) for ex in expected: @@ -214,3 +221,15 @@ class WebSocketTest(BaseTest): for q in (b'', b'\xfe' * 125, b'\xfe' * 126, b'\xfe' * 127, b'\xfe' * 128, b'\xfe' * 65535, b'\xfe' * 65536): client = server.connect() self.simple_test(client, [q], [q]) + + for payload in ['', 'ping', b'\x00\xff\xfe\xfd\xfc\xfb\x00\xff', b"\xfe" * 125]: + client = server.connect() + self.simple_test(client, [(PING, payload)], [(PONG, payload)]) + + client = server.connect() + with server.silence_log: + self.simple_test(client, [(PING, 'a'*126)], close_code=PROTOCOL_ERROR, send_close=False) + + for payload in (b'', b'pong'): + client = server.connect() + self.simple_test(client, [(PONG, payload)], []) diff --git a/src/calibre/srv/web_socket.py b/src/calibre/srv/web_socket.py index 7d3939e424..5baf0f6454 100644 --- a/src/calibre/srv/web_socket.py +++ b/src/calibre/srv/web_socket.py @@ -50,8 +50,14 @@ UNEXPECTED_ERROR = 1011 class ReadFrame(object): # {{{ def __init__(self): + self.reset() + + def reset(self): self.state = self.read_header0 + def __call__(self, conn): + return self.state(conn) + def read_header0(self, conn): data = conn.recv(1) if not data: @@ -62,7 +68,7 @@ class ReadFrame(object): # {{{ self.state = self.read_header1 if self.opcode not in ALL_CODES: conn.log.error('Unknown OPCODE from client: %r' % self.opcode) - conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.code) + conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode) return def read_header1(self, conn): @@ -118,7 +124,7 @@ class ReadFrame(object): # {{{ self.frame_starting = True if self.payload_length == 0: conn.ws_data_received(b'', self.opcode, True, True, bool(self.fin)) - self.state = None + self.reset() def read_payload(self, conn): bytes_left = self.payload_length - self.pos @@ -133,8 +139,11 @@ class ReadFrame(object): # {{{ data[i] ^= self.mask[(self.pos + i) & 3] data = bytes(data) self.pos += len(data) - conn.ws_data_received(data, self.opcode, self.frame_starting, self.pos >= self.payload_length, bool(self.fin)) + frame_finished = self.pos >= self.payload_length + conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, bool(self.fin)) self.frame_starting = False + if frame_finished: + self.reset() # }}} # Sending frames {{{ @@ -213,6 +222,7 @@ class WebSocketConnection(HTTPConnection): self.ws_close_received = self.ws_close_sent = False conn_id += 1 self.websocket_connection_id = conn_id + self.stop_reading = False def finalize_headers(self, inheaders): upgrade = inheaders.get('Upgrade', None) @@ -239,7 +249,7 @@ class WebSocketConnection(HTTPConnection): if self.write(buf): if self.websocket_handler is None: self.websocket_handler = DummyHandler() - self.current_recv_frame, self.current_recv_opcode = ReadFrame(), None + self.read_frame, self.current_recv_opcode = ReadFrame(), None self.in_websocket_mode = True try: self.websocket_handler.handle_websocket_upgrade(self.websocket_connection_id, weakref.ref(self), inheaders) @@ -252,9 +262,7 @@ class WebSocketConnection(HTTPConnection): def set_ws_state(self): if self.ws_close_sent or self.ws_close_received: if self.ws_close_sent: - self.set_state(READ, self.ws_duplex) - if self.ws_close_received: - self.ready = False + self.ready = False else: self.set_state(WRITE, self.ws_duplex) return @@ -273,6 +281,12 @@ class WebSocketConnection(HTTPConnection): else: self.set_state(RDWR, self.ws_duplex) + if self.stop_reading: + if self.wait_for is READ: + self.ready = False + elif self.wait_for is RDWR: + self.wait_for = WRITE + def ws_duplex(self, event): if event is READ: self.ws_read() @@ -281,7 +295,8 @@ class WebSocketConnection(HTTPConnection): self.set_ws_state() def ws_read(self): - self.current_recv_frame.state(self) + if not self.stop_reading: + self.read_frame(self) def ws_data_received(self, data, opcode, frame_starting, frame_finished, is_final_frame_of_message): if opcode in CONTROL_CODES: @@ -300,8 +315,6 @@ class WebSocketConnection(HTTPConnection): self.websocket_close(UNEXPECTED_ERROR, 'Unexpected error in handler: %r' % as_unicode(err)) self.current_recv_opcode = opcode message_starting = True - if frame_finished: - self.current_recv_frame = ReadFrame() message_finished = frame_finished and is_final_frame_of_message if message_finished: self.current_recv_opcode = None @@ -330,11 +343,13 @@ class WebSocketConnection(HTTPConnection): self.control_frames.append(f) if opcode == CLOSE: self.ws_close_received = True + self.stop_reading = True self.set_ws_state() def websocket_close(self, code=NORMAL_CLOSE, reason=b''): if isinstance(reason, type('')): reason = reason.encode('utf-8') + self.stop_reading = True reason = reason[:123] if code is None and not reason: f = BytesIO(create_frame(1, CLOSE, b''))