More WS speedups by using the buffer protocol to avoid mallocs

This commit is contained in:
Kovid Goyal 2015-10-28 11:34:33 +05:30
parent 689bec8c46
commit 6925b23754
4 changed files with 150 additions and 119 deletions

View File

@ -213,6 +213,29 @@ class Connection(object): # {{{
return b'' return b''
raise raise
def recv_into(self, buf, amt=0):
amt = amt or len(buf)
if self.read_buffer.has_data:
data = self.read_buffer.read(amt)
buf[0:len(data)] = data
return len(data)
try:
bytes_read = self.socket.recv_into(buf, amt)
self.last_activity = monotonic()
if bytes_read == 0:
# a closed connection is indicated by signaling
# a read condition, and having recv() return 0.
self.ready = False
return 0
return bytes_read
except socket.error as e:
if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr:
return 0
if e.errno in socket_errors_socket_closed:
self.ready = False
return 0
raise
def fill_read_buffer(self): def fill_read_buffer(self):
try: try:
num = self.read_buffer.recv_from(self.socket) num = self.read_buffer.recv_from(self.socket)

View File

@ -37,7 +37,7 @@ class WSClient(object):
self.socket.sendall(HANDSHAKE_STR.format(self.key).encode('ascii')) self.socket.sendall(HANDSHAKE_STR.format(self.key).encode('ascii'))
self.read_buf = deque() self.read_buf = deque()
self.read_upgrade_response() self.read_upgrade_response()
self.mask = os.urandom(4) self.mask = memoryview(os.urandom(4))
self.frames = [] self.frames = []
def read_upgrade_response(self): def read_upgrade_response(self):

View File

