Finish up HTTP header parsing and add some tests

This commit is contained in:
Kovid Goyal 2015-05-18 10:01:36 +05:30
parent 18a04ed23d
commit 48f548236e
6 changed files with 383 additions and 52 deletions

View File

@ -6,13 +6,15 @@ from __future__ import (unicode_literals, division, absolute_import,
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import httplib, socket, re
import httplib, socket, re, os
from io import BytesIO
import repr as reprlib
from urllib import unquote
from urlparse import parse_qs
from functools import partial
from calibre import as_unicode
from calibre.srv.errors import MaxSizeExceeded, NonHTTPConnRequest
from calibre.srv.utils import MultiDict
HTTP1 = 'HTTP/1.0'
HTTP11 = 'HTTP/1.1'
@ -62,19 +64,19 @@ def parse_request_uri(uri): # {{{
# }}}
comma_separated_headers = {
b'Accept', b'Accept-Charset', b'Accept-Encoding',
b'Accept-Language', b'Accept-Ranges', b'Allow', b'Cache-Control',
b'Connection', b'Content-Encoding', b'Content-Language', b'Expect',
b'If-Match', b'If-None-Match', b'Pragma', b'Proxy-Authenticate', b'TE',
b'Trailer', b'Transfer-Encoding', b'Upgrade', b'Vary', b'Via', b'Warning',
b'WWW-Authenticate'
'Accept', 'Accept-Charset', 'Accept-Encoding',
'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control',
'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
'WWW-Authenticate'
}
decoded_headers = {
'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect',
}
} | comma_separated_headers
def read_headers(readline, max_line_size, hdict=None): # {{{
def read_headers(readline): # {{{
"""
Read headers from the given stream into the given header dict.
@ -87,8 +89,27 @@ def read_headers(readline, max_line_size, hdict=None): # {{{
This function raises ValueError when the read bytes violate the HTTP spec.
You should probably return "400 Bad Request" if this happens.
"""
if hdict is None:
hdict = {}
hdict = MultiDict()
def safe_decode(hname, value):
try:
return value.decode('ascii')
except UnicodeDecodeError:
if hname in decoded_headers:
raise
return value
current_key = current_value = None
def commit():
if current_key:
key = current_key.decode('ascii')
val = safe_decode(key, current_value)
if key in comma_separated_headers:
existing = hdict.pop(key)
if existing is not None:
val = existing + ', ' + val
hdict[key] = val
while True:
line = readline()
@ -98,32 +119,22 @@ def read_headers(readline, max_line_size, hdict=None): # {{{
if line == b'\r\n':
# Normal end of headers
commit()
break
if not line.endswith(b'\r\n'):
raise ValueError("HTTP requires CRLF terminators")
if line[0] in (b' ', b'\t'):
if line[0] in b' \t':
# It's a continuation line.
v = line.strip()
if current_key is None or current_value is None:
raise ValueError('Orphaned continuation line')
current_value += b' ' + line.strip()
else:
try:
k, v = line.split(b':', 1)
except ValueError:
raise ValueError("Illegal header line.")
k = k.strip().title()
v = v.strip()
hname = k.decode('ascii')
if k in comma_separated_headers:
existing = hdict.get(hname)
if existing:
v = b", ".join((existing, v))
try:
v = v.decode('ascii')
except UnicodeDecodeError:
if hname in decoded_headers:
raise
hdict[hname] = v
commit()
current_key = current_value = None
k, v = line.split(b':', 1)
current_key = k.strip().title()
current_value = v.strip()
return hdict
# }}}
@ -133,53 +144,130 @@ def http_communicate(conn):
request_seen = False
try:
while True:
# (re)set req to None so that if something goes wrong in
# (re)set pair to None so that if something goes wrong in
# the HTTPPair constructor, the error doesn't
# get written to the previous request.
req = None
req = conn.server_loop.http_handler(conn)
pair = None
pair = conn.server_loop.http_handler(conn)
# This order of operations should guarantee correct pipelining.
req.parse_request()
if not req.ready:
pair.parse_request()
if not pair.ready:
# Something went wrong in the parsing (and the server has
# probably already made a simple_response). Return and
# let the conn close.
return
request_seen = True
req.respond()
if req.close_connection:
pair.respond()
if pair.close_connection:
return
except socket.timeout:
# Don't error if we're between requests; only error
# if 1) no request has been started at all, or 2) we're
# in the middle of a request. This allows persistent
# connections for HTTP/1.1
if (not request_seen) or (req and req.started_request):
if (not request_seen) or (pair and pair.started_request):
# Don't bother writing the 408 if the response
# has already started being written.
if req and not req.sent_headers:
req.simple_response(httplib.REQUEST_TIMEOUT, "Request Timeout")
if pair and not pair.sent_headers:
pair.simple_response(httplib.REQUEST_TIMEOUT, "Request Timeout")
except NonHTTPConnRequest:
raise
except Exception:
conn.server_loop.log.exception()
if req and not req.sent_headers:
req.simple_response(httplib.INTERNAL_SERVER_ERROR, "Internal Server Error")
conn.server_loop.log.exception('Error serving request:', pair.path if pair else None)
if pair and not pair.sent_headers:
pair.simple_response(httplib.INTERNAL_SERVER_ERROR, "Internal Server Error")
class FixedSizeReader(object):
def __init__(self, socket_file, content_length):
self.socket_file, self.remaining = socket_file, content_length
def __call__(self, size=-1):
if size < 0:
size = self.remaining
size = min(self.remaining, size)
if size < 1:
return b''
data = self.socket_file.read(size)
self.remaining -= len(data)
return data
class ChunkedReader(object):
def __init__(self, socket_file, maxsize):
self.socket_file, self.maxsize = socket_file, maxsize
self.rbuf = BytesIO()
self.bytes_read = 0
self.finished = False
def check_size(self):
if self.bytes_read > self.maxsize:
raise MaxSizeExceeded('Request entity too large', self.bytes_read, self.maxsize)
def read_chunk(self):
if self.finished:
return
line = self.socket_file.readline()
self.bytes_read += len(line)
self.check_size()
chunk_size = line.strip().split(b';', 1)[0]
try:
chunk_size = int(line, 16) + 2
except Exception:
raise ValueError('%s is not a valid chunk size' % reprlib.repr(chunk_size))
if chunk_size + self.bytes_read > self.maxsize:
raise MaxSizeExceeded('Request entity too large', self.bytes_read + chunk_size, self.maxsize)
chunk = self.socket_file.read(chunk_size)
if len(chunk) < chunk_size:
raise ValueError('Bad chunked encoding, chunk truncated: %d < %s' % (len(chunk), chunk_size))
if not chunk.endswith(b'\r\n'):
raise ValueError('Bad chunked encoding: %r != CRLF' % chunk[:-2])
self.rbuf.seek(0, os.SEEK_END)
self.bytes_read += chunk_size
if chunk_size == 2:
self.finished = True
else:
self.rbuf.write(chunk[:-2])
def __call__(self, size=-1):
if size < 0:
# Read all data
while not self.finished:
self.read_chunk()
self.rbuf.seek(0)
rv = self.rbuf.read()
if rv:
self.rbuf.truncate(0)
return rv
if size == 0:
return b''
while self.rbuf.tell() < size and not self.finished:
self.read_chunk()
data = self.rbuf.getvalue()
self.rbuf.truncate(0)
if size < len(data):
self.rbuf.write(data[size:])
return data[:size]
return data
class HTTPPair(object):
''' Represents a HTTP request/response pair '''
def __init__(self, conn):
def __init__(self, conn, handle_request):
self.conn = conn
self.server_loop = conn.server_loop
self.max_header_line_size = self.server_loop.max_header_line_size
self.scheme = 'http' if self.server_loop.ssl_context is None else 'https'
self.inheaders = {}
self.outheaders = []
self.inheaders = MultiDict()
self.outheaders = MultiDict()
self.handle_request = handle_request
self.path = ()
self.qs = MultiDict()
"""When True, the request has been parsed and is ready to begin generating
the response. When False, signals the calling Connection that the response
@ -198,6 +286,9 @@ class HTTPPair(object):
self.status = b''
self.sent_headers = False
self.request_content_length = 0
self.chunked_read = False
def parse_request(self):
"""Parse the next HTTP request start-line and message-headers."""
try:
@ -273,7 +364,7 @@ class HTTPPair(object):
if b'?' in path:
path, qs = path.split(b'?', 1)
try:
self.qs = {k.decode('utf-8'):tuple(x.decode('utf-8') for x in v) for k, v in parse_qs(qs, keep_blank_values=True).iteritems()}
self.qs = MultiDict.create_from_query_string(qs)
except Exception:
self.simple_response(httplib.BAD_REQUEST, "Bad Request", "Malformed Request-Line",
'Unparseable query string')
@ -293,13 +384,13 @@ class HTTPPair(object):
def read_request_headers(self):
# then all the http headers
try:
read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size), self.inheaders)
content_length = int(self.inheaders.get('Content-Length', 0))
self.inheaders = read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size))
self.request_content_length = int(self.inheaders.get('Content-Length', 0))
except ValueError as e:
self.simple_response(httplib.BAD_REQUEST, "Bad Request", as_unicode(e))
return False
if content_length > self.server_loop.max_request_body_size:
if self.request_content_length > self.server_loop.max_request_body_size:
self.simple_response(
httplib.REQUEST_ENTITY_TOO_LARGE, "Request Entity Too Large",
"The entity sent with the request exceeds the maximum "

View File

@ -0,0 +1 @@

View File

@ -0,0 +1,22 @@
#!/usr/bin/env python2
# vim:fileencoding=UTF-8:ts=4:sw=4:sta:et:sts=4:ai
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2011, Kovid Goyal <kovid@kovidgoyal.net>'
__docformat__ = 'restructuredtext en'
import unittest, shutil
from functools import partial
rmtree = partial(shutil.rmtree, ignore_errors=True)
class BaseTest(unittest.TestCase):
longMessage = True
maxDiff = None
ae = unittest.TestCase.assertEqual

View File

@ -0,0 +1,47 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import textwrap
from io import BytesIO
from calibre.srv.tests.base import BaseTest
def headers(raw):
return BytesIO(textwrap.dedent(raw).encode('utf-8'))
class TestHTTP(BaseTest):
def test_header_parsing(self):
'Test parsing of HTTP headers'
from calibre.srv.http import read_headers
def test(name, raw, **kwargs):
hdict = read_headers(headers(raw).readline)
self.assertSetEqual(set(hdict.items()), {(k.replace('_', '-').title(), v) for k, v in kwargs.iteritems()}, name + ' failed')
test('Continuation line parsing',
'''\
a: one\r
b: two\r
2\r
\t3\r
c:three\r
\r\n''', a='one', b='two 2 3', c='three')
test('Non-ascii headers parsing',
'''\
a:mūs\r
\r\n''', a='mūs'.encode('utf-8'))
with self.assertRaises(ValueError):
read_headers(headers('Connection:mūs\r\n').readline)
read_headers(headers('Connection\r\n').readline)
read_headers(headers('Connection:a\r\n').readline)
read_headers(headers('Connection:a\n').readline)
read_headers(headers(' Connection:a\n').readline)

View File

@ -0,0 +1,101 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import unittest, os, argparse, time, functools
try:
import init_calibre
del init_calibre
except ImportError:
pass
def no_endl(f):
@functools.wraps(f)
def func(*args, **kwargs):
self = f.__self__
orig = self.stream.writeln
self.stream.writeln = self.stream.write
try:
return f(*args, **kwargs)
finally:
self.stream.writeln = orig
return func
class TestResult(unittest.TextTestResult):
def __init__(self, *args, **kwargs):
super(TestResult, self).__init__(*args, **kwargs)
self.start_time = {}
for x in ('Success', 'Error', 'Failure', 'Skip', 'ExpectedFailure', 'UnexpectedSuccess'):
x = 'add' + x
setattr(self, x, no_endl(getattr(self, x)))
self.times = {}
def startTest(self, test):
self.start_time[test] = time.time()
return super(TestResult, self).startTest(test)
def stopTest(self, test):
orig = self.stream.writeln
self.stream.writeln = self.stream.write
super(TestResult, self).stopTest(test)
elapsed = time.time()
elapsed -= self.start_time.get(test, elapsed)
self.times[test] = elapsed
self.stream.writeln = orig
self.stream.writeln(' [%.1g s]' % elapsed)
def stopTestRun(self):
super(TestResult, self).stopTestRun()
if self.wasSuccessful():
tests = sorted(self.times, key=self.times.get, reverse=True)
slowest = ['%s [%g s]' % (t.id(), self.times[t]) for t in tests[:3]]
if len(slowest) > 1:
self.stream.writeln('\nSlowest tests: %s' % ' '.join(slowest))
def find_tests():
return unittest.defaultTestLoader.discover(os.path.dirname(os.path.abspath(__file__)), pattern='*.py')
def run_tests(find_tests=find_tests):
parser = argparse.ArgumentParser()
parser.add_argument('name', nargs='?', default=None,
help='The name of the test to run, for e.g. writing.WritingTest.many_many_basic or .many_many_basic for a shortcut')
args = parser.parse_args()
if args.name and args.name.startswith('.'):
tests = find_tests()
q = args.name[1:]
if not q.startswith('test_'):
q = 'test_' + q
ans = None
try:
for suite in tests:
for test in suite._tests:
if test.__class__.__name__ == 'ModuleImportFailure':
raise Exception('Failed to import a test module: %s' % test)
for s in test:
if s._testMethodName == q:
ans = s
raise StopIteration()
except StopIteration:
pass
if ans is None:
print ('No test named %s found' % args.name)
raise SystemExit(1)
tests = ans
else:
tests = unittest.defaultTestLoader.loadTestsFromName(args.name) if args.name else find_tests()
r = unittest.TextTestRunner
r.resultclass = TestResult
r(verbosity=4).run(tests)
if __name__ == '__main__':
from calibre.utils.config_base import reset_tweaks_to_default
from calibre.ebooks.metadata.book.base import reset_field_metadata
reset_tweaks_to_default()
reset_field_metadata()
run_tests()

69
src/calibre/srv/utils.py Normal file
View File

@ -0,0 +1,69 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
from urlparse import parse_qs
class MultiDict(dict):
def __setitem__(self, key, val):
vals = dict.get(self, key, [])
vals.append(val)
dict.__setitem__(self, key, vals)
def __getitem__(self, key):
return dict.__getitem__(self, key)[-1]
@staticmethod
def create_from_query_string(qs):
ans = MultiDict()
for k, v in parse_qs(qs, keep_blank_values=True):
dict.__setitem__(ans, k.decode('utf-8'), [x.decode('utf-8') for x in v])
return ans
def update_from_listdict(self, ld):
for key, values in ld.iteritems():
for val in values:
self[key] = val
def items(self, duplicates=True):
for k, v in dict.iteritems(self):
if duplicates:
for x in v:
yield k, x
else:
yield k, v[-1]
iteritems = items
def values(self, duplicates=True):
for v in dict.itervalues(self):
if duplicates:
for x in v:
yield x
else:
yield v[-1]
itervalues = values
def set(self, key, val, replace=False):
if replace:
dict.__setitem__(self, key, [val])
else:
self[key] = val
def get(self, key, default=None, all=False):
if all:
try:
return dict.__getitem__(self, key)
except KeyError:
return []
return self.__getitem__(key)
def pop(self, key, default=None, all=False):
ans = dict.pop(self, key, default)
if ans is default:
return [] if all else default
return ans if all else ans[-1]