Use a stricter UTF-8 decoder

The calibre server now passes all Autobahn WebSocket tests in strict
mode.
This commit is contained in:
Kovid Goyal 2015-10-26 21:14:55 +05:30
parent 747f1012e0
commit 96e957a491
3 changed files with 109 additions and 8 deletions

View File

@ -247,6 +247,9 @@ class WebSocketTest(BaseTest):
frags.append({'opcode':(TEXT if i == 0 else CONTINUATION), 'fin':1 if i == len(q)-1 else 0, 'payload':b})
simple_test(frags, close_code=INCONSISTENT_DATA, send_close=False)
for q in (b'\xce', b'\xce\xba\xe1'):
simple_test([{'opcode':TEXT, 'payload':q}], close_code=INCONSISTENT_DATA, send_close=False)
simple_test([
{'opcode':TEXT, 'payload':fragments[0], 'fin':0}, {'opcode':CONTINUATION, 'payload':fragments[1]}
], [''.join(fragments)])
@ -269,7 +272,7 @@ class WebSocketTest(BaseTest):
simple_test([
{'opcode':TEXT, 'fin':0}, {'opcode':CONTINUATION, 'fin':0, 'payload':'x'}, {'opcode':CONTINUATION},], ['x'])
for q in (b'\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5', "Hello-µ@ßöäüàá-UTF-8!!".encode('utf-8')):
for q in (b'\xc2\xb5', b'\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5', "Hello-µ@ßöäüàá-UTF-8!!".encode('utf-8')):
frags = []
for i, b in enumerate(q):
frags.append({'opcode':(TEXT if i == 0 else CONTINUATION), 'fin':1 if i == len(q)-1 else 0, 'payload':b})

View File

