From 48f548236eca29709c4dcc14c18e813ab14a0157 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Mon, 18 May 2015 10:01:36 +0530 Subject: [PATCH] Finish up HTTP header parsing and add some tests --- src/calibre/srv/http.py | 195 ++++++++++++++++++++++-------- src/calibre/srv/tests/__init__.py | 1 + src/calibre/srv/tests/base.py | 22 ++++ src/calibre/srv/tests/http.py | 47 +++++++ src/calibre/srv/tests/main.py | 101 ++++++++++++++++ src/calibre/srv/utils.py | 69 +++++++++++ 6 files changed, 383 insertions(+), 52 deletions(-) create mode 100644 src/calibre/srv/tests/__init__.py create mode 100644 src/calibre/srv/tests/base.py create mode 100644 src/calibre/srv/tests/http.py create mode 100644 src/calibre/srv/tests/main.py create mode 100644 src/calibre/srv/utils.py diff --git a/src/calibre/srv/http.py b/src/calibre/srv/http.py index 703c82fd73..c970d21121 100644 --- a/src/calibre/srv/http.py +++ b/src/calibre/srv/http.py @@ -6,13 +6,15 @@ from __future__ import (unicode_literals, division, absolute_import, __license__ = 'GPL v3' __copyright__ = '2015, Kovid Goyal ' -import httplib, socket, re +import httplib, socket, re, os +from io import BytesIO +import repr as reprlib from urllib import unquote -from urlparse import parse_qs from functools import partial from calibre import as_unicode from calibre.srv.errors import MaxSizeExceeded, NonHTTPConnRequest +from calibre.srv.utils import MultiDict HTTP1 = 'HTTP/1.0' HTTP11 = 'HTTP/1.1' @@ -62,19 +64,19 @@ def parse_request_uri(uri): # {{{ # }}} comma_separated_headers = { - b'Accept', b'Accept-Charset', b'Accept-Encoding', - b'Accept-Language', b'Accept-Ranges', b'Allow', b'Cache-Control', - b'Connection', b'Content-Encoding', b'Content-Language', b'Expect', - b'If-Match', b'If-None-Match', b'Pragma', b'Proxy-Authenticate', b'TE', - b'Trailer', b'Transfer-Encoding', b'Upgrade', b'Vary', b'Via', b'Warning', - b'WWW-Authenticate' + 'Accept', 'Accept-Charset', 'Accept-Encoding', + 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control', + 'Connection', 'Content-Encoding', 'Content-Language', 'Expect', + 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE', + 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning', + 'WWW-Authenticate' } decoded_headers = { 'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect', -} +} | comma_separated_headers -def read_headers(readline, max_line_size, hdict=None): # {{{ +def read_headers(readline): # {{{ """ Read headers from the given stream into the given header dict. @@ -87,8 +89,27 @@ def read_headers(readline, max_line_size, hdict=None): # {{{ This function raises ValueError when the read bytes violate the HTTP spec. You should probably return "400 Bad Request" if this happens. """ - if hdict is None: - hdict = {} + hdict = MultiDict() + + def safe_decode(hname, value): + try: + return value.decode('ascii') + except UnicodeDecodeError: + if hname in decoded_headers: + raise + return value + + current_key = current_value = None + + def commit(): + if current_key: + key = current_key.decode('ascii') + val = safe_decode(key, current_value) + if key in comma_separated_headers: + existing = hdict.pop(key) + if existing is not None: + val = existing + ', ' + val + hdict[key] = val while True: line = readline() @@ -98,32 +119,22 @@ def read_headers(readline, max_line_size, hdict=None): # {{{ if line == b'\r\n': # Normal end of headers + commit() break if not line.endswith(b'\r\n'): raise ValueError("HTTP requires CRLF terminators") - if line[0] in (b' ', b'\t'): + if line[0] in b' \t': # It's a continuation line. - v = line.strip() + if current_key is None or current_value is None: + raise ValueError('Orphaned continuation line') + current_value += b' ' + line.strip() else: - try: - k, v = line.split(b':', 1) - except ValueError: - raise ValueError("Illegal header line.") - k = k.strip().title() - v = v.strip() - hname = k.decode('ascii') - - if k in comma_separated_headers: - existing = hdict.get(hname) - if existing: - v = b", ".join((existing, v)) - try: - v = v.decode('ascii') - except UnicodeDecodeError: - if hname in decoded_headers: - raise - hdict[hname] = v + commit() + current_key = current_value = None + k, v = line.split(b':', 1) + current_key = k.strip().title() + current_value = v.strip() return hdict # }}} @@ -133,53 +144,130 @@ def http_communicate(conn): request_seen = False try: while True: - # (re)set req to None so that if something goes wrong in + # (re)set pair to None so that if something goes wrong in # the HTTPPair constructor, the error doesn't # get written to the previous request. - req = None - req = conn.server_loop.http_handler(conn) + pair = None + pair = conn.server_loop.http_handler(conn) # This order of operations should guarantee correct pipelining. - req.parse_request() - if not req.ready: + pair.parse_request() + if not pair.ready: # Something went wrong in the parsing (and the server has # probably already made a simple_response). Return and # let the conn close. return request_seen = True - req.respond() - if req.close_connection: + pair.respond() + if pair.close_connection: return except socket.timeout: # Don't error if we're between requests; only error # if 1) no request has been started at all, or 2) we're # in the middle of a request. This allows persistent # connections for HTTP/1.1 - if (not request_seen) or (req and req.started_request): + if (not request_seen) or (pair and pair.started_request): # Don't bother writing the 408 if the response # has already started being written. - if req and not req.sent_headers: - req.simple_response(httplib.REQUEST_TIMEOUT, "Request Timeout") + if pair and not pair.sent_headers: + pair.simple_response(httplib.REQUEST_TIMEOUT, "Request Timeout") except NonHTTPConnRequest: raise except Exception: - conn.server_loop.log.exception() - if req and not req.sent_headers: - req.simple_response(httplib.INTERNAL_SERVER_ERROR, "Internal Server Error") + conn.server_loop.log.exception('Error serving request:', pair.path if pair else None) + if pair and not pair.sent_headers: + pair.simple_response(httplib.INTERNAL_SERVER_ERROR, "Internal Server Error") + +class FixedSizeReader(object): + + def __init__(self, socket_file, content_length): + self.socket_file, self.remaining = socket_file, content_length + + def __call__(self, size=-1): + if size < 0: + size = self.remaining + size = min(self.remaining, size) + if size < 1: + return b'' + data = self.socket_file.read(size) + self.remaining -= len(data) + return data + + +class ChunkedReader(object): + + def __init__(self, socket_file, maxsize): + self.socket_file, self.maxsize = socket_file, maxsize + self.rbuf = BytesIO() + self.bytes_read = 0 + self.finished = False + + def check_size(self): + if self.bytes_read > self.maxsize: + raise MaxSizeExceeded('Request entity too large', self.bytes_read, self.maxsize) + + def read_chunk(self): + if self.finished: + return + line = self.socket_file.readline() + self.bytes_read += len(line) + self.check_size() + chunk_size = line.strip().split(b';', 1)[0] + try: + chunk_size = int(line, 16) + 2 + except Exception: + raise ValueError('%s is not a valid chunk size' % reprlib.repr(chunk_size)) + if chunk_size + self.bytes_read > self.maxsize: + raise MaxSizeExceeded('Request entity too large', self.bytes_read + chunk_size, self.maxsize) + chunk = self.socket_file.read(chunk_size) + if len(chunk) < chunk_size: + raise ValueError('Bad chunked encoding, chunk truncated: %d < %s' % (len(chunk), chunk_size)) + if not chunk.endswith(b'\r\n'): + raise ValueError('Bad chunked encoding: %r != CRLF' % chunk[:-2]) + self.rbuf.seek(0, os.SEEK_END) + self.bytes_read += chunk_size + if chunk_size == 2: + self.finished = True + else: + self.rbuf.write(chunk[:-2]) + + def __call__(self, size=-1): + if size < 0: + # Read all data + while not self.finished: + self.read_chunk() + self.rbuf.seek(0) + rv = self.rbuf.read() + if rv: + self.rbuf.truncate(0) + return rv + if size == 0: + return b'' + while self.rbuf.tell() < size and not self.finished: + self.read_chunk() + data = self.rbuf.getvalue() + self.rbuf.truncate(0) + if size < len(data): + self.rbuf.write(data[size:]) + return data[:size] + return data class HTTPPair(object): ''' Represents a HTTP request/response pair ''' - def __init__(self, conn): + def __init__(self, conn, handle_request): self.conn = conn self.server_loop = conn.server_loop self.max_header_line_size = self.server_loop.max_header_line_size self.scheme = 'http' if self.server_loop.ssl_context is None else 'https' - self.inheaders = {} - self.outheaders = [] + self.inheaders = MultiDict() + self.outheaders = MultiDict() + self.handle_request = handle_request + self.path = () + self.qs = MultiDict() """When True, the request has been parsed and is ready to begin generating the response. When False, signals the calling Connection that the response @@ -198,6 +286,9 @@ class HTTPPair(object): self.status = b'' self.sent_headers = False + self.request_content_length = 0 + self.chunked_read = False + def parse_request(self): """Parse the next HTTP request start-line and message-headers.""" try: @@ -273,7 +364,7 @@ class HTTPPair(object): if b'?' in path: path, qs = path.split(b'?', 1) try: - self.qs = {k.decode('utf-8'):tuple(x.decode('utf-8') for x in v) for k, v in parse_qs(qs, keep_blank_values=True).iteritems()} + self.qs = MultiDict.create_from_query_string(qs) except Exception: self.simple_response(httplib.BAD_REQUEST, "Bad Request", "Malformed Request-Line", 'Unparseable query string') @@ -293,13 +384,13 @@ class HTTPPair(object): def read_request_headers(self): # then all the http headers try: - read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size), self.inheaders) - content_length = int(self.inheaders.get('Content-Length', 0)) + self.inheaders = read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size)) + self.request_content_length = int(self.inheaders.get('Content-Length', 0)) except ValueError as e: self.simple_response(httplib.BAD_REQUEST, "Bad Request", as_unicode(e)) return False - if content_length > self.server_loop.max_request_body_size: + if self.request_content_length > self.server_loop.max_request_body_size: self.simple_response( httplib.REQUEST_ENTITY_TOO_LARGE, "Request Entity Too Large", "The entity sent with the request exceeds the maximum " diff --git a/src/calibre/srv/tests/__init__.py b/src/calibre/srv/tests/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/src/calibre/srv/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/src/calibre/srv/tests/base.py b/src/calibre/srv/tests/base.py new file mode 100644 index 0000000000..643e9df1ec --- /dev/null +++ b/src/calibre/srv/tests/base.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python2 +# vim:fileencoding=UTF-8:ts=4:sw=4:sta:et:sts=4:ai +from __future__ import (unicode_literals, division, absolute_import, + print_function) + +__license__ = 'GPL v3' +__copyright__ = '2011, Kovid Goyal ' +__docformat__ = 'restructuredtext en' + +import unittest, shutil +from functools import partial + +rmtree = partial(shutil.rmtree, ignore_errors=True) + +class BaseTest(unittest.TestCase): + + longMessage = True + maxDiff = None + + ae = unittest.TestCase.assertEqual + + diff --git a/src/calibre/srv/tests/http.py b/src/calibre/srv/tests/http.py new file mode 100644 index 0000000000..59243290a0 --- /dev/null +++ b/src/calibre/srv/tests/http.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python2 +# vim:fileencoding=utf-8 +from __future__ import (unicode_literals, division, absolute_import, + print_function) + +__license__ = 'GPL v3' +__copyright__ = '2015, Kovid Goyal ' + +import textwrap +from io import BytesIO + +from calibre.srv.tests.base import BaseTest + +def headers(raw): + return BytesIO(textwrap.dedent(raw).encode('utf-8')) + +class TestHTTP(BaseTest): + + def test_header_parsing(self): + 'Test parsing of HTTP headers' + from calibre.srv.http import read_headers + + def test(name, raw, **kwargs): + hdict = read_headers(headers(raw).readline) + self.assertSetEqual(set(hdict.items()), {(k.replace('_', '-').title(), v) for k, v in kwargs.iteritems()}, name + ' failed') + + test('Continuation line parsing', + '''\ + a: one\r + b: two\r + 2\r + \t3\r + c:three\r + \r\n''', a='one', b='two 2 3', c='three') + + test('Non-ascii headers parsing', + '''\ + a:mūs\r + \r\n''', a='mūs'.encode('utf-8')) + + with self.assertRaises(ValueError): + read_headers(headers('Connection:mūs\r\n').readline) + read_headers(headers('Connection\r\n').readline) + read_headers(headers('Connection:a\r\n').readline) + read_headers(headers('Connection:a\n').readline) + read_headers(headers(' Connection:a\n').readline) + diff --git a/src/calibre/srv/tests/main.py b/src/calibre/srv/tests/main.py new file mode 100644 index 0000000000..ea41078d83 --- /dev/null +++ b/src/calibre/srv/tests/main.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python2 +# vim:fileencoding=utf-8 +from __future__ import (unicode_literals, division, absolute_import, + print_function) + +__license__ = 'GPL v3' +__copyright__ = '2015, Kovid Goyal ' + +import unittest, os, argparse, time, functools + +try: + import init_calibre + del init_calibre +except ImportError: + pass + +def no_endl(f): + @functools.wraps(f) + def func(*args, **kwargs): + self = f.__self__ + orig = self.stream.writeln + self.stream.writeln = self.stream.write + try: + return f(*args, **kwargs) + finally: + self.stream.writeln = orig + return func + +class TestResult(unittest.TextTestResult): + + def __init__(self, *args, **kwargs): + super(TestResult, self).__init__(*args, **kwargs) + self.start_time = {} + for x in ('Success', 'Error', 'Failure', 'Skip', 'ExpectedFailure', 'UnexpectedSuccess'): + x = 'add' + x + setattr(self, x, no_endl(getattr(self, x))) + self.times = {} + + def startTest(self, test): + self.start_time[test] = time.time() + return super(TestResult, self).startTest(test) + + def stopTest(self, test): + orig = self.stream.writeln + self.stream.writeln = self.stream.write + super(TestResult, self).stopTest(test) + elapsed = time.time() + elapsed -= self.start_time.get(test, elapsed) + self.times[test] = elapsed + self.stream.writeln = orig + self.stream.writeln(' [%.1g s]' % elapsed) + + def stopTestRun(self): + super(TestResult, self).stopTestRun() + if self.wasSuccessful(): + tests = sorted(self.times, key=self.times.get, reverse=True) + slowest = ['%s [%g s]' % (t.id(), self.times[t]) for t in tests[:3]] + if len(slowest) > 1: + self.stream.writeln('\nSlowest tests: %s' % ' '.join(slowest)) + +def find_tests(): + return unittest.defaultTestLoader.discover(os.path.dirname(os.path.abspath(__file__)), pattern='*.py') + +def run_tests(find_tests=find_tests): + parser = argparse.ArgumentParser() + parser.add_argument('name', nargs='?', default=None, + help='The name of the test to run, for e.g. writing.WritingTest.many_many_basic or .many_many_basic for a shortcut') + args = parser.parse_args() + if args.name and args.name.startswith('.'): + tests = find_tests() + q = args.name[1:] + if not q.startswith('test_'): + q = 'test_' + q + ans = None + try: + for suite in tests: + for test in suite._tests: + if test.__class__.__name__ == 'ModuleImportFailure': + raise Exception('Failed to import a test module: %s' % test) + for s in test: + if s._testMethodName == q: + ans = s + raise StopIteration() + except StopIteration: + pass + if ans is None: + print ('No test named %s found' % args.name) + raise SystemExit(1) + tests = ans + else: + tests = unittest.defaultTestLoader.loadTestsFromName(args.name) if args.name else find_tests() + r = unittest.TextTestRunner + r.resultclass = TestResult + r(verbosity=4).run(tests) + +if __name__ == '__main__': + from calibre.utils.config_base import reset_tweaks_to_default + from calibre.ebooks.metadata.book.base import reset_field_metadata + reset_tweaks_to_default() + reset_field_metadata() + run_tests() diff --git a/src/calibre/srv/utils.py b/src/calibre/srv/utils.py new file mode 100644 index 0000000000..ddb0649ebd --- /dev/null +++ b/src/calibre/srv/utils.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python2 +# vim:fileencoding=utf-8 +from __future__ import (unicode_literals, division, absolute_import, + print_function) + +__license__ = 'GPL v3' +__copyright__ = '2015, Kovid Goyal ' + +from urlparse import parse_qs + +class MultiDict(dict): + + def __setitem__(self, key, val): + vals = dict.get(self, key, []) + vals.append(val) + dict.__setitem__(self, key, vals) + + def __getitem__(self, key): + return dict.__getitem__(self, key)[-1] + + @staticmethod + def create_from_query_string(qs): + ans = MultiDict() + for k, v in parse_qs(qs, keep_blank_values=True): + dict.__setitem__(ans, k.decode('utf-8'), [x.decode('utf-8') for x in v]) + return ans + + def update_from_listdict(self, ld): + for key, values in ld.iteritems(): + for val in values: + self[key] = val + + def items(self, duplicates=True): + for k, v in dict.iteritems(self): + if duplicates: + for x in v: + yield k, x + else: + yield k, v[-1] + iteritems = items + + def values(self, duplicates=True): + for v in dict.itervalues(self): + if duplicates: + for x in v: + yield x + else: + yield v[-1] + itervalues = values + + def set(self, key, val, replace=False): + if replace: + dict.__setitem__(self, key, [val]) + else: + self[key] = val + + def get(self, key, default=None, all=False): + if all: + try: + return dict.__getitem__(self, key) + except KeyError: + return [] + return self.__getitem__(key) + + def pop(self, key, default=None, all=False): + ans = dict.pop(self, key, default) + if ans is default: + return [] if all else default + return ans if all else ans[-1]