More WS speedups by using the buffer protocol to avoid mallocs

This commit is contained in:
Kovid Goyal 2015-10-28 11:34:33 +05:30
parent 689bec8c46
commit 6925b23754
4 changed files with 150 additions and 119 deletions

View File

@ -213,6 +213,29 @@ class Connection(object): # {{{
return b''
raise
def recv_into(self, buf, amt=0):
amt = amt or len(buf)
if self.read_buffer.has_data:
data = self.read_buffer.read(amt)
buf[0:len(data)] = data
return len(data)
try:
bytes_read = self.socket.recv_into(buf, amt)
self.last_activity = monotonic()
if bytes_read == 0:
# a closed connection is indicated by signaling
# a read condition, and having recv() return 0.
self.ready = False
return 0
return bytes_read
except socket.error as e:
if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr:
return 0
if e.errno in socket_errors_socket_closed:
self.ready = False
return 0
raise
def fill_read_buffer(self):
try:
num = self.read_buffer.recv_from(self.socket)

View File

@ -37,7 +37,7 @@ class WSClient(object):
self.socket.sendall(HANDSHAKE_STR.format(self.key).encode('ascii'))
self.read_buf = deque()
self.read_upgrade_response()
self.mask = os.urandom(4)
self.mask = memoryview(os.urandom(4))
self.frames = []
def read_upgrade_response(self):

View File

