mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
More hammering on the WS server
Ensure that the server stops reading after a close packet send is triggered
This commit is contained in:
parent
d2535aa9cf
commit
f4b5d451fd
@ -11,7 +11,9 @@ from functools import partial
|
||||
from hashlib import sha1
|
||||
|
||||
from calibre.srv.tests.base import BaseTest, TestServer
|
||||
from calibre.srv.web_socket import GUID_STR, BINARY, TEXT, MessageWriter, create_frame, CLOSE, NORMAL_CLOSE
|
||||
from calibre.srv.web_socket import (
|
||||
GUID_STR, BINARY, TEXT, MessageWriter, create_frame, CLOSE, NORMAL_CLOSE,
|
||||
PING, PONG, PROTOCOL_ERROR)
|
||||
from calibre.utils.monotonic import monotonic
|
||||
from calibre.utils.socket_inheritance import set_socket_inherit
|
||||
|
||||
@ -116,6 +118,11 @@ class WSClient(object):
|
||||
return opcode, ans
|
||||
|
||||
def write_message(self, msg, chunk_size=None):
|
||||
if isinstance(msg, tuple):
|
||||
opcode, msg = msg
|
||||
if isinstance(msg, type('')):
|
||||
msg = msg.encode('utf-8')
|
||||
return self.write_frame(1, opcode, msg)
|
||||
w = MessageWriter(msg, self.mask, chunk_size)
|
||||
while True:
|
||||
frame = w.create_frame()
|
||||
@ -187,7 +194,7 @@ class WSTestServer(TestServer):
|
||||
|
||||
class WebSocketTest(BaseTest):
|
||||
|
||||
def simple_test(self, client, msgs, expected, close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE'):
|
||||
def simple_test(self, client, msgs, expected=(), close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE'):
|
||||
for msg in msgs:
|
||||
client.write_message(msg)
|
||||
for ex in expected:
|
||||
@ -214,3 +221,15 @@ class WebSocketTest(BaseTest):
|
||||
for q in (b'', b'\xfe' * 125, b'\xfe' * 126, b'\xfe' * 127, b'\xfe' * 128, b'\xfe' * 65535, b'\xfe' * 65536):
|
||||
client = server.connect()
|
||||
self.simple_test(client, [q], [q])
|
||||
|
||||
for payload in ['', 'ping', b'\x00\xff\xfe\xfd\xfc\xfb\x00\xff', b"\xfe" * 125]:
|
||||
client = server.connect()
|
||||
self.simple_test(client, [(PING, payload)], [(PONG, payload)])
|
||||
|
||||
client = server.connect()
|
||||
with server.silence_log:
|
||||
self.simple_test(client, [(PING, 'a'*126)], close_code=PROTOCOL_ERROR, send_close=False)
|
||||
|
||||
for payload in (b'', b'pong'):
|
||||
client = server.connect()
|
||||
self.simple_test(client, [(PONG, payload)], [])
|
||||
|
@ -50,8 +50,14 @@ UNEXPECTED_ERROR = 1011
|
||||
class ReadFrame(object): # {{{
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.state = self.read_header0
|
||||
|
||||
def __call__(self, conn):
|
||||
return self.state(conn)
|
||||
|
||||
def read_header0(self, conn):
|
||||
data = conn.recv(1)
|
||||
if not data:
|
||||
@ -62,7 +68,7 @@ class ReadFrame(object): # {{{
|
||||
self.state = self.read_header1
|
||||
if self.opcode not in ALL_CODES:
|
||||
conn.log.error('Unknown OPCODE from client: %r' % self.opcode)
|
||||
conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.code)
|
||||
conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode)
|
||||
return
|
||||
|
||||
def read_header1(self, conn):
|
||||
@ -118,7 +124,7 @@ class ReadFrame(object): # {{{
|
||||
self.frame_starting = True
|
||||
if self.payload_length == 0:
|
||||
conn.ws_data_received(b'', self.opcode, True, True, bool(self.fin))
|
||||
self.state = None
|
||||
self.reset()
|
||||
|
||||
def read_payload(self, conn):
|
||||
bytes_left = self.payload_length - self.pos
|
||||
@ -133,8 +139,11 @@ class ReadFrame(object): # {{{
|
||||
data[i] ^= self.mask[(self.pos + i) & 3]
|
||||
data = bytes(data)
|
||||
self.pos += len(data)
|
||||
conn.ws_data_received(data, self.opcode, self.frame_starting, self.pos >= self.payload_length, bool(self.fin))
|
||||
frame_finished = self.pos >= self.payload_length
|
||||
conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, bool(self.fin))
|
||||
self.frame_starting = False
|
||||
if frame_finished:
|
||||
self.reset()
|
||||
# }}}
|
||||
|
||||
# Sending frames {{{
|
||||
@ -213,6 +222,7 @@ class WebSocketConnection(HTTPConnection):
|
||||
self.ws_close_received = self.ws_close_sent = False
|
||||
conn_id += 1
|
||||
self.websocket_connection_id = conn_id
|
||||
self.stop_reading = False
|
||||
|
||||
def finalize_headers(self, inheaders):
|
||||
upgrade = inheaders.get('Upgrade', None)
|
||||
@ -239,7 +249,7 @@ class WebSocketConnection(HTTPConnection):
|
||||
if self.write(buf):
|
||||
if self.websocket_handler is None:
|
||||
self.websocket_handler = DummyHandler()
|
||||
self.current_recv_frame, self.current_recv_opcode = ReadFrame(), None
|
||||
self.read_frame, self.current_recv_opcode = ReadFrame(), None
|
||||
self.in_websocket_mode = True
|
||||
try:
|
||||
self.websocket_handler.handle_websocket_upgrade(self.websocket_connection_id, weakref.ref(self), inheaders)
|
||||
@ -252,9 +262,7 @@ class WebSocketConnection(HTTPConnection):
|
||||
def set_ws_state(self):
|
||||
if self.ws_close_sent or self.ws_close_received:
|
||||
if self.ws_close_sent:
|
||||
self.set_state(READ, self.ws_duplex)
|
||||
if self.ws_close_received:
|
||||
self.ready = False
|
||||
self.ready = False
|
||||
else:
|
||||
self.set_state(WRITE, self.ws_duplex)
|
||||
return
|
||||
@ -273,6 +281,12 @@ class WebSocketConnection(HTTPConnection):
|
||||
else:
|
||||
self.set_state(RDWR, self.ws_duplex)
|
||||
|
||||
if self.stop_reading:
|
||||
if self.wait_for is READ:
|
||||
self.ready = False
|
||||
elif self.wait_for is RDWR:
|
||||
self.wait_for = WRITE
|
||||
|
||||
def ws_duplex(self, event):
|
||||
if event is READ:
|
||||
self.ws_read()
|
||||
@ -281,7 +295,8 @@ class WebSocketConnection(HTTPConnection):
|
||||
self.set_ws_state()
|
||||
|
||||
def ws_read(self):
|
||||
self.current_recv_frame.state(self)
|
||||
if not self.stop_reading:
|
||||
self.read_frame(self)
|
||||
|
||||
def ws_data_received(self, data, opcode, frame_starting, frame_finished, is_final_frame_of_message):
|
||||
if opcode in CONTROL_CODES:
|
||||
@ -300,8 +315,6 @@ class WebSocketConnection(HTTPConnection):
|
||||
self.websocket_close(UNEXPECTED_ERROR, 'Unexpected error in handler: %r' % as_unicode(err))
|
||||
self.current_recv_opcode = opcode
|
||||
message_starting = True
|
||||
if frame_finished:
|
||||
self.current_recv_frame = ReadFrame()
|
||||
message_finished = frame_finished and is_final_frame_of_message
|
||||
if message_finished:
|
||||
self.current_recv_opcode = None
|
||||
@ -330,11 +343,13 @@ class WebSocketConnection(HTTPConnection):
|
||||
self.control_frames.append(f)
|
||||
if opcode == CLOSE:
|
||||
self.ws_close_received = True
|
||||
self.stop_reading = True
|
||||
self.set_ws_state()
|
||||
|
||||
def websocket_close(self, code=NORMAL_CLOSE, reason=b''):
|
||||
if isinstance(reason, type('')):
|
||||
reason = reason.encode('utf-8')
|
||||
self.stop_reading = True
|
||||
reason = reason[:123]
|
||||
if code is None and not reason:
|
||||
f = BytesIO(create_frame(1, CLOSE, b''))
|
||||
|
Loading…
x
Reference in New Issue
Block a user