mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
Testing infrastructure for the web socket server
This commit is contained in:
parent
4a6f04fd95
commit
fb579a6257
@ -4,8 +4,134 @@
|
|||||||
|
|
||||||
from __future__ import (unicode_literals, division, absolute_import,
|
from __future__ import (unicode_literals, division, absolute_import,
|
||||||
print_function)
|
print_function)
|
||||||
|
import socket, os, struct
|
||||||
|
from base64 import standard_b64encode
|
||||||
|
from collections import deque, namedtuple
|
||||||
|
from functools import partial
|
||||||
|
from hashlib import sha1
|
||||||
|
|
||||||
from calibre.srv.tests.base import BaseTest, TestServer
|
from calibre.srv.tests.base import BaseTest, TestServer
|
||||||
|
from calibre.srv.web_socket import GUID_STR, BINARY, TEXT, MessageWriter, create_frame, CLOSE, NORMAL_CLOSE
|
||||||
|
from calibre.utils.monotonic import monotonic
|
||||||
|
from calibre.utils.socket_inheritance import set_socket_inherit
|
||||||
|
|
||||||
|
HANDSHAKE_STR = '''\
|
||||||
|
GET / HTTP/1.1\r
|
||||||
|
Upgrade: websocket\r
|
||||||
|
Connection: Upgrade\r
|
||||||
|
Sec-WebSocket-Key: {}\r
|
||||||
|
Sec-WebSocket-Version: 13\r
|
||||||
|
''' + '\r\n'
|
||||||
|
|
||||||
|
Frame = namedtuple('Frame', 'fin opcode payload')
|
||||||
|
|
||||||
|
class WSClient(object):
|
||||||
|
|
||||||
|
def __init__(self, port, timeout=5):
|
||||||
|
self.timeout = timeout
|
||||||
|
self.socket = socket.create_connection(('localhost', port), timeout)
|
||||||
|
set_socket_inherit(self.socket, False)
|
||||||
|
self.key = standard_b64encode(os.urandom(8))
|
||||||
|
self.socket.sendall(HANDSHAKE_STR.format(self.key).encode('ascii'))
|
||||||
|
self.read_buf = deque()
|
||||||
|
self.read_upgrade_response()
|
||||||
|
self.mask = os.urandom(4)
|
||||||
|
self.frames = []
|
||||||
|
|
||||||
|
def read_upgrade_response(self):
|
||||||
|
from calibre.srv.http_request import read_headers
|
||||||
|
st = monotonic()
|
||||||
|
buf, idx = b'', -1
|
||||||
|
while idx == -1:
|
||||||
|
data = self.socket.recv(1024)
|
||||||
|
if not data:
|
||||||
|
raise ValueError('Server did not respond with a valid HTTP upgrade response')
|
||||||
|
buf += data
|
||||||
|
if len(buf) > 4096:
|
||||||
|
raise ValueError('Server responded with too much data to HTTP upgrade request')
|
||||||
|
if monotonic() - st > self.timeout:
|
||||||
|
raise ValueError('Timed out while waiting for server response to HTTP upgrade')
|
||||||
|
idx = buf.find(b'\r\n\r\n')
|
||||||
|
response, rest = buf[:idx+4], buf[idx+4:]
|
||||||
|
if rest:
|
||||||
|
self.read_buf.append(rest)
|
||||||
|
lines = (x + b'\r\n' for x in response.split(b'\r\n')[:-1])
|
||||||
|
rl = next(lines)
|
||||||
|
if rl != b'HTTP/1.1 101 Switching Protocols\r\n':
|
||||||
|
raise ValueError('Server did not respond with correct switching protocols line')
|
||||||
|
headers = read_headers(partial(next, lines))
|
||||||
|
key = standard_b64encode(sha1(self.key + GUID_STR).digest())
|
||||||
|
if headers.get('Sec-WebSocket-Accept') != key:
|
||||||
|
raise ValueError('Server did not respond with correct key in Sec-WebSocket-Accept')
|
||||||
|
|
||||||
|
def recv(self, max_amt):
|
||||||
|
if self.read_buf:
|
||||||
|
data = self.read_buf.popleft()
|
||||||
|
if len(data) <= max_amt:
|
||||||
|
return data
|
||||||
|
self.read_buf.appendleft(data[max_amt+1:])
|
||||||
|
return data[:max_amt + 1]
|
||||||
|
return self.socket.recv(max_amt)
|
||||||
|
|
||||||
|
def read_size(self, size):
|
||||||
|
ans = b''
|
||||||
|
while len(ans) < size:
|
||||||
|
d = self.recv(size - len(ans))
|
||||||
|
if not d:
|
||||||
|
raise ValueError('Connection to server closed, no data received')
|
||||||
|
ans += d
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def read_frame(self):
|
||||||
|
b1, b2 = bytearray(self.read_size(2))
|
||||||
|
fin = b1 & 0b10000000
|
||||||
|
opcode = b1 & 0b1111
|
||||||
|
masked = b2 & 0b10000000
|
||||||
|
if masked:
|
||||||
|
raise ValueError('Got a frame with mask bit set from the server')
|
||||||
|
payload_length = b2 & 0b01111111
|
||||||
|
if payload_length == 126:
|
||||||
|
payload_length = struct.unpack(b'!H', self.read_size(2))
|
||||||
|
elif payload_length == 127:
|
||||||
|
payload_length = struct.unpack(b'!Q', self.read_size(8))
|
||||||
|
return Frame(fin, opcode, self.read_size(payload_length))
|
||||||
|
|
||||||
|
def read_message(self):
|
||||||
|
frames = []
|
||||||
|
while True:
|
||||||
|
frame = self.read_frame()
|
||||||
|
frames.append(frame)
|
||||||
|
if frame.fin:
|
||||||
|
break
|
||||||
|
ans, opcode = [], None
|
||||||
|
for frame in frames:
|
||||||
|
if frame is frames[0]:
|
||||||
|
opcode = frame.opcode
|
||||||
|
if frame.fin == 0 and frame.opcode not in (BINARY, TEXT):
|
||||||
|
raise ValueError('Server sent a start frame with fin=0 and bad opcode')
|
||||||
|
ans.append(frame.payload)
|
||||||
|
ans = b''.join(ans)
|
||||||
|
if opcode == TEXT:
|
||||||
|
ans = ans.decode('utf-8')
|
||||||
|
return opcode, ans
|
||||||
|
|
||||||
|
def write_message(self, msg, chunk_size=None):
|
||||||
|
w = MessageWriter(msg, self.mask, chunk_size)
|
||||||
|
while True:
|
||||||
|
frame = w.create_frame()
|
||||||
|
if frame is None:
|
||||||
|
break
|
||||||
|
self.socket.sendall(frame.getvalue())
|
||||||
|
|
||||||
|
def write_frame(self, fin, opcode, payload=b'', rsv=0, mask=True):
|
||||||
|
frame = create_frame(fin, opcode, payload, rsv=rsv, mask=self.mask if mask else None)
|
||||||
|
self.socket.sendall(frame)
|
||||||
|
|
||||||
|
def write_close(self, code, reason=b''):
|
||||||
|
if isinstance(reason, type('')):
|
||||||
|
reason = reason.encode('utf-8')
|
||||||
|
self.write_frame(1, CLOSE, struct.pack(b'!H', code) + reason)
|
||||||
|
|
||||||
|
|
||||||
class TestHandler(object):
|
class TestHandler(object):
|
||||||
|
|
||||||
@ -14,13 +140,13 @@ class TestHandler(object):
|
|||||||
self.connection_state = {}
|
self.connection_state = {}
|
||||||
|
|
||||||
def conn(self, cid):
|
def conn(self, cid):
|
||||||
ans = self.ws_connections.get(cid)
|
ans = self.connections.get(cid)
|
||||||
if ans is not None:
|
if ans is not None:
|
||||||
ans = ans()
|
ans = ans()
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def handle_websocket_upgrade(self, connection_id, connection_ref, inheaders):
|
def handle_websocket_upgrade(self, connection_id, connection_ref, inheaders):
|
||||||
self.ws_connections[connection_id] = connection_ref
|
self.connections[connection_id] = connection_ref
|
||||||
|
|
||||||
def handle_websocket_data(self, data, message_starting, message_finished, connection_id):
|
def handle_websocket_data(self, data, message_starting, message_finished, connection_id):
|
||||||
pass
|
pass
|
||||||
@ -42,7 +168,7 @@ class EchoHandler(TestHandler):
|
|||||||
j = '' if isinstance(self.msg_buf[0], type('')) else b''
|
j = '' if isinstance(self.msg_buf[0], type('')) else b''
|
||||||
msg = j.join(self.msg_buf)
|
msg = j.join(self.msg_buf)
|
||||||
self.msg_buf = []
|
self.msg_buf = []
|
||||||
self.conn(connection_id).send_websocket_message(msg)
|
self.conn(connection_id).send_websocket_message(msg, wakeup=False)
|
||||||
|
|
||||||
|
|
||||||
class WSTestServer(TestServer):
|
class WSTestServer(TestServer):
|
||||||
@ -56,10 +182,32 @@ class WSTestServer(TestServer):
|
|||||||
def ws_handler(self):
|
def ws_handler(self):
|
||||||
return self.loop.handler.websocket_handler
|
return self.loop.handler.websocket_handler
|
||||||
|
|
||||||
|
def ws_connect(self):
|
||||||
|
return WSClient(self.address[1])
|
||||||
|
|
||||||
class WebSocketTest(BaseTest):
|
class WebSocketTest(BaseTest):
|
||||||
|
|
||||||
|
def simple_test(self, client, msgs, expected, close_code=NORMAL_CLOSE, send_close=True, close_reason=b'NORMAL CLOSE'):
|
||||||
|
for msg in msgs:
|
||||||
|
client.write_message(msg)
|
||||||
|
for ex in expected:
|
||||||
|
if isinstance(ex, type('')):
|
||||||
|
ex = TEXT, ex
|
||||||
|
elif isinstance(ex, bytes):
|
||||||
|
ex = BINARY, ex
|
||||||
|
elif isinstance(ex, int):
|
||||||
|
ex = ex, b''
|
||||||
|
self.ae(ex, client.read_message())
|
||||||
|
if send_close:
|
||||||
|
client.write_close(close_code, close_reason)
|
||||||
|
opcode, data = client.read_message()
|
||||||
|
self.ae(opcode, CLOSE)
|
||||||
|
self.ae(close_code, struct.unpack_from(b'!H', data, 0)[0])
|
||||||
|
|
||||||
def test_websocket_basic(self):
|
def test_websocket_basic(self):
|
||||||
'Test basic interaction with the websocket server'
|
'Test basic interaction with the websocket server'
|
||||||
|
|
||||||
with WSTestServer(EchoHandler):
|
with WSTestServer(EchoHandler) as server:
|
||||||
pass
|
client = server.ws_connect()
|
||||||
|
st = partial(self.simple_test, client)
|
||||||
|
st([''], [''])
|
||||||
|
@ -116,6 +116,9 @@ class ReadFrame(object): # {{{
|
|||||||
self.state = self.read_payload
|
self.state = self.read_payload
|
||||||
self.pos = 0
|
self.pos = 0
|
||||||
self.frame_starting = True
|
self.frame_starting = True
|
||||||
|
if self.payload_length == 0:
|
||||||
|
conn.ws_data_received(b'', self.opcode, True, True, bool(self.fin))
|
||||||
|
self.state = None
|
||||||
|
|
||||||
def read_payload(self, conn):
|
def read_payload(self, conn):
|
||||||
bytes_left = self.payload_length - self.pos
|
bytes_left = self.payload_length - self.pos
|
||||||
@ -133,13 +136,13 @@ class ReadFrame(object): # {{{
|
|||||||
|
|
||||||
# Sending frames {{{
|
# Sending frames {{{
|
||||||
|
|
||||||
def create_frame(fin, opcode, payload, mask=None):
|
def create_frame(fin, opcode, payload, mask=None, rsv=0):
|
||||||
if isinstance(payload, type('')):
|
if isinstance(payload, type('')):
|
||||||
payload = payload.encode('utf-8')
|
payload = payload.encode('utf-8')
|
||||||
l = len(payload)
|
l = len(payload)
|
||||||
opcode &= 0b1111
|
opcode &= 0b1111
|
||||||
b1 = opcode | (0b10000000 if fin else 0)
|
b1 = opcode | (0b10000000 if fin else 0)
|
||||||
b2 = 0 if mask is None else 0b10000000
|
b2 = rsv | (0 if mask is None else 0b10000000)
|
||||||
if l < 126:
|
if l < 126:
|
||||||
header = bytes(bytearray((b1, b2 | l)))
|
header = bytes(bytearray((b1, b2 | l)))
|
||||||
elif 126 <= l <= 65535:
|
elif 126 <= l <= 65535:
|
||||||
@ -159,13 +162,14 @@ def create_frame(fin, opcode, payload, mask=None):
|
|||||||
|
|
||||||
class MessageWriter(object):
|
class MessageWriter(object):
|
||||||
|
|
||||||
def __init__(self, buf):
|
def __init__(self, buf, mask=None, chunk_size=None):
|
||||||
self.buf, self.data_type = buf, BINARY
|
self.buf, self.data_type, self.mask = buf, BINARY, mask
|
||||||
if isinstance(buf, type('')):
|
if isinstance(buf, type('')):
|
||||||
self.buf, self.data_type = BytesIO(buf.encode('utf-8')), TEXT
|
self.buf, self.data_type = BytesIO(buf.encode('utf-8')), TEXT
|
||||||
elif isinstance(buf, bytes):
|
elif isinstance(buf, bytes):
|
||||||
self.buf = BytesIO(buf)
|
self.buf = BytesIO(buf)
|
||||||
buf = self.buf
|
buf = self.buf
|
||||||
|
self.chunk_size = chunk_size or SEND_CHUNK_SIZE
|
||||||
try:
|
try:
|
||||||
pos = buf.tell()
|
pos = buf.tell()
|
||||||
buf.seek(0, os.SEEK_END)
|
buf.seek(0, os.SEEK_END)
|
||||||
@ -179,12 +183,12 @@ class MessageWriter(object):
|
|||||||
if self.exhausted:
|
if self.exhausted:
|
||||||
return None
|
return None
|
||||||
buf = self.buf
|
buf = self.buf
|
||||||
raw = buf.read(SEND_CHUNK_SIZE)
|
raw = buf.read(self.chunk_size)
|
||||||
has_more = True if self.size is None else self.size > buf.tell()
|
has_more = True if self.size is None else self.size > buf.tell()
|
||||||
fin = 0 if has_more and raw else 1
|
fin = 0 if has_more and raw else 1
|
||||||
opcode = 0 if self.first_frame_created else self.data_type
|
opcode = 0 if self.first_frame_created else self.data_type
|
||||||
self.first_frame_created, self.exhausted = True, bool(fin)
|
self.first_frame_created, self.exhausted = True, bool(fin)
|
||||||
return BytesIO(create_frame(fin, opcode, raw))
|
return BytesIO(create_frame(fin, opcode, raw, self.mask))
|
||||||
# }}}
|
# }}}
|
||||||
|
|
||||||
conn_id = 0
|
conn_id = 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user