diff --git a/src/calibre/srv/tests/web_sockets.py b/src/calibre/srv/tests/web_sockets.py index c4fc4dca1d..bcb92d2d16 100644 --- a/src/calibre/srv/tests/web_sockets.py +++ b/src/calibre/srv/tests/web_sockets.py @@ -91,9 +91,9 @@ class WSClient(object): raise ValueError('Got a frame with mask bit set from the server') payload_length = b2 & 0b01111111 if payload_length == 126: - payload_length = struct.unpack(b'!H', self.read_size(2)) + payload_length = struct.unpack(b'!H', self.read_size(2))[0] elif payload_length == 127: - payload_length = struct.unpack(b'!Q', self.read_size(8)) + payload_length = struct.unpack(b'!Q', self.read_size(8))[0] return Frame(fin, opcode, self.read_size(payload_length)) def read_message(self): @@ -182,7 +182,7 @@ class WSTestServer(TestServer): def ws_handler(self): return self.loop.handler.websocket_handler - def ws_connect(self): + def connect(self): return WSClient(self.address[1]) class WebSocketTest(BaseTest): @@ -208,6 +208,9 @@ class WebSocketTest(BaseTest): 'Test basic interaction with the websocket server' with WSTestServer(EchoHandler) as server: - client = server.ws_connect() - st = partial(self.simple_test, client) - st([''], ['']) + for q in ('', '*' * 125, '*' * 126, '*' * 127, '*' * 128, '*' * 65535, '*' * 65536): + client = server.connect() + self.simple_test(client, [q], [q]) + for q in (b'', b'\xfe' * 125, b'\xfe' * 126, b'\xfe' * 127, b'\xfe' * 128, b'\xfe' * 65535, b'\xfe' * 65536): + client = server.connect() + self.simple_test(client, [q], [q]) diff --git a/src/calibre/srv/web_socket.py b/src/calibre/srv/web_socket.py index 1fe9bd1539..7d3939e424 100644 --- a/src/calibre/srv/web_socket.py +++ b/src/calibre/srv/web_socket.py @@ -5,7 +5,7 @@ from __future__ import (unicode_literals, division, absolute_import, print_function) -import codecs, httplib, struct, os, weakref, repr as reprlib, time +import codecs, httplib, struct, os, weakref, repr as reprlib, time, socket from base64 import standard_b64encode from functools import partial from hashlib import sha1 @@ -128,9 +128,12 @@ class ReadFrame(object): # {{{ return else: data = b'' - unmasked = bytes(bytearray(ord(byte) ^ self.mask[(self.pos + i) & 3] for i, byte in enumerate(data))) + data = bytearray(data) + for i in xrange(len(data)): + data[i] ^= self.mask[(self.pos + i) & 3] + data = bytes(data) self.pos += len(data) - conn.ws_data_received(unmasked, self.opcode, self.frame_starting, self.pos >= self.payload_length, bool(self.fin)) + conn.ws_data_received(data, self.opcode, self.frame_starting, self.pos >= self.payload_length, bool(self.fin)) self.frame_starting = False # }}} @@ -229,6 +232,7 @@ class WebSocketConnection(HTTPConnection): response = HANDSHAKE_STR % standard_b64encode(sha1(key + GUID_STR).digest()) self.optimize_for_sending_packet() + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.set_state(WRITE, self.upgrade_connection_to_ws, BytesIO(response.encode('ascii')), inheaders) def upgrade_connection_to_ws(self, buf, inheaders, event): @@ -243,6 +247,7 @@ class WebSocketConnection(HTTPConnection): self.log.exception('Error in WebSockets upgrade handler:') self.websocket_close(UNEXPECTED_ERROR, 'Unexpected error in handler: %r' % as_unicode(err)) self.set_ws_state() + self.end_send_optimization() def set_ws_state(self): if self.ws_close_sent or self.ws_close_received: @@ -295,9 +300,11 @@ class WebSocketConnection(HTTPConnection): self.websocket_close(UNEXPECTED_ERROR, 'Unexpected error in handler: %r' % as_unicode(err)) self.current_recv_opcode = opcode message_starting = True + if frame_finished: + self.current_recv_frame = ReadFrame() message_finished = frame_finished and is_final_frame_of_message if message_finished: - self.current_recv_frame, self.current_recv_opcode = ReadFrame(), None + self.current_recv_opcode = None if opcode == TEXT: if message_starting: self.frag_decoder.reset() @@ -343,6 +350,7 @@ class WebSocketConnection(HTTPConnection): return if self.send_buf is not None: if self.write(self.send_buf): + self.end_send_optimization() if getattr(self.send_buf, 'is_close_frame', False): self.ws_close_sent = True self.send_buf = None @@ -355,6 +363,8 @@ class WebSocketConnection(HTTPConnection): self.send_buf = self.sending.create_frame() if self.send_buf is None: self.sending = None + if self.send_buf is not None: + self.optimize_for_sending_packet() def close(self): if self.in_websocket_mode: