diff --git a/src/calibre/srv/tests/web_sockets.py b/src/calibre/srv/tests/web_sockets.py index c6812580fc..3c2bb3ff1e 100644 --- a/src/calibre/srv/tests/web_sockets.py +++ b/src/calibre/srv/tests/web_sockets.py @@ -130,8 +130,8 @@ class WSClient(object): break self.socket.sendall(frame.getvalue()) - def write_frame(self, fin, opcode, payload=b'', rsv=0, mask=True): - frame = create_frame(fin, opcode, payload, rsv=rsv, mask=self.mask if mask else None) + def write_frame(self, fin=1, opcode=CLOSE, payload=b'', rsv=0, mask=True): + frame = create_frame(fin, opcode, payload, rsv=(rsv << 4), mask=self.mask if mask else None) self.socket.sendall(frame) def write_close(self, code, reason=b''): @@ -196,7 +196,10 @@ class WebSocketTest(BaseTest): def simple_test(self, client, msgs, expected=(), close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE'): for msg in msgs: - client.write_message(msg) + if isinstance(msg, dict): + client.write_frame(**msg) + else: + client.write_message(msg) for ex in expected: if isinstance(ex, type('')): ex = TEXT, ex @@ -233,3 +236,11 @@ class WebSocketTest(BaseTest): for payload in (b'', b'pong'): client = server.connect() self.simple_test(client, [(PONG, payload)], []) + + with server.silence_log: + for rsv in xrange(1, 7): + client = server.connect() + self.simple_test(client, [{'rsv':rsv, 'opcode':BINARY}], [], close_code=PROTOCOL_ERROR, send_close=False) + for opcode in (3, 4, 5, 6, 7, 11, 12, 13, 14, 15): + client = server.connect() + self.simple_test(client, [{'opcode':opcode}], [], close_code=PROTOCOL_ERROR, send_close=False) diff --git a/src/calibre/srv/web_socket.py b/src/calibre/srv/web_socket.py index 5baf0f6454..d9d7a24b6f 100644 --- a/src/calibre/srv/web_socket.py +++ b/src/calibre/srv/web_socket.py @@ -64,6 +64,11 @@ class ReadFrame(object): # {{{ return b = ord(data) self.fin = b & 0b10000000 + if b & 0b01110000: + conn.log.error('RSV bits set in frame from client') + conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set') + return + self.opcode = b & 0b1111 self.state = self.read_header1 if self.opcode not in ALL_CODES: @@ -153,8 +158,8 @@ def create_frame(fin, opcode, payload, mask=None, rsv=0): payload = payload.encode('utf-8') l = len(payload) opcode &= 0b1111 - b1 = opcode | (0b10000000 if fin else 0) - b2 = rsv | (0 if mask is None else 0b10000000) + b1 = opcode | (0b10000000 if fin else 0) | (rsv & 0b01110000) + b2 = 0 if mask is None else 0b10000000 if l < 126: header = bytes(bytearray((b1, b2 | l))) elif 126 <= l <= 65535: