Finish up HTTP header parsing and add some tests

This commit is contained in:
Kovid Goyal 2015-05-18 10:01:36 +05:30
parent 18a04ed23d
commit 48f548236e
6 changed files with 383 additions and 52 deletions

View File

@ -6,13 +6,15 @@ from __future__ import (unicode_literals, division, absolute_import,
__license__ = 'GPL v3' __license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>' __copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import httplib, socket, re import httplib, socket, re, os
from io import BytesIO
import repr as reprlib
from urllib import unquote from urllib import unquote
from urlparse import parse_qs
from functools import partial from functools import partial
from calibre import as_unicode from calibre import as_unicode
from calibre.srv.errors import MaxSizeExceeded, NonHTTPConnRequest from calibre.srv.errors import MaxSizeExceeded, NonHTTPConnRequest
from calibre.srv.utils import MultiDict
HTTP1 = 'HTTP/1.0' HTTP1 = 'HTTP/1.0'
HTTP11 = 'HTTP/1.1' HTTP11 = 'HTTP/1.1'
@ -62,19 +64,19 @@ def parse_request_uri(uri): # {{{
# }}} # }}}
comma_separated_headers = { comma_separated_headers = {
b'Accept', b'Accept-Charset', b'Accept-Encoding', 'Accept', 'Accept-Charset', 'Accept-Encoding',
b'Accept-Language', b'Accept-Ranges', b'Allow', b'Cache-Control', 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control',
b'Connection', b'Content-Encoding', b'Content-Language', b'Expect', 'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
b'If-Match', b'If-None-Match', b'Pragma', b'Proxy-Authenticate', b'TE', 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
b'Trailer', b'Transfer-Encoding', b'Upgrade', b'Vary', b'Via', b'Warning', 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
b'WWW-Authenticate' 'WWW-Authenticate'
} }
decoded_headers = { decoded_headers = {
'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect', 'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect',
} } | comma_separated_headers
def read_headers(readline, max_line_size, hdict=None): # {{{ def read_headers(readline): # {{{
""" """
Read headers from the given stream into the given header dict. Read headers from the given stream into the given header dict.
@ -87,8 +89,27 @@ def read_headers(readline, max_line_size, hdict=None): # {{{
This function raises ValueError when the read bytes violate the HTTP spec. This function raises ValueError when the read bytes violate the HTTP spec.
You should probably return "400 Bad Request" if this happens. You should probably return "400 Bad Request" if this happens.
""" """
if hdict is None: hdict = MultiDict()
hdict = {}
def safe_decode(hname, value):
try:
return value.decode('ascii')
except UnicodeDecodeError:
if hname in decoded_headers:
raise
return value
current_key = current_value = None
def commit():
if current_key:
key = current_key.decode('ascii')
val = safe_decode(key, current_value)
if key in comma_separated_headers:
existing = hdict.pop(key)
if existing is not None:
val = existing + ', ' + val
hdict[key] = val
while True: while True:
line = readline() line = readline()
@ -98,32 +119,22 @@ def read_headers(readline, max_line_size, hdict=None): # {{{
if line == b'\r\n': if line == b'\r\n':
# Normal end of headers # Normal end of headers
commit()
break break
if not line.endswith(b'\r\n'): if not line.endswith(b'\r\n'):
raise ValueError("HTTP requires CRLF terminators") raise ValueError("HTTP requires CRLF terminators")
if line[0] in (b' ', b'\t'): if line[0] in b' \t':
# It's a continuation line. # It's a continuation line.
v = line.strip() if current_key is None or current_value is None:
raise ValueError('Orphaned continuation line')
current_value += b' ' + line.strip()
else: else:
try: commit()
k, v = line.split(b':', 1) current_key = current_value = None
except ValueError: k, v = line.split(b':', 1)
raise ValueError("Illegal header line.") current_key = k.strip().title()
k = k.strip().title() current_value = v.strip()
v = v.strip()
hname = k.decode('ascii')
if k in comma_separated_headers:
existing = hdict.get(hname)
if existing:
v = b", ".join((existing, v))
try:
v = v.decode('ascii')
except UnicodeDecodeError:
if hname in decoded_headers:
raise
hdict[hname] = v
return hdict return hdict
# }}} # }}}
@ -133,53 +144,130 @@ def http_communicate(conn):
request_seen = False request_seen = False
try: try:
while True: while True:
# (re)set req to None so that if something goes wrong in # (re)set pair to None so that if something goes wrong in
# the HTTPPair constructor, the error doesn't # the HTTPPair constructor, the error doesn't
# get written to the previous request. # get written to the previous request.
req = None pair = None
req = conn.server_loop.http_handler(conn) pair = conn.server_loop.http_handler(conn)
# This order of operations should guarantee correct pipelining. # This order of operations should guarantee correct pipelining.
req.parse_request() pair.parse_request()
if not req.ready: if not pair.ready:
# Something went wrong in the parsing (and the server has # Something went wrong in the parsing (and the server has
# probably already made a simple_response). Return and # probably already made a simple_response). Return and
# let the conn close. # let the conn close.
return return
request_seen = True request_seen = True
req.respond() pair.respond()
if req.close_connection: if pair.close_connection:
return return
except socket.timeout: except socket.timeout:
# Don't error if we're between requests; only error # Don't error if we're between requests; only error
# if 1) no request has been started at all, or 2) we're # if 1) no request has been started at all, or 2) we're
# in the middle of a request. This allows persistent # in the middle of a request. This allows persistent
# connections for HTTP/1.1 # connections for HTTP/1.1
if (not request_seen) or (req and req.started_request): if (not request_seen) or (pair and pair.started_request):
# Don't bother writing the 408 if the response # Don't bother writing the 408 if the response
# has already started being written. # has already started being written.
if req and not req.sent_headers: if pair and not pair.sent_headers:
req.simple_response(httplib.REQUEST_TIMEOUT, "Request Timeout") pair.simple_response(httplib.REQUEST_TIMEOUT, "Request Timeout")
except NonHTTPConnRequest: except NonHTTPConnRequest:
raise raise
except Exception: except Exception:
conn.server_loop.log.exception() conn.server_loop.log.exception('Error serving request:', pair.path if pair else None)
if req and not req.sent_headers: if pair and not pair.sent_headers:
req.simple_response(httplib.INTERNAL_SERVER_ERROR, "Internal Server Error") pair.simple_response(httplib.INTERNAL_SERVER_ERROR, "Internal Server Error")
class FixedSizeReader(object):
def __init__(self, socket_file, content_length):
self.socket_file, self.remaining = socket_file, content_length
def __call__(self, size=-1):
if size < 0:
size = self.remaining
size = min(self.remaining, size)
if size < 1:
return b''
data = self.socket_file.read(size)
self.remaining -= len(data)
return data
class ChunkedReader(object):
def __init__(self, socket_file, maxsize):
self.socket_file, self.maxsize = socket_file, maxsize
self.rbuf = BytesIO()
self.bytes_read = 0
self.finished = False
def check_size(self):
if self.bytes_read > self.maxsize:
raise MaxSizeExceeded('Request entity too large', self.bytes_read, self.maxsize)
def read_chunk(self):
if self.finished:
return
line = self.socket_file.readline()
self.bytes_read += len(line)
self.check_size()
chunk_size = line.strip().split(b';', 1)[0]
try:
chunk_size = int(line, 16) + 2
except Exception:
raise ValueError('%s is not a valid chunk size' % reprlib.repr(chunk_size))
if chunk_size + self.bytes_read > self.maxsize:
raise MaxSizeExceeded('Request entity too large', self.bytes_read + chunk_size, self.maxsize)
chunk = self.socket_file.read(chunk_size)
if len(chunk) < chunk_size:
raise ValueError('Bad chunked encoding, chunk truncated: %d < %s' % (len(chunk), chunk_size))
if not chunk.endswith(b'\r\n'):
raise ValueError('Bad chunked encoding: %r != CRLF' % chunk[:-2])
self.rbuf.seek(0, os.SEEK_END)
self.bytes_read += chunk_size
if chunk_size == 2:
self.finished = True
else:
self.rbuf.write(chunk[:-2])
def __call__(self, size=-1):
if size < 0:
# Read all data
while not self.finished:
self.read_chunk()
self.rbuf.seek(0)
rv = self.rbuf.read()
if rv:
self.rbuf.truncate(0)
return rv
if size == 0:
return b''
while self.rbuf.tell() < size and not self.finished:
self.read_chunk()
data = self.rbuf.getvalue()
self.rbuf.truncate(0)
if size < len(data):
self.rbuf.write(data[size:])
return data[:size]
return data
class HTTPPair(object): class HTTPPair(object):
''' Represents a HTTP request/response pair ''' ''' Represents a HTTP request/response pair '''
def __init__(self, conn): def __init__(self, conn, handle_request):
self.conn = conn self.conn = conn
self.server_loop = conn.server_loop self.server_loop = conn.server_loop
self.max_header_line_size = self.server_loop.max_header_line_size self.max_header_line_size = self.server_loop.max_header_line_size
self.scheme = 'http' if self.server_loop.ssl_context is None else 'https' self.scheme = 'http' if self.server_loop.ssl_context is None else 'https'
self.inheaders = {} self.inheaders = MultiDict()
self.outheaders = [] self.outheaders = MultiDict()
self.handle_request = handle_request
self.path = ()
self.qs = MultiDict()
"""When True, the request has been parsed and is ready to begin generating """When True, the request has been parsed and is ready to begin generating
the response. When False, signals the calling Connection that the response the response. When False, signals the calling Connection that the response
@ -198,6 +286,9 @@ class HTTPPair(object):
self.status = b'' self.status = b''
self.sent_headers = False self.sent_headers = False
self.request_content_length = 0
self.chunked_read = False
def parse_request(self): def parse_request(self):
"""Parse the next HTTP request start-line and message-headers.""" """Parse the next HTTP request start-line and message-headers."""
try: try:
@ -273,7 +364,7 @@ class HTTPPair(object):
if b'?' in path: if b'?' in path:
path, qs = path.split(b'?', 1) path, qs = path.split(b'?', 1)
try: try:
self.qs = {k.decode('utf-8'):tuple(x.decode('utf-8') for x in v) for k, v in parse_qs(qs, keep_blank_values=True).iteritems()} self.qs = MultiDict.create_from_query_string(qs)
except Exception: except Exception:
self.simple_response(httplib.BAD_REQUEST, "Bad Request", "Malformed Request-Line", self.simple_response(httplib.BAD_REQUEST, "Bad Request", "Malformed Request-Line",
'Unparseable query string') 'Unparseable query string')
@ -293,13 +384,13 @@ class HTTPPair(object):
def read_request_headers(self): def read_request_headers(self):
# then all the http headers # then all the http headers
try: try:
read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size), self.inheaders) self.inheaders = read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size))
content_length = int(self.inheaders.get('Content-Length', 0)) self.request_content_length = int(self.inheaders.get('Content-Length', 0))
except ValueError as e: except ValueError as e:
self.simple_response(httplib.BAD_REQUEST, "Bad Request", as_unicode(e)) self.simple_response(httplib.BAD_REQUEST, "Bad Request", as_unicode(e))
return False return False
if content_length > self.server_loop.max_request_body_size: if self.request_content_length > self.server_loop.max_request_body_size:
self.simple_response( self.simple_response(
httplib.REQUEST_ENTITY_TOO_LARGE, "Request Entity Too Large", httplib.REQUEST_ENTITY_TOO_LARGE, "Request Entity Too Large",
"The entity sent with the request exceeds the maximum " "The entity sent with the request exceeds the maximum "

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,22 @@
#!/usr/bin/env python2
# vim:fileencoding=UTF-8:ts=4:sw=4:sta:et:sts=4:ai
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2011, Kovid Goyal <kovid@kovidgoyal.net>'
__docformat__ = 'restructuredtext en'
import unittest, shutil
from functools import partial
rmtree = partial(shutil.rmtree, ignore_errors=True)
class BaseTest(unittest.TestCase):
longMessage = True
maxDiff = None
ae = unittest.TestCase.assertEqual

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import textwrap
from io import BytesIO
from calibre.srv.tests.base import BaseTest
def headers(raw):
return BytesIO(textwrap.dedent(raw).encode('utf-8'))
class TestHTTP(BaseTest):
def test_header_parsing(self):
'Test parsing of HTTP headers'
from calibre.srv.http import read_headers
def test(name, raw, **kwargs):
hdict = read_headers(headers(raw).readline)
self.assertSetEqual(set(hdict.items()), {(k.replace('_', '-').title(), v) for k, v in kwargs.iteritems()}, name + ' failed')
test('Continuation line parsing',
'''\
a: one\r
b: two\r
2\r
\t3\r
c:three\r
\r\n''', a='one', b='two 2 3', c='three')
test('Non-ascii headers parsing',
'''\
a:mūs\r
\r\n''', a='mūs'.encode('utf-8'))
with self.assertRaises(ValueError):
read_headers(headers('Connection:mūs\r\n').readline)
read_headers(headers('Connection\r\n').readline)
read_headers(headers('Connection:a\r\n').readline)
read_headers(headers('Connection:a\n').readline)
read_headers(headers(' Connection:a\n').readline)

View File

@ -0,0 +1,101 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import unittest, os, argparse, time, functools
try:
import init_calibre
del init_calibre
except ImportError:
pass
def no_endl(f):
@functools.wraps(f)
def func(*args, **kwargs):
self = f.__self__
orig = self.stream.writeln
self.stream.writeln = self.stream.write
try:
return f(*args, **kwargs)
finally:
self.stream.writeln = orig
return func
class TestResult(unittest.TextTestResult):
def __init__(self, *args, **kwargs):
super(TestResult, self).__init__(*args, **kwargs)
self.start_time = {}
for x in ('Success', 'Error', 'Failure', 'Skip', 'ExpectedFailure', 'UnexpectedSuccess'):
x = 'add' + x
setattr(self, x, no_endl(getattr(self, x)))
self.times = {}
def startTest(self, test):
self.start_time[test] = time.time()
return super(TestResult, self).startTest(test)
def stopTest(self, test):
orig = self.stream.writeln
self.stream.writeln = self.stream.write
super(TestResult, self).stopTest(test)
elapsed = time.time()
elapsed -= self.start_time.get(test, elapsed)
self.times[test] = elapsed
self.stream.writeln = orig
self.stream.writeln(' [%.1g s]' % elapsed)
def stopTestRun(self):
super(TestResult, self).stopTestRun()
if self.wasSuccessful():
tests = sorted(self.times, key=self.times.get, reverse=True)
slowest = ['%s [%g s]' % (t.id(), self.times[t]) for t in tests[:3]]
if len(slowest) > 1:
self.stream.writeln('\nSlowest tests: %s' % ' '.join(slowest))
def find_tests():
return unittest.defaultTestLoader.discover(os.path.dirname(os.path.abspath(__file__)), pattern='*.py')
def run_tests(find_tests=find_tests):
parser = argparse.ArgumentParser()
parser.add_argument('name', nargs='?', default=None,
help='The name of the test to run, for e.g. writing.WritingTest.many_many_basic or .many_many_basic for a shortcut')
args = parser.parse_args()
if args.name and args.name.startswith('.'):
tests = find_tests()
q = args.name[1:]
if not q.startswith('test_'):
q = 'test_' + q
ans = None
try:
for suite in tests:
for test in suite._tests:
if test.__class__.__name__ == 'ModuleImportFailure':
raise Exception('Failed to import a test module: %s' % test)
for s in test:
if s._testMethodName == q:
ans = s
raise StopIteration()
except StopIteration:
pass
if ans is None:
print ('No test named %s found' % args.name)
raise SystemExit(1)
tests = ans
else:
tests = unittest.defaultTestLoader.loadTestsFromName(args.name) if args.name else find_tests()
r = unittest.TextTestRunner
r.resultclass = TestResult
r(verbosity=4).run(tests)
if __name__ == '__main__':
from calibre.utils.config_base import reset_tweaks_to_default
from calibre.ebooks.metadata.book.base import reset_field_metadata
reset_tweaks_to_default()
reset_field_metadata()
run_tests()

69
src/calibre/srv/utils.py Normal file
View File

@ -0,0 +1,69 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
from urlparse import parse_qs
class MultiDict(dict):
def __setitem__(self, key, val):
vals = dict.get(self, key, [])
vals.append(val)
dict.__setitem__(self, key, vals)
def __getitem__(self, key):
return dict.__getitem__(self, key)[-1]
@staticmethod
def create_from_query_string(qs):
ans = MultiDict()
for k, v in parse_qs(qs, keep_blank_values=True):
dict.__setitem__(ans, k.decode('utf-8'), [x.decode('utf-8') for x in v])
return ans
def update_from_listdict(self, ld):
for key, values in ld.iteritems():
for val in values:
self[key] = val
def items(self, duplicates=True):
for k, v in dict.iteritems(self):
if duplicates:
for x in v:
yield k, x
else:
yield k, v[-1]
iteritems = items
def values(self, duplicates=True):
for v in dict.itervalues(self):
if duplicates:
for x in v:
yield x
else:
yield v[-1]
itervalues = values
def set(self, key, val, replace=False):
if replace:
dict.__setitem__(self, key, [val])
else:
self[key] = val
def get(self, key, default=None, all=False):
if all:
try:
return dict.__getitem__(self, key)
except KeyError:
return []
return self.__getitem__(key)
def pop(self, key, default=None, all=False):
ans = dict.pop(self, key, default)
if ans is default:
return [] if all else default
return ans if all else ans[-1]