@ -5,7 +5,7 @@
from __future__ import (unicode_literals, division, absolute_import,
print_function)
import codecs, httplib, struct, os, weakref, socket
import httplib, struct, os, weakref, socket
from base64 import standard_b64encode
from collections import deque
from functools import partial
@ -22,7 +22,7 @@ from calibre.srv.utils import DESIRED_SEND_BUFFER_SIZE
speedup, err = plugins['speedup']
if not speedup:
raise RuntimeError('Failed to load speedup module with error: ' + err)
fast_mask = speedup.websocket_mask
fast_mask, utf8_decode = speedup.websocket_mask, speedup.utf8_decode
del speedup, err
HANDSHAKE_STR = (
@ -224,6 +224,20 @@ class MessageWriter(object):
conn_id = 0
class UTF8Decoder(object): # {{{
def __init__(self):
self.reset()
def __call__(self, data):
ans, self.state, self.codep = utf8_decode(data, self.state, self.codep)
return ans
def reset(self):
self.state = 0
self.codep = 0
# }}}
class WebSocketConnection(HTTPConnection):
# Internal API {{{
@ -238,7 +252,7 @@ class WebSocketConnection(HTTPConnection):
self.cf_lock = Lock()
self.sending = None
self.send_buf = None
self.frag_decoder = codecs.getincrementaldecoder('utf-8')(errors='strict')
self.frag_decoder = UTF8Decoder()
self.ws_close_received = self.ws_close_sent = False
conn_id += 1
self.websocket_connection_id = conn_id
@ -337,12 +351,18 @@ class WebSocketConnection(HTTPConnection):
if self.current_recv_opcode == TEXT:
if message_starting:
self.frag_decoder.reset()
empty_data = not data
try:
data = self.frag_decoder.decode(data, final=message_finished)
except UnicodeDecodeError:
data = self.frag_decoder(data)
except ValueError:
self.frag_decoder.reset()
self.log.error('Client sent undecodeable UTF-8')
return self.websocket_close(INCONSISTENT_DATA, 'Not valid UTF-8')
if message_finished:
if (not data and not empty_data) or self.frag_decoder.state:
self.frag_decoder.reset()
self.log.error('Client sent undecodeable UTF-8')
return self.websocket_close(INCONSISTENT_DATA, 'Not valid UTF-8')
if message_finished:
self.current_recv_opcode = None
self.frag_decoder.reset()
@ -365,8 +385,8 @@ class WebSocketConnection(HTTPConnection):
data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be atleast two bytes'
else:
try:
data[2:].decode('utf-8')
except UnicodeDecodeError:
utf8_decode(data[2:])
except ValueError:
data = struct.pack(b'!H', PROTOCOL_ERROR) + b'close frame data must be valid UTF-8'
else:
if close_code < 1000 or close_code in RESERVED_CLOSE_CODES or (1011 < close_code < 3000):

View File

@ -13,6 +13,18 @@
#define CLAMP(value, lower, upper) ((value > upper) ? upper : ((value < lower) ? lower : value))
#define STRIDE(width, r, c) ((width * (r)) + (c))
#ifdef _MSC_VER
#ifndef uint32_t
#define unsigned __int32 uint32_t;
#endif
#ifndef uint8_t
#define unsigned char uint8_t;
#endif
#else
#include <stdint.h>
#endif
static PyObject *
speedup_parse_date(PyObject *self, PyObject *args) {
const char *raw, *orig, *tz;
@ -234,6 +246,68 @@ speedup_websocket_mask(PyObject *self, PyObject *args) {
return ans;
}
#if PY_VERSION_HEX >= 0x03030000
#error Not implemented for python >= 3.3
#endif
#define UTF8_ACCEPT 0
#define UTF8_REJECT 1
static const uint8_t utf8d[] = {
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 00..1f
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 20..3f
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 40..5f
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 60..7f
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, // 80..9f
7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, // a0..bf
8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, // c0..df
0xa,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x4,0x3,0x3, // e0..ef
0xb,0x6,0x6,0x6,0x5,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8, // f0..ff
0x0,0x1,0x2,0x3,0x5,0x8,0x7,0x1,0x1,0x1,0x4,0x6,0x1,0x1,0x1,0x1, // s0..s0
1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,0,1,0,1,1,1,1,1,1, // s1..s2
1,2,1,1,1,1,1,2,1,2,1,1,1,1,1,1,1,1,1,1,1,1,1,2,1,1,1,1,1,1,1,1, // s3..s4
1,2,1,1,1,1,1,1,1,2,1,1,1,1,1,1,1,1,1,1,1,1,1,3,1,3,1,1,1,1,1,1, // s5..s6
1,3,1,1,1,1,1,3,1,3,1,1,1,1,1,1,1,3,1,1,1,1,1,1,1,1,1,1,1,1,1,1, // s7..s8
};
static void inline
utf8_decode_(uint32_t* state, uint32_t* codep, uint8_t byte) {
/* Comes from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ */
uint32_t type = utf8d[byte];
*codep = (*state != UTF8_ACCEPT) ?
(byte & 0x3fu) | (*codep << 6) :
(0xff >> type) & (byte);
*state = utf8d[256 + *state*16 + type];
}
static PyObject*
utf8_decode(PyObject *self, PyObject *args) {
uint32_t state = UTF8_ACCEPT, codep = 0;
PyObject *data = NULL, *ans = NULL;
Py_ssize_t i = 0, pos = 0;
uint32_t *buf = NULL;
unsigned char *dbuf = NULL;
if(!PyArg_ParseTuple(args, "O|II", &data, &state, &codep)) return NULL;
buf = (uint32_t*)PyMem_Malloc(sizeof(uint32_t) * PyBytes_GET_SIZE(data));
if (buf == NULL) return PyErr_NoMemory();
dbuf = (unsigned char*)PyBytes_AS_STRING(data);
for (i = 0; i < PyBytes_GET_SIZE(data); i++) {
utf8_decode_(&state, &codep, dbuf[i]);
if (state == UTF8_ACCEPT) buf[pos++] = codep;
else if (state == UTF8_REJECT) { PyErr_SetString(PyExc_ValueError, "Invalid byte in UTF-8 string"); goto error; }
}
ans = PyUnicode_DecodeUTF32((const char*)buf, pos * sizeof(uint32_t), "strict", NULL);
error:
PyMem_Free(buf); buf = NULL;
if (ans == NULL) return ans;
return Py_BuildValue("NII", ans, state, codep);
}
static PyMethodDef speedup_methods[] = {
{"parse_date", speedup_parse_date, METH_VARARGS,
"parse_date()\n\nParse ISO dates faster."
@ -265,6 +339,10 @@ static PyMethodDef speedup_methods[] = {
"websocket_mask(data, mask [, offset=0)\n\nXOR the data (bytestring) with the specified (must be 4-byte bytestring) mask"
},
{"utf8_decode", utf8_decode, METH_VARARGS,
"utf8_decode(data, [, state=0, codep=0)\n\nDecode an UTF-8 bytestring, using a strict UTF-8 decoder, that unlike python does not allow orphaned surrogates. Returns a unicode object and the state."
},
{NULL, NULL, 0, NULL}
};