More WS speedups by using the buffer protocol to avoid mallocs

This commit is contained in:
Kovid Goyal 2015-10-28 11:34:33 +05:30
parent 689bec8c46
commit 6925b23754
4 changed files with 150 additions and 119 deletions

View File

@ -213,6 +213,29 @@ class Connection(object): # {{{
return b'' return b''
raise raise
def recv_into(self, buf, amt=0):
amt = amt or len(buf)
if self.read_buffer.has_data:
data = self.read_buffer.read(amt)
buf[0:len(data)] = data
return len(data)
try:
bytes_read = self.socket.recv_into(buf, amt)
self.last_activity = monotonic()
if bytes_read == 0:
# a closed connection is indicated by signaling
# a read condition, and having recv() return 0.
self.ready = False
return 0
return bytes_read
except socket.error as e:
if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr:
return 0
if e.errno in socket_errors_socket_closed:
self.ready = False
return 0
raise
def fill_read_buffer(self): def fill_read_buffer(self):
try: try:
num = self.read_buffer.recv_from(self.socket) num = self.read_buffer.recv_from(self.socket)

View File

@ -37,7 +37,7 @@ class WSClient(object):
self.socket.sendall(HANDSHAKE_STR.format(self.key).encode('ascii')) self.socket.sendall(HANDSHAKE_STR.format(self.key).encode('ascii'))
self.read_buf = deque() self.read_buf = deque()
self.read_upgrade_response() self.read_upgrade_response()
self.mask = os.urandom(4) self.mask = memoryview(os.urandom(4))
self.frames = [] self.frames = []
def read_upgrade_response(self): def read_upgrade_response(self):

View File

