diff --git a/src/calibre/srv/tests/web_sockets.py b/src/calibre/srv/tests/web_sockets.py index 3c2bb3ff1e..e3672c7659 100644 --- a/src/calibre/srv/tests/web_sockets.py +++ b/src/calibre/srv/tests/web_sockets.py @@ -244,3 +244,10 @@ class WebSocketTest(BaseTest): for opcode in (3, 4, 5, 6, 7, 11, 12, 13, 14, 15): client = server.connect() self.simple_test(client, [{'opcode':opcode}], [], close_code=PROTOCOL_ERROR, send_close=False) + + for opcode in (PING, PONG): + client = server.connect() + self.simple_test(client, [ + {'opcode':opcode, 'payload':'f1', 'fin':0}, {'opcode':opcode, 'payload':'f2'} + ], close_code=PROTOCOL_ERROR, send_close=False) + diff --git a/src/calibre/srv/web_socket.py b/src/calibre/srv/web_socket.py index d9d7a24b6f..81684b00e3 100644 --- a/src/calibre/srv/web_socket.py +++ b/src/calibre/srv/web_socket.py @@ -7,6 +7,7 @@ from __future__ import (unicode_literals, division, absolute_import, import codecs, httplib, struct, os, weakref, repr as reprlib, time, socket from base64 import standard_b64encode +from collections import deque from functools import partial from hashlib import sha1 from io import BytesIO @@ -63,7 +64,7 @@ class ReadFrame(object): # {{{ if not data: return b = ord(data) - self.fin = b & 0b10000000 + self.fin = bool(b & 0b10000000) if b & 0b01110000: conn.log.error('RSV bits set in frame from client') conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set') @@ -75,6 +76,10 @@ class ReadFrame(object): # {{{ conn.log.error('Unknown OPCODE from client: %r' % self.opcode) conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode) return + if not self.fin and self.opcode in CONTROL_CODES: + conn.log.error('Fragmented control frame from client') + conn.websocket_close(PROTOCOL_ERROR, 'Fragmented control frame') + return def read_header1(self, conn): data = conn.recv(1) @@ -85,11 +90,13 @@ class ReadFrame(object): # {{{ if not self.mask: conn.log.error('Unmasked packet from client') conn.websocket_close(PROTOCOL_ERROR, 'Unmasked packet not allowed') + self.reset() return self.payload_length = b & 0b01111111 if self.opcode in (PING, PONG) and self.payload_length > 125: conn.log.error('Too large ping packet from client') conn.websocket_close(PROTOCOL_ERROR, 'Ping packet too large') + self.reset() return self.mask_buf = b'' if self.payload_length == 126: @@ -128,7 +135,7 @@ class ReadFrame(object): # {{{ self.pos = 0 self.frame_starting = True if self.payload_length == 0: - conn.ws_data_received(b'', self.opcode, True, True, bool(self.fin)) + conn.ws_data_received(b'', self.opcode, True, True, self.fin) self.reset() def read_payload(self, conn): @@ -145,7 +152,7 @@ class ReadFrame(object): # {{{ data = bytes(data) self.pos += len(data) frame_finished = self.pos >= self.payload_length - conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, bool(self.fin)) + conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, self.fin) self.frame_starting = False if frame_finished: self.reset() @@ -219,7 +226,7 @@ class WebSocketConnection(HTTPConnection): global conn_id HTTPConnection.__init__(self, *args, **kwargs) self.sendq = Queue() - self.control_frames = [] + self.control_frames = deque() self.cf_lock = Lock() self.sending = None self.send_buf = None @@ -377,7 +384,7 @@ class WebSocketConnection(HTTPConnection): else: with self.cf_lock: try: - self.send_buf = self.control_frames.pop() + self.send_buf = self.control_frames.popleft() except IndexError: if self.sending is not None: self.send_buf = self.sending.create_frame() @@ -396,7 +403,8 @@ class WebSocketConnection(HTTPConnection): try: if self.send_buf is None: self.websocket_close(SHUTTING_DOWN, 'Shutting down') - self.write(self.control_frames.pop()) + with self.cf_lock: + self.write(self.control_frames.pop()) except Exception: pass Connection.close(self)