@ -5,12 +5,12 @@
from __future__ import (unicode_literals, division, absolute_import, from __future__ import (unicode_literals, division, absolute_import,
print_function) print_function)
import httplib, struct, os, weakref, socket import httplib, os, weakref, socket
from base64 import standard_b64encode from base64 import standard_b64encode
from collections import deque from collections import deque
from functools import partial
from hashlib import sha1 from hashlib import sha1
from Queue import Queue, Empty from Queue import Queue, Empty
from struct import unpack_from, pack, error as struct_error
from threading import Lock from threading import Lock
from calibre import as_unicode from calibre import as_unicode
@ -58,28 +58,31 @@ RESERVED_CLOSE_CODES = (1004,1005,1006,)
class ReadFrame(object): # {{{ class ReadFrame(object): # {{{
def __init__(self): def __init__(self):
self.header_buf = bytearray(14)
self.rbuf = bytearray(CHUNK_SIZE)
self.reset() self.reset()
def reset(self): def reset(self):
self.state = self.read_header0 self.header_view = memoryview(self.header_buf)[:6]
self.control_buf = [] self.state = self.read_header
def __call__(self, conn): def __call__(self, conn):
return self.state(conn) return self.state(conn)
def read_header0(self, conn): def read_header(self, conn):
data = conn.recv(1) num_bytes = conn.recv_into(self.header_view)
if not data: if num_bytes == 0:
return return
b = ord(data) read_bytes = 6 - len(self.header_view) + num_bytes
self.fin = bool(b & 0b10000000) if read_bytes > 2:
if b & 0b01110000: b1, b2 = self.header_buf[0], self.header_buf[1]
self.fin = bool(b1 & 0b10000000)
if b1 & 0b01110000:
conn.log.error('RSV bits set in frame from client') conn.log.error('RSV bits set in frame from client')
conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set') conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set')
return return
self.opcode = b & 0b1111 self.opcode = b1 & 0b1111
self.state = self.read_header1
self.is_control = self.opcode in CONTROL_CODES self.is_control = self.opcode in CONTROL_CODES
if self.opcode not in ALL_CODES: if self.opcode not in ALL_CODES:
conn.log.error('Unknown OPCODE from client: %r' % self.opcode) conn.log.error('Unknown OPCODE from client: %r' % self.opcode)
@ -90,82 +93,83 @@ class ReadFrame(object): # {{{
conn.websocket_close(PROTOCOL_ERROR, 'Fragmented control frame') conn.websocket_close(PROTOCOL_ERROR, 'Fragmented control frame')
return return
def read_header1(self, conn): mask = b2 & 0b10000000
data = conn.recv(1) if not mask:
if not data:
return
b = ord(data)
self.mask = b & 0b10000000
if not self.mask:
conn.log.error('Unmasked packet from client') conn.log.error('Unmasked packet from client')
conn.websocket_close(PROTOCOL_ERROR, 'Unmasked packet not allowed') conn.websocket_close(PROTOCOL_ERROR, 'Unmasked packet not allowed')
self.reset() self.reset()
return return
self.payload_length = b & 0b01111111 self.payload_length = l = b2 & 0b01111111
if self.is_control and self.payload_length > 125: if self.is_control and l > 125:
conn.log.error('Too large control frame from client') conn.log.error('Too large control frame from client')
conn.websocket_close(PROTOCOL_ERROR, 'Control frame too large') conn.websocket_close(PROTOCOL_ERROR, 'Control frame too large')
self.reset() self.reset()
return return
self.mask_buf = b'' header_len = 6 + (0 if l < 126 else 2 if l == 126 else 8)
if self.payload_length == 126: if header_len <= read_bytes:
self.plbuf = b'' self.process_header(conn)
self.state = partial(self.read_payload_length, 2)
elif self.payload_length == 127:
self.plbuf = b''
self.state = partial(self.read_payload_length, 8)
else: else:
self.state = self.read_masking_key self.header_view = memoryview(self.header_buf)[read_bytes:header_len]
self.state = self.finish_reading_header
else:
self.header_view = self.header_view[num_bytes:]
def read_payload_length(self, size_in_bytes, conn): def finish_reading_header(self, conn):
num_left = size_in_bytes - len(self.plbuf) num_bytes = conn.recv_into(self.header_view)
data = conn.recv(num_left) if num_bytes == 0:
if not data:
return return
self.plbuf += data if num_bytes >= len(self.header_view):
if len(self.plbuf) < size_in_bytes: self.process_header(conn)
return else:
fmt = b'!H' if size_in_bytes == 2 else b'!Q' self.header_view = self.header_view[num_bytes:]
self.payload_length = struct.unpack(fmt, self.plbuf)[0]
del self.plbuf
self.state = self.read_masking_key
def read_masking_key(self, conn): def process_header(self, conn):
num_left = 4 - len(self.mask_buf) if self.payload_length < 126:
data = conn.recv(num_left) self.mask = memoryview(self.header_buf)[2:6]
if not data: elif self.payload_length == 126:
return self.payload_length, = unpack_from(b'!H', self.header_buf, 2)
self.mask_buf += data self.mask = memoryview(self.header_buf)[4:8]
if len(self.mask_buf) < 4: else:
return self.payload_length, = unpack_from(b'!Q', self.header_buf, 2)
self.state = self.read_payload self.mask = memoryview(self.header_buf)[10:14]
self.pos = 0
self.frame_starting = True self.frame_starting = True
self.bytes_received = 0
if self.payload_length <= CHUNK_SIZE:
if self.payload_length == 0: if self.payload_length == 0:
conn.ws_data_received(b'', self.opcode, True, True, self.fin) conn.ws_data_received(b'', self.opcode, True, True, self.fin)
self.reset() self.reset()
else:
self.rview = memoryview(self.rbuf)[:self.payload_length]
self.state = self.read_packet
else:
self.rview = memoryview(self.rbuf)
self.state = self.read_payload
def read_packet(self, conn):
num_bytes = conn.recv_into(self.rview)
if num_bytes == 0:
return
if num_bytes >= len(self.rview):
data = memoryview(self.rbuf)[:self.payload_length]
fast_mask(data, self.mask)
conn.ws_data_received(data.tobytes(), self.opcode, True, True, self.fin)
self.reset()
else:
self.rview = self.rview[num_bytes:]
def read_payload(self, conn): def read_payload(self, conn):
bytes_left = self.payload_length - self.pos num_bytes = conn.recv_into(self.rview, min(len(self.rview), self.payload_length - self.bytes_received))
if bytes_left > 0: if num_bytes == 0:
data = conn.recv(min(bytes_left, CHUNK_SIZE))
if not data:
return return
data = fast_mask(data, self.mask_buf, self.pos) data = memoryview(self.rbuf)[:num_bytes]
else: fast_mask(data, self.mask, self.bytes_received)
data = b'' self.bytes_received += num_bytes
self.pos += len(data) frame_finished = self.bytes_received >= self.payload_length
frame_finished = self.pos >= self.payload_length conn.ws_data_received(data.tobytes(), self.opcode, self.frame_starting, frame_finished, self.fin)
if self.is_control:
self.control_buf.append(data)
if frame_finished:
data = b''.join(self.control_buf)
else:
return
conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, self.fin)
self.frame_starting = False self.frame_starting = False
if frame_finished: if frame_finished:
self.reset() self.reset()
# }}} # }}}
# Sending frames {{{ # Sending frames {{{
@ -174,20 +178,26 @@ def create_frame(fin, opcode, payload, mask=None, rsv=0):
if isinstance(payload, type('')): if isinstance(payload, type('')):
payload = payload.encode('utf-8') payload = payload.encode('utf-8')
l = len(payload) l = len(payload)
opcode &= 0b1111 header_len = 2 + (0 if l < 126 else 2 if 126 <= l <= 65535 else 8) + (0 if mask is None else 4)
b1 = opcode | (0b10000000 if fin else 0) | (rsv & 0b01110000) frame = bytearray(header_len + l)
b2 = 0 if mask is None else 0b10000000 if l > 0:
frame[-l:] = payload
frame[0] = (opcode & 0b1111) | (0b10000000 if fin else 0) | (rsv & 0b01110000)
if l < 126: if l < 126:
header = bytes(bytearray((b1, b2 | l))) frame[1] = l
elif 126 <= l <= 65535: elif 126 <= l <= 65535:
header = bytes(bytearray((b1, b2 | 126))) + struct.pack(b'!H', l) frame[2:4] = pack(b'!H', l)
frame[1] = 126
else: else:
header = bytes(bytearray((b1, b2 | 127))) + struct.pack(b'!Q', l) frame[2:10] = pack(b'!Q', l)
frame[1] = 127
if mask is not None: if mask is not None:
header += mask frame[1] |= 0b10000000
payload = fast_mask(payload, mask) frame[header_len-4:header_len] = mask
if l > 0:
fast_mask(memoryview(frame)[-l:], mask)
return header + payload return memoryview(frame)
class MessageWriter(object): class MessageWriter(object):
@ -379,20 +389,20 @@ class WebSocketConnection(HTTPConnection):
self.stop_reading = True self.stop_reading = True
if data: if data:
try: try:
close_code = struct.unpack_from(b'!H', data)[0] close_code = unpack_from(b'!H', data)[0]
except struct.error: except struct_error:
data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be atleast two bytes' data = pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be atleast two bytes'
else: else:
try: try:
utf8_decode(data[2:]) utf8_decode(data[2:])
except ValueError: except ValueError:
data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be valid UTF-8' data = pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be valid UTF-8'
else: else:
if close_code < 1000 or close_code in RESERVED_CLOSE_CODES or (1011 < close_code < 3000): if close_code < 1000 or close_code in RESERVED_CLOSE_CODES or (1011 < close_code < 3000):
data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close code reserved' data = pack(b'!H', PROTOCOL_ERROR) + b'close code reserved'
else: else:
close_code = NORMAL_CLOSE close_code = NORMAL_CLOSE
data = struct.pack(b'!H', close_code) data = pack(b'!H', close_code)
f = ReadOnlyFileBuffer(create_frame(1, rcode, data)) f = ReadOnlyFileBuffer(create_frame(1, rcode, data))
f.is_close_frame = opcode == CLOSE f.is_close_frame = opcode == CLOSE
with self.cf_lock: with self.cf_lock:
@ -412,7 +422,7 @@ class WebSocketConnection(HTTPConnection):
if code is None and not reason: if code is None and not reason:
f = ReadOnlyFileBuffer(create_frame(1, CLOSE, b'')) f = ReadOnlyFileBuffer(create_frame(1, CLOSE, b''))
else: else:
f = ReadOnlyFileBuffer(create_frame(1, CLOSE, struct.pack(b'!H', code) + reason)) f = ReadOnlyFileBuffer(create_frame(1, CLOSE, pack(b'!H', code) + reason))
f.is_close_frame = True f.is_close_frame = True
with self.cf_lock: with self.cf_lock:
self.control_frames.append(f) self.control_frames.append(f)

