diff --git a/src/calibre/srv/http_request.py b/src/calibre/srv/http_request.py index 38f9b6b33e..d444182286 100644 --- a/src/calibre/srv/http_request.py +++ b/src/calibre/srv/http_request.py @@ -105,6 +105,7 @@ comma_separated_headers = { decoded_headers = { 'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect', 'WWW-Authenticate', 'Authorization', + 'Sec-WebSocket-Key', 'Sec-WebSocket-Version', 'Sec-WebSocket-Protocol', } | comma_separated_headers uppercase_headers = {'WWW', 'TE'} @@ -114,6 +115,8 @@ def normalize_header_name(name): q = parts[0].upper() if q in uppercase_headers: parts[0] = q + if len(parts) == 3 and parts[1] == 'Websocket': + parts[1] = 'WebSocket' return '-'.join(parts) class HTTPHeaderParser(object): diff --git a/src/calibre/srv/http_response.py b/src/calibre/srv/http_response.py index 524222a121..926e483ef3 100644 --- a/src/calibre/srv/http_response.py +++ b/src/calibre/srv/http_response.py @@ -621,13 +621,19 @@ class HTTPConnection(HTTPRequest): request.status_code = httplib.PARTIAL_CONTENT return output -def create_http_handler(handler): +def create_http_handler(handler=None, websocket_handler=None): + from calibre.srv.web_socket import WebSocketConnection static_cache = {} translator_cache = {} + if handler is None: + def dummy_http_handler(data): + return 'Hello' + handler = dummy_http_handler @wraps(handler) def wrapper(*args, **kwargs): - ans = HTTPConnection(*args, **kwargs) + ans = WebSocketConnection(*args, **kwargs) ans.request_handler = handler + ans.websocket_handler = websocket_handler ans.static_cache = static_cache ans.translator_cache = translator_cache return ans diff --git a/src/calibre/srv/loop.py b/src/calibre/srv/loop.py index 4be6666944..9be9f7b6c3 100644 --- a/src/calibre/srv/loop.py +++ b/src/calibre/srv/loop.py @@ -117,8 +117,8 @@ class ReadBuffer(object): # {{{ class Connection(object): # {{{ - def __init__(self, socket, opts, ssl_context, tdir, addr, pool, log): - self.opts, self.pool, self.log = opts, pool, log + def __init__(self, socket, opts, ssl_context, tdir, addr, pool, log, wakeup): + self.opts, self.pool, self.log, self.wakeup = opts, pool, log, wakeup try: self.remote_addr = addr[0] self.remote_port = addr[1] @@ -286,7 +286,7 @@ class ServerLoop(object): self.opts = opts or Options() self.log = log or ThreadSafeLog(level=ThreadSafeLog.DEBUG) - ba = (opts.listen_on, int(opts.port)) + ba = (self.opts.listen_on, int(self.opts.port)) if not ba[0]: # AI_PASSIVE does not work with host of '' or None ba = ('0.0.0.0', ba[1]) @@ -532,7 +532,7 @@ class ServerLoop(object): if sock is not None: s = sock.fileno() if s > -1: - self.connection_map[s] = conn = self.handler(sock, self.opts, self.ssl_context, self.tdir, addr, self.pool, self.log) + self.connection_map[s] = conn = self.handler(sock, self.opts, self.ssl_context, self.tdir, addr, self.pool, self.log, self.wakeup) if self.ssl_context is not None: yield s, conn, RDWR elif s == control: diff --git a/src/calibre/srv/tests/web_sockets.py b/src/calibre/srv/tests/web_sockets.py new file mode 100644 index 0000000000..15c8f78463 --- /dev/null +++ b/src/calibre/srv/tests/web_sockets.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python2 +# vim:fileencoding=utf-8 +# License: GPLv3 Copyright: 2015, Kovid Goyal + +from __future__ import (unicode_literals, division, absolute_import, + print_function) + +from calibre.srv.tests.base import BaseTest, TestServer + +class TestHandler(object): + + def __init__(self): + self.connections = {} + self.connection_state = {} + + def conn(self, cid): + ans = self.ws_connections.get(cid) + if ans is not None: + ans = ans() + return ans + + def handle_websocket_upgrade(self, connection_id, connection_ref, inheaders): + self.ws_connections[connection_id] = connection_ref + + def handle_websocket_data(self, data, message_starting, message_finished, connection_id): + pass + + def handle_websocket_close(self, connection_id): + self.connections.pop(connection_id, None) + +class EchoHandler(TestHandler): + + def __init__(self): + TestHandler.__init__(self) + self.msg_buf = [] + + def handle_websocket_data(self, data, message_starting, message_finished, connection_id): + if message_starting: + self.msg_buf = [] + self.msg_buf.append(data) + if message_finished: + j = '' if isinstance(self.msg_buf[0], type('')) else b'' + msg = j.join(self.msg_buf) + self.msg_buf = [] + self.conn(connection_id).send_websocket_message(msg) + + +class WSTestServer(TestServer): + + def __init__(self, handler=TestHandler): + TestServer.__init__(self, None) + from calibre.srv.http_response import create_http_handler + self.loop.handler = create_http_handler(websocket_handler=handler()) + + @property + def ws_handler(self): + return self.loop.handler.websocket_handler + +class WebSocketTest(BaseTest): + + def test_websocket_basic(self): + 'Test basic interaction with the websocket server' + + with WSTestServer(EchoHandler): + pass diff --git a/src/calibre/srv/web_socket.py b/src/calibre/srv/web_socket.py new file mode 100644 index 0000000000..a1a9ca8321 --- /dev/null +++ b/src/calibre/srv/web_socket.py @@ -0,0 +1,479 @@ +#!/usr/bin/env python2 +# vim:fileencoding=utf-8 +# License: GPLv3 Copyright: 2015, Kovid Goyal + +from __future__ import (unicode_literals, division, absolute_import, + print_function) + +import codecs, httplib, struct, os, weakref, repr as reprlib, time +from base64 import standard_b64encode +from functools import partial +from hashlib import sha1 +from io import BytesIO +from Queue import Queue, Empty +from threading import Lock, Thread + +from calibre import as_unicode +from calibre.srv.loop import ServerLoop, HandleInterrupt, WRITE, READ, RDWR, Connection +from calibre.srv.http_response import HTTPConnection, create_http_handler +from calibre.srv.utils import DESIRED_SEND_BUFFER_SIZE + +HANDSHAKE_STR = ( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: %s\r\n\r\n" +) +GUID_STR = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + +CONTINUATION = 0x0 +TEXT = 0x1 +BINARY = 0x2 +CLOSE = 0x8 +PING = 0x9 +PONG = 0xA +CONTROL_CODES = (CLOSE, PING, PONG) +ALL_CODES = CONTROL_CODES + (CONTINUATION, TEXT, BINARY) + +CHUNK_SIZE = 16 * 1024 +SEND_CHUNK_SIZE = DESIRED_SEND_BUFFER_SIZE - 16 + +NORMAL_CLOSE = 1000 +SHUTTING_DOWN = 1001 +PROTOCOL_ERROR = 1002 +UNSUPPORTED_DATA = 1003 +INCONSISTENT_DATA = 1007 +POLICY_VIOLATION = 1008 +MESSAGE_TOO_BIG = 1009 +UNEXPECTED_ERROR = 1011 + +class ReadFrame(object): # {{{ + + def __init__(self): + self.state = self.read_header0 + + def read_header0(self, conn): + data = conn.recv(1) + if not data: + return + b = ord(data) + self.fin = b & 0b10000000 + self.opcode = b & 0b1111 + self.state = self.read_header1 + if self.opcode not in ALL_CODES: + conn.log.error('Unknown OPCODE from client: %r' % self.opcode) + conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.code) + return + + def read_header1(self, conn): + data = conn.recv(1) + if not data: + return + b = ord(data) + self.mask = b & 0b10000000 + if not self.mask: + conn.log.error('Unmasked packet from client') + conn.websocket_close(PROTOCOL_ERROR, 'Unmasked packet not allowed') + return + self.payload_length = b & 0b01111111 + if self.opcode in (PING, PONG) and self.payload_length > 125: + conn.log.error('Too large ping packet from client') + conn.websocket_close(PROTOCOL_ERROR, 'Ping packet too large') + return + self.mask_buf = b'' + if self.payload_length == 126: + self.plbuf = b'' + self.state = partial(self.read_payload_length, 2) + elif self.payload_length == 127: + self.plbuf = b'' + self.state = partial(self.read_payload_length, 8) + else: + self.state = self.read_masking_key + + def read_payload_length(self, size_in_bytes, conn): + num_left = size_in_bytes - len(self.plbuf) + data = conn.recv(num_left) + if not data: + return + self.plbuf += data + if len(self.plbuf) < size_in_bytes: + return + fmt = b'!H' if size_in_bytes == 2 else b'!Q' + self.payload_length = struct.unpack(fmt, self.plbuf)[0] + del self.plbuf + self.state = self.read_masking_key + + def read_masking_key(self, conn): + num_left = 4 - len(self.mask_buf) + data = conn.recv(num_left) + if not data: + return + self.mask_buf += data + if len(self.mask_buf) < 4: + return + self.mask = bytearray(self.mask_buf) + del self.mask_buf + self.state = self.read_payload + self.pos = 0 + self.frame_starting = True + + def read_payload(self, conn): + bytes_left = self.payload_length - self.pos + if bytes_left > 0: + data = conn.recv(min(bytes_left, CHUNK_SIZE)) + if not data: + return + else: + data = b'' + unmasked = bytes(bytearray(ord(byte) ^ self.mask[(self.pos + i) & 3] for i, byte in enumerate(data))) + self.pos += len(data) + conn.ws_data_received(unmasked, self.opcode, self.frame_starting, self.pos >= self.payload_length, bool(self.fin)) + self.frame_starting = False +# }}} + +# Sending frames {{{ + +def create_frame(fin, opcode, payload, mask=None): + if isinstance(payload, type('')): + payload = payload.encode('utf-8') + l = len(payload) + opcode &= 0b1111 + b1 = opcode | (0b10000000 if fin else 0) + b2 = 0 if mask is None else 0b10000000 + if l < 126: + header = bytes(bytearray((b1, b2 | l))) + elif 126 <= l <= 65535: + header = bytes(bytearray((b1, b2 | 126))) + struct.pack(b'!H', l) + else: + header = bytes(bytearray((b1, b2 | 127))) + struct.pack(b'!Q', l) + if mask is not None: + header += mask + mask = bytearray(mask) + payload = bytearray(payload) + for i in xrange(len(payload)): + payload[i] ^= mask[i & 3] + payload = bytes(payload) + + return header + payload + + +class MessageWriter(object): + + def __init__(self, buf): + self.buf, self.data_type = buf, BINARY + if isinstance(buf, type('')): + self.buf, self.data_type = BytesIO(buf.encode('utf-8')), TEXT + elif isinstance(buf, bytes): + self.buf = BytesIO(buf) + buf = self.buf + try: + pos = buf.tell() + buf.seek(0, os.SEEK_END) + self.size = buf.tell() - pos + buf.seek(pos) + except Exception: + self.size = None + self.first_frame_created = self.exhausted = False + + def create_frame(self): + if self.exhausted: + return None + buf = self.buf + raw = buf.read(SEND_CHUNK_SIZE) + has_more = True if self.size is None else self.size > buf.tell() + fin = 0 if has_more and raw else 1 + opcode = 0 if self.first_frame_created else self.data_type + self.first_frame_created, self.exhausted = True, bool(fin) + return BytesIO(create_frame(fin, opcode, raw)) +# }}} + +conn_id = 0 + +class WebSocketConnection(HTTPConnection): + + in_websocket_mode = False + websocket_handler = None + + def __init__(self, *args, **kwargs): + global conn_id + HTTPConnection.__init__(self, *args, **kwargs) + self.sendq = Queue() + self.control_frames = [] + self.cf_lock = Lock() + self.sending = None + self.send_buf = None + self.frag_decoder = codecs.getincrementaldecoder('utf-8')(errors='strict') + self.ws_close_received = self.ws_close_sent = False + conn_id += 1 + self.websocket_connection_id = conn_id + + def finalize_headers(self, inheaders): + upgrade = inheaders.get('Upgrade', None) + key = inheaders.get('Sec-WebSocket-Key', None) + conn = inheaders.get('Connection', None) + if key is None or upgrade.lower() != 'websocket' or conn != 'Upgrade': + return HTTPConnection.finalize_headers(self, inheaders) + ver = inheaders.get('Sec-WebSocket-Version', 'Unknown') + try: + ver_ok = int(ver) >= 13 + except Exception: + ver_ok = False + if not ver_ok: + return self.simple_response(httplib.BAD_REQUEST, 'Unsupported WebSocket protocol version: %s' % ver) + if self.method != 'GET': + return self.simple_response(httplib.BAD_REQUEST, 'Invalid WebSocket method: %s' % self.method) + + response = HANDSHAKE_STR % standard_b64encode(sha1(key + GUID_STR).digest()) + self.optimize_for_sending_packet() + self.set_state(WRITE, self.upgrade_connection_to_ws, BytesIO(response.encode('ascii')), inheaders) + + def upgrade_connection_to_ws(self, buf, inheaders, event): + if self.write(buf): + if self.websocket_handler is None: + self.websocket_handler = DummyHandler() + self.current_recv_frame, self.current_recv_opcode = ReadFrame(), None + self.in_websocket_mode = True + try: + self.websocket_handler.handle_websocket_upgrade(self.websocket_connection_id, weakref.ref(self), inheaders) + except Exception as err: + self.log.exception('Error in WebSockets upgrade handler:') + self.websocket_close(UNEXPECTED_ERROR, 'Unexpected error in handler: %r' % as_unicode(err)) + self.set_ws_state() + + def set_ws_state(self): + if self.ws_close_sent or self.ws_close_received: + if self.ws_close_sent: + self.set_state(READ, self.ws_duplex) + if self.ws_close_received: + self.ready = False + else: + self.set_state(WRITE, self.ws_duplex) + return + + if self.send_buf is not None or self.sending is not None: + self.set_state(RDWR, self.ws_duplex) + else: + try: + self.sending = self.sendq.get_nowait() + except Empty: + with self.cf_lock: + if self.control_frames: + self.set_state(RDWR, self.ws_duplex) + else: + self.set_state(READ, self.ws_duplex) + else: + self.set_state(RDWR, self.ws_duplex) + + def ws_duplex(self, event): + if event is READ: + self.ws_read() + elif event is WRITE: + self.ws_write() + self.set_ws_state() + + def ws_read(self): + self.current_recv_frame.state(self) + + def ws_data_received(self, data, opcode, frame_starting, frame_finished, is_final_frame_of_message): + if opcode in CONTROL_CODES: + return self.ws_control_frame(opcode, data) + + message_starting = self.current_recv_opcode is None + if message_starting: + self.current_recv_opcode = opcode + else: + if opcode != CONTINUATION: + # This is a new message + try: + self.handle_websocket_data('' if self.current_recv_opcode == TEXT else b'', False, True) + except Exception as err: + self.log.exception('Error in WebSockets data handler:') + self.websocket_close(UNEXPECTED_ERROR, 'Unexpected error in handler: %r' % as_unicode(err)) + self.current_recv_opcode = opcode + message_starting = True + message_finished = frame_finished and is_final_frame_of_message + if message_finished: + self.current_recv_frame, self.current_recv_opcode = ReadFrame(), None + if opcode == TEXT: + if message_starting: + self.frag_decoder.reset() + try: + data = self.frag_decoder.decode(data, final=message_finished) + except ValueError: + self.frag_decoder.reset() + return self.websocket_close(INCONSISTENT_DATA, 'Not valid UTF-8') + if message_finished: + self.frag_decoder.reset() + try: + self.handle_websocket_data(data, message_starting, message_finished) + except Exception as err: + self.log.exception('Error in WebSockets data handler:') + self.websocket_close(UNEXPECTED_ERROR, 'Unexpected error in handler: %r' % as_unicode(err)) + + def ws_control_frame(self, opcode, data): + if opcode in (PING, CLOSE): + rcode = PONG if opcode == PING else CLOSE + f = BytesIO(create_frame(1, rcode, data)) + f.is_close_frame = opcode == CLOSE + with self.cf_lock: + self.control_frames.append(f) + if opcode == CLOSE: + self.ws_close_received = True + self.set_ws_state() + + def websocket_close(self, code=NORMAL_CLOSE, reason=b''): + if isinstance(reason, type('')): + reason = reason.encode('utf-8') + reason = reason[:123] + if code is None and not reason: + f = BytesIO(create_frame(1, CLOSE, b'')) + else: + f = BytesIO(create_frame(1, CLOSE, struct.pack(b'!H', code) + reason)) + f.is_close_frame = True + with self.cf_lock: + self.control_frames.append(f) + self.set_ws_state() + + def ws_write(self): + if self.ws_close_sent: + return + if self.send_buf is not None: + if self.write(self.send_buf): + if getattr(self.send_buf, 'is_close_frame', False): + self.ws_close_sent = True + self.send_buf = None + else: + with self.cf_lock: + try: + self.send_buf = self.control_frames.pop() + except IndexError: + if self.sending is not None: + self.send_buf = self.sending.create_frame() + if self.send_buf is None: + self.sending = None + + def close(self): + if self.in_websocket_mode: + try: + self.websocket_handler.handle_websocket_close(self.websocket_connection_id) + except Exception: + self.log.exception('Error in WebSocket close handler') + # Try to write a close frame, just once + try: + if self.send_buf is None: + self.websocket_close(SHUTTING_DOWN, 'Shutting down') + self.write(self.control_frames.pop()) + except Exception: + pass + Connection.close(self) + else: + HTTPConnection.close(self) + + def send_websocket_message(self, buf, wakeup=True): + self.sendq.put(MessageWriter(buf)) + self.wait_for = RDWR + if wakeup: + self.wakeup() + + def handle_websocket_data(self, data, message_starting, message_finished): + self.websocket_handler.handle_websocket_data(data, message_starting, message_finished, self.websocket_connection_id) + +class DummyHandler(object): + + def handle_websocket_upgrade(self, connection_id, connection_ref, inheaders): + conn = connection_ref() + conn.websocket_close(NORMAL_CLOSE, 'No WebSocket handler available') + + def handle_websocket_data(self, data, message_starting, message_finished, connection_id): + pass + + def handle_websocket_close(self, connection_id): + pass + +# Testing {{{ +class EchoClientHandler(object): + + def __init__(self, *args, **kwargs): + self.msg_buf = [] + self.ws_connections = {} + + def conn(self, cid): + ans = self.ws_connections.get(cid) + if ans is not None: + ans = ans() + return ans + + def handle_websocket_upgrade(self, connection_id, connection_ref, inheaders): + self.ws_connections[connection_id] = connection_ref + + def handle_websocket_data(self, data, message_starting, message_finished, connection_id): + if message_starting: + self.msg_buf = [] + self.msg_buf.append(data) + if message_finished: + j = '' if isinstance(self.msg_buf[0], type('')) else b'' + msg = j.join(self.msg_buf) + self.msg_buf = [] + print('Received message from client:', reprlib.repr(msg)) + self.conn(connection_id).send_websocket_message(msg) + + def handle_websocket_close(self, connection_id): + self.ws_connections.pop(connection_id, None) + +class EchoServerHandler(object): + + def __init__(self, *args, **kwargs): + self.msg_buf = [] + self.ws_connections = {} + t = Thread(name='StdinReader', target=self.get_input) + t.start() + + def conn(self, cid): + ans = self.ws_connections.get(cid) + if ans is not None: + ans = ans() + return ans + + def handle_websocket_upgrade(self, connection_id, connection_ref, inheaders): + self.ws_connections[connection_id] = connection_ref + + def handle_websocket_data(self, data, message_starting, message_finished, connection_id): + if message_starting: + self.msg_buf = [] + self.msg_buf.append(data) + if message_finished: + j = '' if isinstance(self.msg_buf[0], type('')) else b'' + msg = j.join(self.msg_buf) + self.msg_buf = [] + print('Received message from client:', msg) + + def get_input(self): + time.sleep(0.5) + try: + while True: + try: + raw = raw_input('Enter some text: ') + if not raw: + break + raw = raw.decode('utf-8') + for conn in self.ws_connections.itervalues(): + conn = conn() + if conn is not None: + conn.send_websocket_message(raw) + except (EOFError, KeyboardInterrupt): + break + finally: + for conn in self.ws_connections.itervalues(): + conn = conn() + if conn is not None: + conn.close() + print('\nUse Ctrl+C to exit server loop') + + def handle_websocket_close(self, connection_id): + self.ws_connections.pop(connection_id, None) + +if __name__ == '__main__': + s = ServerLoop(create_http_handler(websocket_handler=EchoClientHandler())) + with HandleInterrupt(s.wakeup): + s.serve_forever() +# }}}