@ -5,12 +5,12 @@
from __future__ import (unicode_literals, division, absolute_import,
print_function)
import httplib, struct, os, weakref, socket
import httplib, os, weakref, socket
from base64 import standard_b64encode
from collections import deque
from functools import partial
from hashlib import sha1
from Queue import Queue, Empty
from struct import unpack_from, pack, error as struct_error
from threading import Lock
from calibre import as_unicode
@ -58,28 +58,31 @@ RESERVED_CLOSE_CODES = (1004,1005,1006,)
class ReadFrame(object): # {{{
def __init__(self):
self.header_buf = bytearray(14)
self.rbuf = bytearray(CHUNK_SIZE)
self.reset()
def reset(self):
self.state = self.read_header0
self.control_buf = []
self.header_view = memoryview(self.header_buf)[:6]
self.state = self.read_header
def __call__(self, conn):
return self.state(conn)
def read_header0(self, conn):
data = conn.recv(1)
if not data:
def read_header(self, conn):
num_bytes = conn.recv_into(self.header_view)
if num_bytes == 0:
return
b = ord(data)
self.fin = bool(b & 0b10000000)
if b & 0b01110000:
read_bytes = 6 - len(self.header_view) + num_bytes
if read_bytes > 2:
b1, b2 = self.header_buf[0], self.header_buf[1]
self.fin = bool(b1 & 0b10000000)
if b1 & 0b01110000:
conn.log.error('RSV bits set in frame from client')
conn.websocket_close(PROTOCOL_ERROR, 'RSV bits set')
return
self.opcode = b & 0b1111
self.state = self.read_header1
self.opcode = b1 & 0b1111
self.is_control = self.opcode in CONTROL_CODES
if self.opcode not in ALL_CODES:
conn.log.error('Unknown OPCODE from client: %r' % self.opcode)
@ -90,82 +93,83 @@ class ReadFrame(object): # {{{
conn.websocket_close(PROTOCOL_ERROR, 'Fragmented control frame')
return
def read_header1(self, conn):
data = conn.recv(1)
if not data:
return
b = ord(data)
self.mask = b & 0b10000000
if not self.mask:
mask = b2 & 0b10000000
if not mask:
conn.log.error('Unmasked packet from client')
conn.websocket_close(PROTOCOL_ERROR, 'Unmasked packet not allowed')
self.reset()
return
self.payload_length = b & 0b01111111
if self.is_control and self.payload_length > 125:
self.payload_length = l = b2 & 0b01111111
if self.is_control and l > 125:
conn.log.error('Too large control frame from client')
conn.websocket_close(PROTOCOL_ERROR, 'Control frame too large')
self.reset()
return
self.mask_buf = b''
if self.payload_length == 126:
self.plbuf = b''
self.state = partial(self.read_payload_length, 2)
elif self.payload_length == 127:
self.plbuf = b''
self.state = partial(self.read_payload_length, 8)
header_len = 6 + (0 if l < 126 else 2 if l == 126 else 8)
if header_len <= read_bytes:
self.process_header(conn)
else:
self.state = self.read_masking_key
self.header_view = memoryview(self.header_buf)[read_bytes:header_len]
self.state = self.finish_reading_header
else:
self.header_view = self.header_view[num_bytes:]
def read_payload_length(self, size_in_bytes, conn):
num_left = size_in_bytes - len(self.plbuf)
data = conn.recv(num_left)
if not data:
def finish_reading_header(self, conn):
num_bytes = conn.recv_into(self.header_view)
if num_bytes == 0:
return
self.plbuf += data
if len(self.plbuf) < size_in_bytes:
return
fmt = b'!H' if size_in_bytes == 2 else b'!Q'
self.payload_length = struct.unpack(fmt, self.plbuf)[0]
del self.plbuf
self.state = self.read_masking_key
if num_bytes >= len(self.header_view):
self.process_header(conn)
else:
self.header_view = self.header_view[num_bytes:]
def read_masking_key(self, conn):
num_left = 4 - len(self.mask_buf)
data = conn.recv(num_left)
if not data:
return
self.mask_buf += data
if len(self.mask_buf) < 4:
return
self.state = self.read_payload
self.pos = 0
def process_header(self, conn):
if self.payload_length < 126:
self.mask = memoryview(self.header_buf)[2:6]
elif self.payload_length == 126:
self.payload_length, = unpack_from(b'!H', self.header_buf, 2)
self.mask = memoryview(self.header_buf)[4:8]
else:
self.payload_length, = unpack_from(b'!Q', self.header_buf, 2)
self.mask = memoryview(self.header_buf)[10:14]
self.frame_starting = True
self.bytes_received = 0
if self.payload_length <= CHUNK_SIZE:
if self.payload_length == 0:
conn.ws_data_received(b'', self.opcode, True, True, self.fin)
self.reset()
else:
self.rview = memoryview(self.rbuf)[:self.payload_length]
self.state = self.read_packet
else:
self.rview = memoryview(self.rbuf)
self.state = self.read_payload
def read_packet(self, conn):
num_bytes = conn.recv_into(self.rview)
if num_bytes == 0:
return
if num_bytes >= len(self.rview):
data = memoryview(self.rbuf)[:self.payload_length]
fast_mask(data, self.mask)
conn.ws_data_received(data.tobytes(), self.opcode, True, True, self.fin)
self.reset()
else:
self.rview = self.rview[num_bytes:]
def read_payload(self, conn):
bytes_left = self.payload_length - self.pos
if bytes_left > 0:
data = conn.recv(min(bytes_left, CHUNK_SIZE))
if not data:
num_bytes = conn.recv_into(self.rview, min(len(self.rview), self.payload_length - self.bytes_received))
if num_bytes == 0:
return
data = fast_mask(data, self.mask_buf, self.pos)
else:
data = b''
self.pos += len(data)
frame_finished = self.pos >= self.payload_length
if self.is_control:
self.control_buf.append(data)
if frame_finished:
data = b''.join(self.control_buf)
else:
return
conn.ws_data_received(data, self.opcode, self.frame_starting, frame_finished, self.fin)
data = memoryview(self.rbuf)[:num_bytes]
fast_mask(data, self.mask, self.bytes_received)
self.bytes_received += num_bytes
frame_finished = self.bytes_received >= self.payload_length
conn.ws_data_received(data.tobytes(), self.opcode, self.frame_starting, frame_finished, self.fin)
self.frame_starting = False
if frame_finished:
self.reset()
# }}}
# Sending frames {{{
@ -174,20 +178,26 @@ def create_frame(fin, opcode, payload, mask=None, rsv=0):
if isinstance(payload, type('')):
payload = payload.encode('utf-8')
l = len(payload)
opcode &= 0b1111
b1 = opcode | (0b10000000 if fin else 0) | (rsv & 0b01110000)
b2 = 0 if mask is None else 0b10000000
header_len = 2 + (0 if l < 126 else 2 if 126 <= l <= 65535 else 8) + (0 if mask is None else 4)
frame = bytearray(header_len + l)
if l > 0:
frame[-l:] = payload
frame[0] = (opcode & 0b1111) | (0b10000000 if fin else 0) | (rsv & 0b01110000)
if l < 126:
header = bytes(bytearray((b1, b2 | l)))
frame[1] = l
elif 126 <= l <= 65535:
header = bytes(bytearray((b1, b2 | 126))) + struct.pack(b'!H', l)
frame[2:4] = pack(b'!H', l)
frame[1] = 126
else:
header = bytes(bytearray((b1, b2 | 127))) + struct.pack(b'!Q', l)
frame[2:10] = pack(b'!Q', l)
frame[1] = 127
if mask is not None:
header += mask
payload = fast_mask(payload, mask)
frame[1] |= 0b10000000
frame[header_len-4:header_len] = mask
if l > 0:
fast_mask(memoryview(frame)[-l:], mask)
return header + payload
return memoryview(frame)
class MessageWriter(object):
@ -379,20 +389,20 @@ class WebSocketConnection(HTTPConnection):
self.stop_reading = True
if data:
try:
close_code = struct.unpack_from(b'!H', data)[0]
except struct.error:
data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be atleast two bytes'
close_code = unpack_from(b'!H', data)[0]
except struct_error:
data = pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be atleast two bytes'
else:
try:
utf8_decode(data[2:])
except ValueError:
data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be valid UTF-8'
data = pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be valid UTF-8'
else:
if close_code < 1000 or close_code in RESERVED_CLOSE_CODES or (1011 < close_code < 3000):
data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close code reserved'
data = pack(b'!H', PROTOCOL_ERROR) + b'close code reserved'
else:
close_code = NORMAL_CLOSE
data = struct.pack(b'!H', close_code)
data = pack(b'!H', close_code)
f = ReadOnlyFileBuffer(create_frame(1, rcode, data))
f.is_close_frame = opcode == CLOSE
with self.cf_lock:
@ -412,7 +422,7 @@ class WebSocketConnection(HTTPConnection):
if code is None and not reason:
f = ReadOnlyFileBuffer(create_frame(1, CLOSE, b''))
else:
f = ReadOnlyFileBuffer(create_frame(1, CLOSE, struct.pack(b'!H', code) + reason))
f = ReadOnlyFileBuffer(create_frame(1, CLOSE, pack(b'!H', code) + reason))
f.is_close_frame = True
with self.cf_lock:
self.control_frames.append(f)