@ -5,12 +5,12 @@
from __future__ import (unicode_literals, division, absolute_import, from __future__ import (unicode_literals, division, absolute_import,
print_function) print_function)
import httplib, struct, os, weakref, socket import httplib, os, weakref, socket
from base64 import standard_b64encode from base64 import standard_b64encode
from collections import deque from collections import deque
from functools import partial
from hashlib import sha1 from hashlib import sha1
from Queue import Queue, Empty from Queue import Queue, Empty
from struct import unpack_from, pack, error as struct_error
from threading import Lock from threading import Lock
from calibre import as_unicode from calibre import as_unicode
@ -58,114 +58,118 @@ RESERVED_CLOSE_CODES = (1004,1005,1006,)
class ReadFrame(object): # {{{ class ReadFrame(object): # {{{
def __init__(self): def __init__(self):
self.header_buf = bytearray(14)
self.rbuf = bytearray(CHUNK_SIZE)
self.reset() self.reset()
def reset(self): def reset(self):
self.state = self.read_header0 self.header_view = memoryview(self.header_buf)[:6]
self.control_buf = [] self.state = self.read_header
def __call__(self, conn): def __call__(self, conn):
return self.state(conn) return self.state(conn)
def read_header0(self, conn): def read_header(self, conn):
data = conn.recv(1) num_bytes = conn.recv_into(self.header_view)
if not data: if num_bytes == 0:
return
b = ord(data)
self.fin = bool(b & 0b10000000)
if b & 0b01110000:
conn.log.error('RSV bits set in frame from client')
conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set')
return return
read_bytes = 6 - len(self.header_view) + num_bytes
if read_bytes > 2:
b1, b2 = self.header_buf[0], self.header_buf[1]
self.fin = bool(b1 & 0b10000000)
if b1 & 0b01110000:
conn.log.error('RSV bits set in frame from client')
conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set')
return
self.opcode = b & 0b1111 self.opcode = b1 & 0b1111
self.state = self.read_header1 self.is_control = self.opcode in CONTROL_CODES
self.is_control = self.opcode in CONTROL_CODES if self.opcode not in ALL_CODES:
if self.opcode not in ALL_CODES: conn.log.error('Unknown OPCODE from client: %r' % self.opcode)
conn.log.error('Unknown OPCODE from client: %r' % self.opcode) conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode)
conn.websocket_close(PROTOCOL_ERROR, 'Unknown OPCODE: %r' % self.opcode) return
return if not self.fin and self.is_control:
if not self.fin and self.is_control: conn.log.error('Fragmented control frame from client')
conn.log.error('Fragmented control frame from client') conn.websocket_close(PROTOCOL_ERROR, 'Fragmented control frame')
conn.websocket_close(PROTOCOL_ERROR, 'Fragmented control frame') return
return
def read_header1(self, conn): mask = b2 & 0b10000000
data = conn.recv(1) if not mask:
if not data: conn.log.error('Unmasked packet from client')
return conn.websocket_close(PROTOCOL_ERROR, 'Unmasked packet not allowed')
b = ord(data) self.reset()
self.mask = b & 0b10000000 return
if not self.mask: self.payload_length = l = b2 & 0b01111111
conn.log.error('Unmasked packet from client') if self.is_control and l > 125:
conn.websocket_close(PROTOCOL_ERROR, 'Unmasked packet not allowed') conn.log.error('Too large control frame from client')
self.reset() conn.websocket_close(PROTOCOL_ERROR, 'Control frame too large')
return self.reset()
self.payload_length = b & 0b01111111 return
if self.is_control and self.payload_length > 125: header_len = 6 + (0 if l < 126 else 2 if l == 126 else 8)
conn.log.error('Too large control frame from client') if header_len <= read_bytes:
conn.websocket_close(PROTOCOL_ERROR, 'Control frame too large') self.process_header(conn)
self.reset() else:
return self.header_view = memoryview(self.header_buf)[read_bytes:header_len]
self.mask_buf = b'' self.state = self.finish_reading_header
if self.payload_length == 126:
self.plbuf = b''
self.state = partial(self.read_payload_length, 2)
elif self.payload_length == 127:
self.plbuf = b''
self.state = partial(self.read_payload_length, 8)
else: else:
self.state = self.read_masking_key self.header_view = self.header_view[num_bytes:]
def read_payload_length(self, size_in_bytes, conn): def finish_reading_header(self, conn):
num_left = size_in_bytes - len(self.plbuf) num_bytes = conn.recv_into(self.header_view)
data = conn.recv(num_left) if num_bytes == 0:
if not data:
return return
self.plbuf += data if num_bytes >= len(self.header_view):
if len(self.plbuf) < size_in_bytes: self.process_header(conn)
return else:
fmt = b'!H' if size_in_bytes == 2 else b'!Q' self.header_view = self.header_view[num_bytes:]
self.payload_length = struct.unpack(fmt, self.plbuf)[0]
del self.plbuf
self.state = self.read_masking_key
def read_masking_key(self, conn): def process_header(self, conn):
num_left = 4 - len(self.mask_buf) if self.payload_length < 126:
data = conn.recv(num_left) self.mask = memoryview(self.header_buf)[2:6]
if not data: elif self.payload_length == 126:
return self.payload_length, = unpack_from(b'!H', self.header_buf, 2)
self.mask_buf += data self.mask = memoryview(self.header_buf)[4:8]
if len(self.mask_buf) < 4: else:
return self.payload_length, = unpack_from(b'!Q', self.header_buf, 2)
self.state = self.read_payload self.mask = memoryview(self.header_buf)[10:14]
self.pos = 0
self.frame_starting = True self.frame_starting = True
if self.payload_length == 0: self.bytes_received = 0
conn.ws_data_received(b'', self.opcode, True, True, self.fin) if self.payload_length <= CHUNK_SIZE:
if self.payload_length == 0:
conn.ws_data_received(b'', self.opcode, True, True, self.fin)
self.reset()
else:
self.rview = memoryview(self.rbuf)[:self.payload_length]
self.state = self.read_packet
else:
self.rview = memoryview(self.rbuf)
self.state = self.read_payload
def read_packet(self, conn):
num_bytes = conn.recv_into(self.rview)
if num_bytes == 0:
return
if num_bytes >= len(self.rview):
data = memoryview(self.rbuf)[:self.payload_length]
fast_mask(data, self.mask)
conn.ws_data_received(data.tobytes(), self.opcode, True, True, self.fin)
self.reset() self.reset()
else:
self.rview = self.rview[num_bytes:]
def read_payload(self, conn): def read_payload(self, conn):
bytes_left = self.payload_length - self.pos num_bytes = conn.recv_into(self.rview, min(len(self.rview), self.payload_length - self.bytes_received))
if bytes_left > 0: if num_bytes == 0:
data = conn.recv(min(bytes_left, CHUNK_SIZE)) return
if not data: data = memoryview(self.rbuf)[:num_bytes]
return fast_mask(data, self.mask, self.bytes_received)
data = fast_mask(data, self.mask_buf, self.pos) self.bytes_received += num_bytes
else: frame_finished = self.bytes_received >= self.payload_length
data = b'' conn.ws_data_received(data.tobytes(), self.opcode, self.frame_starting, frame_finished, self.fin)
self.pos += len(data)
frame_finished = self.pos >= self.payload_length
if self.is_control:
self.control_buf.append(data)
if frame_finished:
data = b''.join(self.control_buf)
else:
return
conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, self.fin)
self.frame_starting = False self.frame_starting = False
if frame_finished: if frame_finished:
self.reset() self.reset()
# }}} # }}}
# Sending frames {{{ # Sending frames {{{
@ -174,20 +178,26 @@ def create_frame(fin, opcode, payload, mask=None, rsv=0):
if isinstance(payload, type('')): if isinstance(payload, type('')):
payload = payload.encode('utf-8') payload = payload.encode('utf-8')
l = len(payload) l = len(payload)
opcode &= 0b1111 header_len = 2 + (0 if l < 126 else 2 if 126 <= l <= 65535 else 8) + (0 if mask is None else 4)
b1 = opcode | (0b10000000 if fin else 0) | (rsv & 0b01110000) frame = bytearray(header_len + l)
b2 = 0 if mask is None else 0b10000000 if l > 0:
frame[-l:] = payload
frame[0] = (opcode & 0b1111) | (0b10000000 if fin else 0) | (rsv & 0b01110000)
if l < 126: if l < 126:
header = bytes(bytearray((b1, b2 | l))) frame[1] = l
elif 126 <= l <= 65535: elif 126 <= l <= 65535:
header = bytes(bytearray((b1, b2 | 126))) + struct.pack(b'!H', l) frame[2:4] = pack(b'!H', l)
frame[1] = 126
else: else:
header = bytes(bytearray((b1, b2 | 127))) + struct.pack(b'!Q', l) frame[2:10] = pack(b'!Q', l)
frame[1] = 127
if mask is not None: if mask is not None:
header += mask frame[1] |= 0b10000000
payload = fast_mask(payload, mask) frame[header_len-4:header_len] = mask
if l > 0:
fast_mask(memoryview(frame)[-l:], mask)
return header + payload return memoryview(frame)
class MessageWriter(object): class MessageWriter(object):
@ -379,20 +389,20 @@ class WebSocketConnection(HTTPConnection):
self.stop_reading = True self.stop_reading = True
if data: if data:
try: try:
close_code = struct.unpack_from(b'!H', data)[0] close_code = unpack_from(b'!H', data)[0]
except struct.error: except struct_error:
data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be atleast two bytes' data = pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be atleast two bytes'
else: else:
try: try:
utf8_decode(data[2:]) utf8_decode(data[2:])
except ValueError: except ValueError:
data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be valid UTF-8' data = pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be valid UTF-8'
else: else:
if close_code < 1000 or close_code in RESERVED_CLOSE_CODES or (1011 < close_code < 3000): if close_code < 1000 or close_code in RESERVED_CLOSE_CODES or (1011 < close_code < 3000):
data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close code reserved' data = pack(b'!H', PROTOCOL_ERROR) + b'close code reserved'
else: else:
close_code = NORMAL_CLOSE close_code = NORMAL_CLOSE
data = struct.pack(b'!H', close_code) data = pack(b'!H', close_code)
f = ReadOnlyFileBuffer(create_frame(1, rcode, data)) f = ReadOnlyFileBuffer(create_frame(1, rcode, data))
f.is_close_frame = opcode == CLOSE f.is_close_frame = opcode == CLOSE
with self.cf_lock: with self.cf_lock:
@ -412,7 +422,7 @@ class WebSocketConnection(HTTPConnection):
if code is None and not reason: if code is None and not reason:
f = ReadOnlyFileBuffer(create_frame(1, CLOSE, b'')) f = ReadOnlyFileBuffer(create_frame(1, CLOSE, b''))
else: else:
f = ReadOnlyFileBuffer(create_frame(1, CLOSE, struct.pack(b'!H', code) + reason)) f = ReadOnlyFileBuffer(create_frame(1, CLOSE, pack(b'!H', code) + reason))
f.is_close_frame = True f.is_close_frame = True
with self.cf_lock: with self.cf_lock:
self.control_frames.append(f) self.control_frames.append(f)