View File

@ -231,19 +231,17 @@ speedup_create_texture(PyObject *self, PyObject *args, PyObject *kw) {
static PyObject* static PyObject*
speedup_websocket_mask(PyObject *self, PyObject *args) { speedup_websocket_mask(PyObject *self, PyObject *args) {
PyObject *data = NULL, *mask = NULL, *ans = NULL; PyObject *data = NULL, *mask = NULL;
Py_ssize_t offset_ = 0; Py_buffer data_buf = {0}, mask_buf = {0};
size_t offset = 0, i = 0; Py_ssize_t offset = 0, i = 0;
char *data_buf = NULL, *mask_buf = NULL, *ans_buf = NULL; char *dbuf = NULL, *mbuf = NULL;
if(!PyArg_ParseTuple(args, "OO|n", &data, &mask, &offset_)) return NULL; if(!PyArg_ParseTuple(args, "OO|n", &data, &mask, &offset)) return NULL;
offset = (size_t)offset_; if (PyObject_GetBuffer(data, &data_buf, PyBUF_SIMPLE|PyBUF_WRITABLE) != 0) return NULL;
ans = PyBytes_FromStringAndSize(NULL, PyBytes_GET_SIZE(data)); if (PyObject_GetBuffer(mask, &mask_buf, PyBUF_SIMPLE) != 0) return NULL;
if (ans != NULL) { dbuf = (char*)data_buf.buf; mbuf = (char*)mask_buf.buf;
data_buf = PyBytes_AS_STRING(data); mask_buf = PyBytes_AS_STRING(mask); ans_buf = PyBytes_AS_STRING(ans); for(i = 0; i < data_buf.len; i++)
for(i = 0; i < (size_t)PyBytes_GET_SIZE(ans); i++) dbuf[i] ^= mbuf[(i + offset) & 3];
ans_buf[i] = data_buf[i] ^ mask_buf[(i + offset) & 3]; Py_RETURN_NONE;
}
return ans;
} }
#define UTF8_ACCEPT 0 #define UTF8_ACCEPT 0