Make the EchoHandler more efficient by implementing an interface to send individual fragments instead of only complete messages.

This commit is contained in:
Kovid Goyal 2015-10-26 18:30:18 +05:30
parent 87d16a3f74
commit 8fc0a9f3c4
2 changed files with 61 additions and 82 deletions

View File

@ -13,7 +13,7 @@ from hashlib import sha1
from calibre.srv.tests.base import BaseTest, TestServer from calibre.srv.tests.base import BaseTest, TestServer
from calibre.srv.web_socket import ( from calibre.srv.web_socket import (
GUID_STR, BINARY, TEXT, MessageWriter, create_frame, CLOSE, NORMAL_CLOSE, GUID_STR, BINARY, TEXT, MessageWriter, create_frame, CLOSE, NORMAL_CLOSE,
PING, PONG, PROTOCOL_ERROR, CONTINUATION, INCONSISTENT_DATA) PING, PONG, PROTOCOL_ERROR, CONTINUATION, INCONSISTENT_DATA, CONTROL_CODES)
from calibre.utils.monotonic import monotonic from calibre.utils.monotonic import monotonic
from calibre.utils.socket_inheritance import set_socket_inherit from calibre.utils.socket_inheritance import set_socket_inherit
@ -80,13 +80,16 @@ class WSClient(object):
while len(ans) < size: while len(ans) < size:
d = self.recv(size - len(ans)) d = self.recv(size - len(ans))
if not d: if not d:
raise ValueError('Connection to server closed, no data received') return None
ans += d ans += d
return ans return ans
def read_frame(self): def read_frame(self):
b1, b2 = bytearray(self.read_size(2)) x = self.read_size(2)
fin = b1 & 0b10000000 if x is None:
return None
b1, b2 = bytearray(x)
fin = bool(b1 & 0b10000000)
opcode = b1 & 0b1111 opcode = b1 & 0b1111
masked = b2 & 0b10000000 masked = b2 & 0b10000000
if masked: if masked:
@ -98,24 +101,26 @@ class WSClient(object):
payload_length = struct.unpack(b'!Q', self.read_size(8))[0] payload_length = struct.unpack(b'!Q', self.read_size(8))[0]
return Frame(fin, opcode, self.read_size(payload_length)) return Frame(fin, opcode, self.read_size(payload_length))
def read_message(self): def read_messages(self):
frames = [] messages, control_frames = [], []
msg_buf, opcode = [], None
while True: while True:
frame = self.read_frame() frame = self.read_frame()
frames.append(frame) if frame is None or frame.payload is None:
if frame.fin:
break break
ans, opcode = [], None if frame.opcode in CONTROL_CODES:
for frame in frames: control_frames.append((frame.opcode, frame.payload))
if frame is frames[0]: else:
opcode = frame.opcode if opcode is None:
if frame.fin == 0 and frame.opcode not in (BINARY, TEXT): opcode = frame.opcode
raise ValueError('Server sent a start frame with fin=0 and bad opcode') msg_buf.append(frame.payload)
ans.append(frame.payload) if frame.fin:
ans = b''.join(ans) data = b''.join(msg_buf)
if opcode == TEXT: if opcode == TEXT:
ans = ans.decode('utf-8') data = data.decode('utf-8', 'replace')
return opcode, ans messages.append((opcode, data))
msg_buf, opcode = [], None
return messages, control_frames
def write_message(self, msg, chunk_size=None): def write_message(self, msg, chunk_size=None):
if isinstance(msg, tuple): if isinstance(msg, tuple):
@ -140,47 +145,9 @@ class WSClient(object):
self.write_frame(1, CLOSE, struct.pack(b'!H', code) + reason) self.write_frame(1, CLOSE, struct.pack(b'!H', code) + reason)
class TestHandler(object):
def __init__(self):
self.connections = {}
self.connection_state = {}
def conn(self, cid):
ans = self.connections.get(cid)
if ans is not None:
ans = ans()
return ans
def handle_websocket_upgrade(self, connection_id, connection_ref, inheaders):
self.connections[connection_id] = connection_ref
def handle_websocket_data(self, data, message_starting, message_finished, connection_id):
pass
def handle_websocket_close(self, connection_id):
self.connections.pop(connection_id, None)
class EchoHandler(TestHandler):
def __init__(self):
TestHandler.__init__(self)
self.msg_buf = []
def handle_websocket_data(self, data, message_starting, message_finished, connection_id):
if message_starting:
self.msg_buf = []
self.msg_buf.append(data)
if message_finished:
j = '' if isinstance(self.msg_buf[0], type('')) else b''
msg = j.join(self.msg_buf)
self.msg_buf = []
self.conn(connection_id).send_websocket_message(msg, wakeup=False)
class WSTestServer(TestServer): class WSTestServer(TestServer):
def __init__(self, handler=TestHandler): def __init__(self, handler):
TestServer.__init__(self, None) TestServer.__init__(self, None)
from calibre.srv.http_response import create_http_handler from calibre.srv.http_response import create_http_handler
self.loop.handler = create_http_handler(websocket_handler=handler()) self.loop.handler = create_http_handler(websocket_handler=handler())
@ -201,8 +168,8 @@ class WebSocketTest(BaseTest):
client.write_frame(**msg) client.write_frame(**msg)
else: else:
client.write_message(msg) client.write_message(msg)
ordered = not isinstance(expected, (set, frozenset))
pexpected, replies = set(), set() expected_messages, expected_controls = [], []
for ex in expected: for ex in expected:
if isinstance(ex, type('')): if isinstance(ex, type('')):
ex = TEXT, ex ex = TEXT, ex
@ -210,20 +177,22 @@ class WebSocketTest(BaseTest):
ex = BINARY, ex ex = BINARY, ex
elif isinstance(ex, int): elif isinstance(ex, int):
ex = ex, b'' ex = ex, b''
if ordered: if ex[0] in CONTROL_CODES:
self.ae(ex, client.read_message()) expected_controls.append(ex)
else: else:
pexpected.add(ex), replies.add(client.read_message()) expected_messages.append(ex)
if not ordered:
self.ae(pexpected, replies)
if send_close: if send_close:
client.write_close(close_code, close_reason) client.write_close(close_code, close_reason)
opcode, data = client.read_message() messages, control_frames = client.read_messages()
self.ae(opcode, CLOSE) self.ae(expected_messages, messages)
self.ae(close_code, struct.unpack_from(b'!H', data, 0)[0]) self.assertGreaterEqual(len(control_frames), 1)
self.ae(expected_controls, control_frames[:-1])
self.ae(control_frames[-1][0], CLOSE)
self.ae(close_code, struct.unpack_from(b'!H', control_frames[-1][1], 0)[0])
def test_websocket_basic(self): def test_websocket_basic(self):
'Test basic interaction with the websocket server' 'Test basic interaction with the websocket server'
from calibre.srv.web_socket import EchoHandler
with WSTestServer(EchoHandler) as server: with WSTestServer(EchoHandler) as server:
simple_test = partial(self.simple_test, server) simple_test = partial(self.simple_test, server)
@ -284,7 +253,7 @@ class WebSocketTest(BaseTest):
simple_test([ simple_test([
{'opcode':TEXT, 'payload':fragments[0], 'fin':0}, (PING, b'pong'), {'opcode':CONTINUATION, 'payload':fragments[1]} {'opcode':TEXT, 'payload':fragments[0], 'fin':0}, (PING, b'pong'), {'opcode':CONTINUATION, 'payload':fragments[1]}
], {(PONG, b'pong'), ''.join(fragments)}) ], [(PONG, b'pong'), ''.join(fragments)])
fragments = '12345' fragments = '12345'
simple_test([ simple_test([
@ -293,7 +262,7 @@ class WebSocketTest(BaseTest):
{'opcode':CONTINUATION, 'payload':fragments[2], 'fin':0}, {'opcode':CONTINUATION, 'payload':fragments[3], 'fin':0}, {'opcode':CONTINUATION, 'payload':fragments[2], 'fin':0}, {'opcode':CONTINUATION, 'payload':fragments[3], 'fin':0},
(PING, b'2'), (PING, b'2'),
{'opcode':CONTINUATION, 'payload':fragments[4]} {'opcode':CONTINUATION, 'payload':fragments[4]}
], {(PONG, b'1'), (PONG, b'2'), fragments}) ], [(PONG, b'1'), (PONG, b'2'), fragments])
simple_test([ simple_test([
{'opcode':TEXT, 'fin':0}, {'opcode':CONTINUATION, 'fin':0}, {'opcode':CONTINUATION},], ['']) {'opcode':TEXT, 'fin':0}, {'opcode':CONTINUATION, 'fin':0}, {'opcode':CONTINUATION},], [''])
@ -320,6 +289,7 @@ class WebSocketTest(BaseTest):
simple_test([(CLOSE, struct.pack(b'!H', code))], send_close=False, close_code=PROTOCOL_ERROR) simple_test([(CLOSE, struct.pack(b'!H', code))], send_close=False, close_code=PROTOCOL_ERROR)
def test_websocket_perf(self): def test_websocket_perf(self):
from calibre.srv.web_socket import EchoHandler
with WSTestServer(EchoHandler) as server: with WSTestServer(EchoHandler) as server:
simple_test = partial(self.simple_test, server) simple_test = partial(self.simple_test, server)
for sz in (64, 256, 1024, 4096, 8192, 16384): for sz in (64, 256, 1024, 4096, 8192, 16384):