View File

@ -231,19 +231,17 @@ speedup_create_texture(PyObject *self, PyObject *args, PyObject *kw) {
static PyObject* static PyObject*
speedup_websocket_mask(PyObject *self, PyObject *args) { speedup_websocket_mask(PyObject *self, PyObject *args) {
PyObject *data = NULL, *mask = NULL, *ans = NULL; PyObject *data = NULL, *mask = NULL;
Py_ssize_t offset_ = 0; Py_buffer data_buf = {0}, mask_buf = {0};
size_t offset = 0, i = 0; Py_ssize_t offset = 0, i = 0;
char *data_buf = NULL, *mask_buf = NULL, *ans_buf = NULL; char *dbuf = NULL, *mbuf = NULL;
if(!PyArg_ParseTuple(args, "OO|n", &data, &mask, &offset_)) return NULL; if(!PyArg_ParseTuple(args, "OO|n", &data, &mask, &offset)) return NULL;
offset = (size_t)offset_; if (PyObject_GetBuffer(data, &data_buf, PyBUF_SIMPLE|PyBUF_WRITABLE) != 0) return NULL;
ans = PyBytes_FromStringAndSize(NULL, PyBytes_GET_SIZE(data)); if (PyObject_GetBuffer(mask, &mask_buf, PyBUF_SIMPLE) != 0) return NULL;
if (ans != NULL) { dbuf = (char*)data_buf.buf; mbuf = (char*)mask_buf.buf;
data_buf = PyBytes_AS_STRING(data); mask_buf = PyBytes_AS_STRING(mask); ans_buf = PyBytes_AS_STRING(ans); for(i = 0; i < data_buf.len; i++)
for(i = 0; i < (size_t)PyBytes_GET_SIZE(ans); i++) dbuf[i] ^= mbuf[(i + offset) & 3];
ans_buf[i] = data_buf[i] ^ mask_buf[(i + offset) & 3]; Py_RETURN_NONE;
}
return ans;
} }
#define UTF8_ACCEPT 0 #define UTF8_ACCEPT 0