diff --git a/src/calibre/srv/loop.py b/src/calibre/srv/loop.py index c37a03a852..ea6a99268a 100644 --- a/src/calibre/srv/loop.py +++ b/src/calibre/srv/loop.py @@ -213,6 +213,29 @@ class Connection(object): # {{{ return b'' raise + def recv_into(self, buf, amt=0): + amt = amt or len(buf) + if self.read_buffer.has_data: + data = self.read_buffer.read(amt) + buf[0:len(data)] = data + return len(data) + try: + bytes_read = self.socket.recv_into(buf, amt) + self.last_activity = monotonic() + if bytes_read == 0: + # a closed connection is indicated by signaling + # a read condition, and having recv() return 0. + self.ready = False + return 0 + return bytes_read + except socket.error as e: + if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr: + return 0 + if e.errno in socket_errors_socket_closed: + self.ready = False + return 0 + raise + def fill_read_buffer(self): try: num = self.read_buffer.recv_from(self.socket) diff --git a/src/calibre/srv/tests/web_sockets.py b/src/calibre/srv/tests/web_sockets.py index 2a265ce57c..bf7a775389 100644 --- a/src/calibre/srv/tests/web_sockets.py +++ b/src/calibre/srv/tests/web_sockets.py @@ -37,7 +37,7 @@ class WSClient(object): self.socket.sendall(HANDSHAKE_STR.format(self.key).encode('ascii')) self.read_buf = deque() self.read_upgrade_response() - self.mask = os.urandom(4) + self.mask = memoryview(os.urandom(4)) self.frames = [] def read_upgrade_response(self): diff --git a/src/calibre/srv/web_socket.py b/src/calibre/srv/web_socket.py index eb78c01f8f..e0ee4e2d7f 100644 --- a/src/calibre/srv/web_socket.py +++ b/src/calibre/srv/web_socket.py @@ -5,12 +5,12 @@ from __future__ import (unicode_literals, division, absolute_import, print_function) -import httplib, struct, os, weakref, socket +import httplib, os, weakref, socket from base64 import standard_b64encode from collections import deque -from functools import partial from hashlib import sha1 from Queue import Queue, Empty +from struct import unpack_from, pack, error as struct_error from threading import Lock from calibre import as_unicode @@ -58,114 +58,118 @@ RESERVED_CLOSE_CODES = (1004,1005,1006,) class ReadFrame(object): # {{{ def __init__(self): + self.header_buf = bytearray(14) + self.rbuf = bytearray(CHUNK_SIZE) self.reset() def reset(self): - self.state = self.read_header0 - self.control_buf = [] + self.header_view = memoryview(self.header_buf)[:6] + self.state = self.read_header def __call__(self, conn): return self.state(conn) - def read_header0(self, conn): - data = conn.recv(1) - if not data: - return - b = ord(data) - self.fin = bool(b & 0b10000000) - if b & 0b01110000: - conn.log.error('RSV bits set in frame from client') - conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set') + def read_header(self, conn): + num_bytes = conn.recv_into(self.header_view) + if num_bytes == 0: return + read_bytes = 6 - len(self.header_view) + num_bytes + if read_bytes > 2: + b1, b2 = self.header_buf[0], self.header_buf[1] + self.fin = bool(b1 & 0b10000000) + if b1 & 0b01110000: + conn.log.error('RSV bits set in frame from client') + conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set') + return - self.opcode = b & 0b1111 - self.state = self.read_header1 - self.is_control = self.opcode in CONTROL_CODES - if self.opcode not in ALL_CODES: - conn.log.error('Unknown OPCODE from client: %r' % self.opcode) - conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode) - return - if not self.fin and self.is_control: - conn.log.error('Fragmented control frame from client') - conn.websocket_close(PROTOCOL_ERROR, 'Fragmented control frame') - return + self.opcode = b1 & 0b1111 + self.is_control = self.opcode in CONTROL_CODES + if self.opcode not in ALL_CODES: + conn.log.error('Unknown OPCODE from client: %r' % self.opcode) + conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode) + return + if not self.fin and self.is_control: + conn.log.error('Fragmented control frame from client') + conn.websocket_close(PROTOCOL_ERROR, 'Fragmented control frame') + return - def read_header1(self, conn): - data = conn.recv(1) - if not data: - return - b = ord(data) - self.mask = b & 0b10000000 - if not self.mask: - conn.log.error('Unmasked packet from client') - conn.websocket_close(PROTOCOL_ERROR, 'Unmasked packet not allowed') - self.reset() - return - self.payload_length = b & 0b01111111 - if self.is_control and self.payload_length > 125: - conn.log.error('Too large control frame from client') - conn.websocket_close(PROTOCOL_ERROR, 'Control frame too large') - self.reset() - return - self.mask_buf = b'' - if self.payload_length == 126: - self.plbuf = b'' - self.state = partial(self.read_payload_length, 2) - elif self.payload_length == 127: - self.plbuf = b'' - self.state = partial(self.read_payload_length, 8) + mask = b2 & 0b10000000 + if not mask: + conn.log.error('Unmasked packet from client') + conn.websocket_close(PROTOCOL_ERROR, 'Unmasked packet not allowed') + self.reset() + return + self.payload_length = l = b2 & 0b01111111 + if self.is_control and l > 125: + conn.log.error('Too large control frame from client') + conn.websocket_close(PROTOCOL_ERROR, 'Control frame too large') + self.reset() + return + header_len = 6 + (0 if l < 126 else 2 if l == 126 else 8) + if header_len <= read_bytes: + self.process_header(conn) + else: + self.header_view = memoryview(self.header_buf)[read_bytes:header_len] + self.state = self.finish_reading_header else: - self.state = self.read_masking_key + self.header_view = self.header_view[num_bytes:] - def read_payload_length(self, size_in_bytes, conn): - num_left = size_in_bytes - len(self.plbuf) - data = conn.recv(num_left) - if not data: + def finish_reading_header(self, conn): + num_bytes = conn.recv_into(self.header_view) + if num_bytes == 0: return - self.plbuf += data - if len(self.plbuf) < size_in_bytes: - return - fmt = b'!H' if size_in_bytes == 2 else b'!Q' - self.payload_length = struct.unpack(fmt, self.plbuf)[0] - del self.plbuf - self.state = self.read_masking_key + if num_bytes >= len(self.header_view): + self.process_header(conn) + else: + self.header_view = self.header_view[num_bytes:] - def read_masking_key(self, conn): - num_left = 4 - len(self.mask_buf) - data = conn.recv(num_left) - if not data: - return - self.mask_buf += data - if len(self.mask_buf) < 4: - return - self.state = self.read_payload - self.pos = 0 + def process_header(self, conn): + if self.payload_length < 126: + self.mask = memoryview(self.header_buf)[2:6] + elif self.payload_length == 126: + self.payload_length, = unpack_from(b'!H', self.header_buf, 2) + self.mask = memoryview(self.header_buf)[4:8] + else: + self.payload_length, = unpack_from(b'!Q', self.header_buf, 2) + self.mask = memoryview(self.header_buf)[10:14] self.frame_starting = True - if self.payload_length == 0: - conn.ws_data_received(b'', self.opcode, True, True, self.fin) + self.bytes_received = 0 + if self.payload_length <= CHUNK_SIZE: + if self.payload_length == 0: + conn.ws_data_received(b'', self.opcode, True, True, self.fin) + self.reset() + else: + self.rview = memoryview(self.rbuf)[:self.payload_length] + self.state = self.read_packet + else: + self.rview = memoryview(self.rbuf) + self.state = self.read_payload + + def read_packet(self, conn): + num_bytes = conn.recv_into(self.rview) + if num_bytes == 0: + return + if num_bytes >= len(self.rview): + data = memoryview(self.rbuf)[:self.payload_length] + fast_mask(data, self.mask) + conn.ws_data_received(data.tobytes(), self.opcode, True, True, self.fin) self.reset() + else: + self.rview = self.rview[num_bytes:] def read_payload(self, conn): - bytes_left = self.payload_length - self.pos - if bytes_left > 0: - data = conn.recv(min(bytes_left, CHUNK_SIZE)) - if not data: - return - data = fast_mask(data, self.mask_buf, self.pos) - else: - data = b'' - self.pos += len(data) - frame_finished = self.pos >= self.payload_length - if self.is_control: - self.control_buf.append(data) - if frame_finished: - data = b''.join(self.control_buf) - else: - return - conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, self.fin) + num_bytes = conn.recv_into(self.rview, min(len(self.rview), self.payload_length - self.bytes_received)) + if num_bytes == 0: + return + data = memoryview(self.rbuf)[:num_bytes] + fast_mask(data, self.mask, self.bytes_received) + self.bytes_received += num_bytes + frame_finished = self.bytes_received >= self.payload_length + conn.ws_data_received(data.tobytes(), self.opcode, self.frame_starting, frame_finished, self.fin) self.frame_starting = False if frame_finished: self.reset() + # }}} # Sending frames {{{ @@ -174,20 +178,26 @@ def create_frame(fin, opcode, payload, mask=None, rsv=0): if isinstance(payload, type('')): payload = payload.encode('utf-8') l = len(payload) - opcode &= 0b1111 - b1 = opcode | (0b10000000 if fin else 0) | (rsv & 0b01110000) - b2 = 0 if mask is None else 0b10000000 + header_len = 2 + (0 if l < 126 else 2 if 126 <= l <= 65535 else 8) + (0 if mask is None else 4) + frame = bytearray(header_len + l) + if l > 0: + frame[-l:] = payload + frame[0] = (opcode & 0b1111) | (0b10000000 if fin else 0) | (rsv & 0b01110000) if l < 126: - header = bytes(bytearray((b1, b2 | l))) + frame[1] = l elif 126 <= l <= 65535: - header = bytes(bytearray((b1, b2 | 126))) + struct.pack(b'!H', l) + frame[2:4] = pack(b'!H', l) + frame[1] = 126 else: - header = bytes(bytearray((b1, b2 | 127))) + struct.pack(b'!Q', l) + frame[2:10] = pack(b'!Q', l) + frame[1] = 127 if mask is not None: - header += mask - payload = fast_mask(payload, mask) + frame[1] |= 0b10000000 + frame[header_len-4:header_len] = mask + if l > 0: + fast_mask(memoryview(frame)[-l:], mask) - return header + payload + return memoryview(frame) class MessageWriter(object): @@ -379,20 +389,20 @@ class WebSocketConnection(HTTPConnection): self.stop_reading = True if data: try: - close_code = struct.unpack_from(b'!H', data)[0] - except struct.error: - data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be atleast two bytes' + close_code = unpack_from(b'!H', data)[0] + except struct_error: + data = pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be atleast two bytes' else: try: utf8_decode(data[2:]) except ValueError: - data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be valid UTF-8' + data = pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be valid UTF-8' else: if close_code < 1000 or close_code in RESERVED_CLOSE_CODES or (1011 < close_code < 3000): - data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close code reserved' + data = pack(b'!H', PROTOCOL_ERROR) + b'close code reserved' else: close_code = NORMAL_CLOSE - data = struct.pack(b'!H', close_code) + data = pack(b'!H', close_code) f = ReadOnlyFileBuffer(create_frame(1, rcode, data)) f.is_close_frame = opcode == CLOSE with self.cf_lock: @@ -412,7 +422,7 @@ class WebSocketConnection(HTTPConnection): if code is None and not reason: f = ReadOnlyFileBuffer(create_frame(1, CLOSE, b'')) else: - f = ReadOnlyFileBuffer(create_frame(1, CLOSE, struct.pack(b'!H', code) + reason)) + f = ReadOnlyFileBuffer(create_frame(1, CLOSE, pack(b'!H', code) + reason)) f.is_close_frame = True with self.cf_lock: self.control_frames.append(f) diff --git a/src/calibre/utils/speedup.c b/src/calibre/utils/speedup.c index 9cad526e8d..e63bcff6f5 100644 --- a/src/calibre/utils/speedup.c +++ b/src/calibre/utils/speedup.c @@ -231,19 +231,17 @@ speedup_create_texture(PyObject *self, PyObject *args, PyObject *kw) { static PyObject* speedup_websocket_mask(PyObject *self, PyObject *args) { - PyObject *data = NULL, *mask = NULL, *ans = NULL; - Py_ssize_t offset_ = 0; - size_t offset = 0, i = 0; - char *data_buf = NULL, *mask_buf = NULL, *ans_buf = NULL; - if(!PyArg_ParseTuple(args, "OO|n", &data, &mask, &offset_)) return NULL; - offset = (size_t)offset_; - ans = PyBytes_FromStringAndSize(NULL, PyBytes_GET_SIZE(data)); - if (ans != NULL) { - data_buf = PyBytes_AS_STRING(data); mask_buf = PyBytes_AS_STRING(mask); ans_buf = PyBytes_AS_STRING(ans); - for(i = 0; i < (size_t)PyBytes_GET_SIZE(ans); i++) - ans_buf[i] = data_buf[i] ^ mask_buf[(i + offset) & 3]; - } - return ans; + PyObject *data = NULL, *mask = NULL; + Py_buffer data_buf = {0}, mask_buf = {0}; + Py_ssize_t offset = 0, i = 0; + char *dbuf = NULL, *mbuf = NULL; + if(!PyArg_ParseTuple(args, "OO|n", &data, &mask, &offset)) return NULL; + if (PyObject_GetBuffer(data, &data_buf, PyBUF_SIMPLE|PyBUF_WRITABLE) != 0) return NULL; + if (PyObject_GetBuffer(mask, &mask_buf, PyBUF_SIMPLE) != 0) return NULL; + dbuf = (char*)data_buf.buf; mbuf = (char*)mask_buf.buf; + for(i = 0; i < data_buf.len; i++) + dbuf[i] ^= mbuf[(i + offset) & 3]; + Py_RETURN_NONE; } #define UTF8_ACCEPT 0