More hammering on the WS server

Ensure that the server stops reading after
a close packet send is triggered
This commit is contained in:
Kovid Goyal 2015-10-25 20:52:41 +05:30
parent d2535aa9cf
commit f4b5d451fd
2 changed files with 46 additions and 12 deletions

View File

@ -11,7 +11,9 @@ from functools import partial
from hashlib import sha1
from calibre.srv.tests.base import BaseTest, TestServer
from calibre.srv.web_socket import GUID_STR, BINARY, TEXT, MessageWriter, create_frame, CLOSE, NORMAL_CLOSE
from calibre.srv.web_socket import (
GUID_STR, BINARY, TEXT, MessageWriter, create_frame, CLOSE, NORMAL_CLOSE,
PING, PONG, PROTOCOL_ERROR)
from calibre.utils.monotonic import monotonic
from calibre.utils.socket_inheritance import set_socket_inherit
@ -116,6 +118,11 @@ class WSClient(object):
return opcode, ans
def write_message(self, msg, chunk_size=None):
if isinstance(msg, tuple):
opcode, msg = msg
if isinstance(msg, type('')):
msg = msg.encode('utf-8')
return self.write_frame(1, opcode, msg)
w = MessageWriter(msg, self.mask, chunk_size)
while True:
frame = w.create_frame()
@ -187,7 +194,7 @@ class WSTestServer(TestServer):
class WebSocketTest(BaseTest):
def simple_test(self, client, msgs, expected, close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE'):
def simple_test(self, client, msgs, expected=(), close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE'):
for msg in msgs:
client.write_message(msg)
for ex in expected:
@ -214,3 +221,15 @@ class WebSocketTest(BaseTest):
for q in (b'', b'\xfe' * 125, b'\xfe' * 126, b'\xfe' * 127, b'\xfe' * 128, b'\xfe' * 65535, b'\xfe' * 65536):
client = server.connect()
self.simple_test(client, [q], [q])
for payload in ['', 'ping', b'\x00\xff\xfe\xfd\xfc\xfb\x00\xff', b"\xfe" * 125]:
client = server.connect()
self.simple_test(client, [(PING, payload)], [(PONG, payload)])
client = server.connect()
with server.silence_log:
self.simple_test(client, [(PING, 'a'*126)], close_code=PROTOCOL_ERROR, send_close=False)
for payload in (b'', b'pong'):
client = server.connect()
self.simple_test(client, [(PONG, payload)], [])

View File

@ -50,8 +50,14 @@ UNEXPECTED_ERROR = 1011
class ReadFrame(object): # {{{
def __init__(self):
self.reset()
def reset(self):
self.state = self.read_header0
def __call__(self, conn):
return self.state(conn)
def read_header0(self, conn):
data = conn.recv(1)
if not data:
@ -62,7 +68,7 @@ class ReadFrame(object): # {{{
self.state = self.read_header1
if self.opcode not in ALL_CODES:
conn.log.error('Unknown OPCODE from client: %r' % self.opcode)
conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.code)
conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode)
return
def read_header1(self, conn):
@ -118,7 +124,7 @@ class ReadFrame(object): # {{{
self.frame_starting = True
if self.payload_length == 0:
conn.ws_data_received(b'', self.opcode, True, True, bool(self.fin))
self.state = None
self.reset()
def read_payload(self, conn):
bytes_left = self.payload_length - self.pos
@ -133,8 +139,11 @@ class ReadFrame(object): # {{{
data[i] ^= self.mask[(self.pos + i) & 3]
data = bytes(data)
self.pos += len(data)
conn.ws_data_received(data, self.opcode, self.frame_starting, self.pos >= self.payload_length, bool(self.fin))
frame_finished = self.pos >= self.payload_length
conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, bool(self.fin))
self.frame_starting = False
if frame_finished:
self.reset()
# }}}
# Sending frames {{{
@ -213,6 +222,7 @@ class WebSocketConnection(HTTPConnection):
self.ws_close_received = self.ws_close_sent = False
conn_id += 1
self.websocket_connection_id = conn_id
self.stop_reading = False
def finalize_headers(self, inheaders):
upgrade = inheaders.get('Upgrade', None)
@ -239,7 +249,7 @@ class WebSocketConnection(HTTPConnection):
if self.write(buf):
if self.websocket_handler is None:
self.websocket_handler = DummyHandler()
self.current_recv_frame, self.current_recv_opcode = ReadFrame(), None
self.read_frame, self.current_recv_opcode = ReadFrame(), None
self.in_websocket_mode = True
try:
self.websocket_handler.handle_websocket_upgrade(self.websocket_connection_id, weakref.ref(self), inheaders)
@ -252,9 +262,7 @@ class WebSocketConnection(HTTPConnection):
def set_ws_state(self):
if self.ws_close_sent or self.ws_close_received:
if self.ws_close_sent:
self.set_state(READ, self.ws_duplex)
if self.ws_close_received:
self.ready = False
self.ready = False
else:
self.set_state(WRITE, self.ws_duplex)
return
@ -273,6 +281,12 @@ class WebSocketConnection(HTTPConnection):
else:
self.set_state(RDWR, self.ws_duplex)
if self.stop_reading:
if self.wait_for is READ:
self.ready = False
elif self.wait_for is RDWR:
self.wait_for = WRITE
def ws_duplex(self, event):
if event is READ:
self.ws_read()
@ -281,7 +295,8 @@ class WebSocketConnection(HTTPConnection):
self.set_ws_state()
def ws_read(self):
self.current_recv_frame.state(self)
if not self.stop_reading:
self.read_frame(self)
def ws_data_received(self, data, opcode, frame_starting, frame_finished, is_final_frame_of_message):
if opcode in CONTROL_CODES:
@ -300,8 +315,6 @@ class WebSocketConnection(HTTPConnection):
self.websocket_close(UNEXPECTED_ERROR, 'Unexpected error in handler: %r' % as_unicode(err))
self.current_recv_opcode = opcode
message_starting = True
if frame_finished:
self.current_recv_frame = ReadFrame()
message_finished = frame_finished and is_final_frame_of_message
if message_finished:
self.current_recv_opcode = None
@ -330,11 +343,13 @@ class WebSocketConnection(HTTPConnection):
self.control_frames.append(f)
if opcode == CLOSE:
self.ws_close_received = True
self.stop_reading = True
self.set_ws_state()
def websocket_close(self, code=NORMAL_CLOSE, reason=b''):
if isinstance(reason, type('')):
reason = reason.encode('utf-8')
self.stop_reading = True
reason = reason[:123]
if code is None and not reason:
f = BytesIO(create_frame(1, CLOSE, b''))