View File

@ -231,19 +231,17 @@ speedup_create_texture(PyObject *self, PyObject *args, PyObject *kw) {
static PyObject*
speedup_websocket_mask(PyObject *self, PyObject *args) {
PyObject *data = NULL, *mask = NULL, *ans = NULL;
Py_ssize_t offset_ = 0;
size_t offset = 0, i = 0;
char *data_buf = NULL, *mask_buf = NULL, *ans_buf = NULL;
if(!PyArg_ParseTuple(args, "OO|n", &data, &mask, &offset_)) return NULL;
offset = (size_t)offset_;
ans = PyBytes_FromStringAndSize(NULL, PyBytes_GET_SIZE(data));
if (ans != NULL) {
data_buf = PyBytes_AS_STRING(data); mask_buf = PyBytes_AS_STRING(mask); ans_buf = PyBytes_AS_STRING(ans);
for(i = 0; i < (size_t)PyBytes_GET_SIZE(ans); i++)
ans_buf[i] = data_buf[i] ^ mask_buf[(i + offset) & 3];
}
return ans;
PyObject *data = NULL, *mask = NULL;
Py_buffer data_buf = {0}, mask_buf = {0};
Py_ssize_t offset = 0, i = 0;
char *dbuf = NULL, *mbuf = NULL;
if(!PyArg_ParseTuple(args, "OO|n", &data, &mask, &offset)) return NULL;
if (PyObject_GetBuffer(data, &data_buf, PyBUF_SIMPLE|PyBUF_WRITABLE) != 0) return NULL;
if (PyObject_GetBuffer(mask, &mask_buf, PyBUF_SIMPLE) != 0) return NULL;
dbuf = (char*)data_buf.buf; mbuf = (char*)mask_buf.buf;
for(i = 0; i < data_buf.len; i++)
dbuf[i] ^= mbuf[(i + offset) & 3];
Py_RETURN_NONE;
}
#define UTF8_ACCEPT 0