View File

@ -226,6 +226,7 @@ conn_id = 0
class WebSocketConnection(HTTPConnection): class WebSocketConnection(HTTPConnection):
# Internal API {{{
in_websocket_mode = False in_websocket_mode = False
websocket_handler = None websocket_handler = None
@ -431,14 +432,31 @@ class WebSocketConnection(HTTPConnection):
Connection.close(self) Connection.close(self)
else: else:
HTTPConnection.close(self) HTTPConnection.close(self)
# }}}
def send_websocket_message(self, buf, wakeup=True): def send_websocket_message(self, buf, wakeup=True):
''' Send a complete message. This class will take care of splitting it
into appropriate frames automatically. `buf` must be a file like object. '''
self.sendq.put(MessageWriter(buf)) self.sendq.put(MessageWriter(buf))
self.wait_for = RDWR self.wait_for = RDWR
if wakeup: if wakeup:
self.wakeup() self.wakeup()
def send_websocket_frame(self, data, is_first=True, is_last=True):
''' Useful for streaming handlers that want to break up messages into
frames themselves. Note that these frames will be interleaved with
control frames, so they should not be too large. '''
opcode = (TEXT if isinstance(data, type('')) else BINARY) if is_first else CONTINUATION
fin = 1 if is_last else 0
frame = create_frame(fin, opcode, data)
with self.cf_lock:
self.control_frames.append(BytesIO(frame))
def handle_websocket_data(self, data, message_starting, message_finished): def handle_websocket_data(self, data, message_starting, message_finished):
''' Called when some data is received from the remote client. In general the
data may not constitute a complete "message", use the message_starting
and message_finished flags to re-assemble it into a complete message in
the handler. '''
self.websocket_handler.handle_websocket_data(data, message_starting, message_finished, self.websocket_connection_id) self.websocket_handler.handle_websocket_data(data, message_starting, message_finished, self.websocket_connection_id)
class DummyHandler(object): class DummyHandler(object):
@ -458,10 +476,9 @@ class DummyHandler(object):
# Run this file with calibre-debug and use wstest to run the Autobahn test # Run this file with calibre-debug and use wstest to run the Autobahn test
# suite # suite
class EchoClientHandler(object): class EchoHandler(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.msg_buf = []
self.ws_connections = {} self.ws_connections = {}
def conn(self, cid): def conn(self, cid):
@ -474,21 +491,13 @@ class EchoClientHandler(object):
self.ws_connections[connection_id] = connection_ref self.ws_connections[connection_id] = connection_ref
def handle_websocket_data(self, data, message_starting, message_finished, connection_id): def handle_websocket_data(self, data, message_starting, message_finished, connection_id):
if message_starting: self.conn(connection_id).send_websocket_frame(data, message_starting, message_finished)
self.msg_buf = []
self.msg_buf.append(data)
if message_finished:
j = '' if isinstance(self.msg_buf[0], type('')) else b''
msg = j.join(self.msg_buf)
self.msg_buf = []
# print('Received message from client:', reprlib.repr(msg))
self.conn(connection_id).send_websocket_message(msg)
def handle_websocket_close(self, connection_id): def handle_websocket_close(self, connection_id):
self.ws_connections.pop(connection_id, None) self.ws_connections.pop(connection_id, None)
if __name__ == '__main__': if __name__ == '__main__':
s = ServerLoop(create_http_handler(websocket_handler=EchoClientHandler())) s = ServerLoop(create_http_handler(websocket_handler=EchoHandler()))
with HandleInterrupt(s.wakeup): with HandleInterrupt(s.wakeup):
s.serve_forever() s.serve_forever()
# }}} # }}}