mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
Finish up HTTP header parsing and add some tests
This commit is contained in:
parent
18a04ed23d
commit
48f548236e
@ -6,13 +6,15 @@ from __future__ import (unicode_literals, division, absolute_import,
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
import httplib, socket, re
|
||||
import httplib, socket, re, os
|
||||
from io import BytesIO
|
||||
import repr as reprlib
|
||||
from urllib import unquote
|
||||
from urlparse import parse_qs
|
||||
from functools import partial
|
||||
|
||||
from calibre import as_unicode
|
||||
from calibre.srv.errors import MaxSizeExceeded, NonHTTPConnRequest
|
||||
from calibre.srv.utils import MultiDict
|
||||
|
||||
HTTP1 = 'HTTP/1.0'
|
||||
HTTP11 = 'HTTP/1.1'
|
||||
@ -62,19 +64,19 @@ def parse_request_uri(uri): # {{{
|
||||
# }}}
|
||||
|
||||
comma_separated_headers = {
|
||||
b'Accept', b'Accept-Charset', b'Accept-Encoding',
|
||||
b'Accept-Language', b'Accept-Ranges', b'Allow', b'Cache-Control',
|
||||
b'Connection', b'Content-Encoding', b'Content-Language', b'Expect',
|
||||
b'If-Match', b'If-None-Match', b'Pragma', b'Proxy-Authenticate', b'TE',
|
||||
b'Trailer', b'Transfer-Encoding', b'Upgrade', b'Vary', b'Via', b'Warning',
|
||||
b'WWW-Authenticate'
|
||||
'Accept', 'Accept-Charset', 'Accept-Encoding',
|
||||
'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control',
|
||||
'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
|
||||
'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
|
||||
'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
|
||||
'WWW-Authenticate'
|
||||
}
|
||||
|
||||
decoded_headers = {
|
||||
'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect',
|
||||
}
|
||||
} | comma_separated_headers
|
||||
|
||||
def read_headers(readline, max_line_size, hdict=None): # {{{
|
||||
def read_headers(readline): # {{{
|
||||
"""
|
||||
Read headers from the given stream into the given header dict.
|
||||
|
||||
@ -87,8 +89,27 @@ def read_headers(readline, max_line_size, hdict=None): # {{{
|
||||
This function raises ValueError when the read bytes violate the HTTP spec.
|
||||
You should probably return "400 Bad Request" if this happens.
|
||||
"""
|
||||
if hdict is None:
|
||||
hdict = {}
|
||||
hdict = MultiDict()
|
||||
|
||||
def safe_decode(hname, value):
|
||||
try:
|
||||
return value.decode('ascii')
|
||||
except UnicodeDecodeError:
|
||||
if hname in decoded_headers:
|
||||
raise
|
||||
return value
|
||||
|
||||
current_key = current_value = None
|
||||
|
||||
def commit():
|
||||
if current_key:
|
||||
key = current_key.decode('ascii')
|
||||
val = safe_decode(key, current_value)
|
||||
if key in comma_separated_headers:
|
||||
existing = hdict.pop(key)
|
||||
if existing is not None:
|
||||
val = existing + ', ' + val
|
||||
hdict[key] = val
|
||||
|
||||
while True:
|
||||
line = readline()
|
||||
@ -98,32 +119,22 @@ def read_headers(readline, max_line_size, hdict=None): # {{{
|
||||
|
||||
if line == b'\r\n':
|
||||
# Normal end of headers
|
||||
commit()
|
||||
break
|
||||
if not line.endswith(b'\r\n'):
|
||||
raise ValueError("HTTP requires CRLF terminators")
|
||||
|
||||
if line[0] in (b' ', b'\t'):
|
||||
if line[0] in b' \t':
|
||||
# It's a continuation line.
|
||||
v = line.strip()
|
||||
if current_key is None or current_value is None:
|
||||
raise ValueError('Orphaned continuation line')
|
||||
current_value += b' ' + line.strip()
|
||||
else:
|
||||
try:
|
||||
k, v = line.split(b':', 1)
|
||||
except ValueError:
|
||||
raise ValueError("Illegal header line.")
|
||||
k = k.strip().title()
|
||||
v = v.strip()
|
||||
hname = k.decode('ascii')
|
||||
|
||||
if k in comma_separated_headers:
|
||||
existing = hdict.get(hname)
|
||||
if existing:
|
||||
v = b", ".join((existing, v))
|
||||
try:
|
||||
v = v.decode('ascii')
|
||||
except UnicodeDecodeError:
|
||||
if hname in decoded_headers:
|
||||
raise
|
||||
hdict[hname] = v
|
||||
commit()
|
||||
current_key = current_value = None
|
||||
k, v = line.split(b':', 1)
|
||||
current_key = k.strip().title()
|
||||
current_value = v.strip()
|
||||
|
||||
return hdict
|
||||
# }}}
|
||||
@ -133,53 +144,130 @@ def http_communicate(conn):
|
||||
request_seen = False
|
||||
try:
|
||||
while True:
|
||||
# (re)set req to None so that if something goes wrong in
|
||||
# (re)set pair to None so that if something goes wrong in
|
||||
# the HTTPPair constructor, the error doesn't
|
||||
# get written to the previous request.
|
||||
req = None
|
||||
req = conn.server_loop.http_handler(conn)
|
||||
pair = None
|
||||
pair = conn.server_loop.http_handler(conn)
|
||||
|
||||
# This order of operations should guarantee correct pipelining.
|
||||
req.parse_request()
|
||||
if not req.ready:
|
||||
pair.parse_request()
|
||||
if not pair.ready:
|
||||
# Something went wrong in the parsing (and the server has
|
||||
# probably already made a simple_response). Return and
|
||||
# let the conn close.
|
||||
return
|
||||
|
||||
request_seen = True
|
||||
req.respond()
|
||||
if req.close_connection:
|
||||
pair.respond()
|
||||
if pair.close_connection:
|
||||
return
|
||||
except socket.timeout:
|
||||
# Don't error if we're between requests; only error
|
||||
# if 1) no request has been started at all, or 2) we're
|
||||
# in the middle of a request. This allows persistent
|
||||
# connections for HTTP/1.1
|
||||
if (not request_seen) or (req and req.started_request):
|
||||
if (not request_seen) or (pair and pair.started_request):
|
||||
# Don't bother writing the 408 if the response
|
||||
# has already started being written.
|
||||
if req and not req.sent_headers:
|
||||
req.simple_response(httplib.REQUEST_TIMEOUT, "Request Timeout")
|
||||
if pair and not pair.sent_headers:
|
||||
pair.simple_response(httplib.REQUEST_TIMEOUT, "Request Timeout")
|
||||
except NonHTTPConnRequest:
|
||||
raise
|
||||
except Exception:
|
||||
conn.server_loop.log.exception()
|
||||
if req and not req.sent_headers:
|
||||
req.simple_response(httplib.INTERNAL_SERVER_ERROR, "Internal Server Error")
|
||||
conn.server_loop.log.exception('Error serving request:', pair.path if pair else None)
|
||||
if pair and not pair.sent_headers:
|
||||
pair.simple_response(httplib.INTERNAL_SERVER_ERROR, "Internal Server Error")
|
||||
|
||||
class FixedSizeReader(object):
|
||||
|
||||
def __init__(self, socket_file, content_length):
|
||||
self.socket_file, self.remaining = socket_file, content_length
|
||||
|
||||
def __call__(self, size=-1):
|
||||
if size < 0:
|
||||
size = self.remaining
|
||||
size = min(self.remaining, size)
|
||||
if size < 1:
|
||||
return b''
|
||||
data = self.socket_file.read(size)
|
||||
self.remaining -= len(data)
|
||||
return data
|
||||
|
||||
|
||||
class ChunkedReader(object):
|
||||
|
||||
def __init__(self, socket_file, maxsize):
|
||||
self.socket_file, self.maxsize = socket_file, maxsize
|
||||
self.rbuf = BytesIO()
|
||||
self.bytes_read = 0
|
||||
self.finished = False
|
||||
|
||||
def check_size(self):
|
||||
if self.bytes_read > self.maxsize:
|
||||
raise MaxSizeExceeded('Request entity too large', self.bytes_read, self.maxsize)
|
||||
|
||||
def read_chunk(self):
|
||||
if self.finished:
|
||||
return
|
||||
line = self.socket_file.readline()
|
||||
self.bytes_read += len(line)
|
||||
self.check_size()
|
||||
chunk_size = line.strip().split(b';', 1)[0]
|
||||
try:
|
||||
chunk_size = int(line, 16) + 2
|
||||
except Exception:
|
||||
raise ValueError('%s is not a valid chunk size' % reprlib.repr(chunk_size))
|
||||
if chunk_size + self.bytes_read > self.maxsize:
|
||||
raise MaxSizeExceeded('Request entity too large', self.bytes_read + chunk_size, self.maxsize)
|
||||
chunk = self.socket_file.read(chunk_size)
|
||||
if len(chunk) < chunk_size:
|
||||
raise ValueError('Bad chunked encoding, chunk truncated: %d < %s' % (len(chunk), chunk_size))
|
||||
if not chunk.endswith(b'\r\n'):
|
||||
raise ValueError('Bad chunked encoding: %r != CRLF' % chunk[:-2])
|
||||
self.rbuf.seek(0, os.SEEK_END)
|
||||
self.bytes_read += chunk_size
|
||||
if chunk_size == 2:
|
||||
self.finished = True
|
||||
else:
|
||||
self.rbuf.write(chunk[:-2])
|
||||
|
||||
def __call__(self, size=-1):
|
||||
if size < 0:
|
||||
# Read all data
|
||||
while not self.finished:
|
||||
self.read_chunk()
|
||||
self.rbuf.seek(0)
|
||||
rv = self.rbuf.read()
|
||||
if rv:
|
||||
self.rbuf.truncate(0)
|
||||
return rv
|
||||
if size == 0:
|
||||
return b''
|
||||
while self.rbuf.tell() < size and not self.finished:
|
||||
self.read_chunk()
|
||||
data = self.rbuf.getvalue()
|
||||
self.rbuf.truncate(0)
|
||||
if size < len(data):
|
||||
self.rbuf.write(data[size:])
|
||||
return data[:size]
|
||||
return data
|
||||
|
||||
|
||||
class HTTPPair(object):
|
||||
|
||||
''' Represents a HTTP request/response pair '''
|
||||
|
||||
def __init__(self, conn):
|
||||
def __init__(self, conn, handle_request):
|
||||
self.conn = conn
|
||||
self.server_loop = conn.server_loop
|
||||
self.max_header_line_size = self.server_loop.max_header_line_size
|
||||
self.scheme = 'http' if self.server_loop.ssl_context is None else 'https'
|
||||
self.inheaders = {}
|
||||
self.outheaders = []
|
||||
self.inheaders = MultiDict()
|
||||
self.outheaders = MultiDict()
|
||||
self.handle_request = handle_request
|
||||
self.path = ()
|
||||
self.qs = MultiDict()
|
||||
|
||||
"""When True, the request has been parsed and is ready to begin generating
|
||||
the response. When False, signals the calling Connection that the response
|
||||
@ -198,6 +286,9 @@ class HTTPPair(object):
|
||||
self.status = b''
|
||||
self.sent_headers = False
|
||||
|
||||
self.request_content_length = 0
|
||||
self.chunked_read = False
|
||||
|
||||
def parse_request(self):
|
||||
"""Parse the next HTTP request start-line and message-headers."""
|
||||
try:
|
||||
@ -273,7 +364,7 @@ class HTTPPair(object):
|
||||
if b'?' in path:
|
||||
path, qs = path.split(b'?', 1)
|
||||
try:
|
||||
self.qs = {k.decode('utf-8'):tuple(x.decode('utf-8') for x in v) for k, v in parse_qs(qs, keep_blank_values=True).iteritems()}
|
||||
self.qs = MultiDict.create_from_query_string(qs)
|
||||
except Exception:
|
||||
self.simple_response(httplib.BAD_REQUEST, "Bad Request", "Malformed Request-Line",
|
||||
'Unparseable query string')
|
||||
@ -293,13 +384,13 @@ class HTTPPair(object):
|
||||
def read_request_headers(self):
|
||||
# then all the http headers
|
||||
try:
|
||||
read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size), self.inheaders)
|
||||
content_length = int(self.inheaders.get('Content-Length', 0))
|
||||
self.inheaders = read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size))
|
||||
self.request_content_length = int(self.inheaders.get('Content-Length', 0))
|
||||
except ValueError as e:
|
||||
self.simple_response(httplib.BAD_REQUEST, "Bad Request", as_unicode(e))
|
||||
return False
|
||||
|
||||
if content_length > self.server_loop.max_request_body_size:
|
||||
if self.request_content_length > self.server_loop.max_request_body_size:
|
||||
self.simple_response(
|
||||
httplib.REQUEST_ENTITY_TOO_LARGE, "Request Entity Too Large",
|
||||
"The entity sent with the request exceeds the maximum "
|
||||
|
1
src/calibre/srv/tests/__init__.py
Normal file
1
src/calibre/srv/tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
22
src/calibre/srv/tests/base.py
Normal file
22
src/calibre/srv/tests/base.py
Normal file
@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env python2
|
||||
# vim:fileencoding=UTF-8:ts=4:sw=4:sta:et:sts=4:ai
|
||||
from __future__ import (unicode_literals, division, absolute_import,
|
||||
print_function)
|
||||
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2011, Kovid Goyal <kovid@kovidgoyal.net>'
|
||||
__docformat__ = 'restructuredtext en'
|
||||
|
||||
import unittest, shutil
|
||||
from functools import partial
|
||||
|
||||
rmtree = partial(shutil.rmtree, ignore_errors=True)
|
||||
|
||||
class BaseTest(unittest.TestCase):
|
||||
|
||||
longMessage = True
|
||||
maxDiff = None
|
||||
|
||||
ae = unittest.TestCase.assertEqual
|
||||
|
||||
|
47
src/calibre/srv/tests/http.py
Normal file
47
src/calibre/srv/tests/http.py
Normal file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python2
|
||||
# vim:fileencoding=utf-8
|
||||
from __future__ import (unicode_literals, division, absolute_import,
|
||||
print_function)
|
||||
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
import textwrap
|
||||
from io import BytesIO
|
||||
|
||||
from calibre.srv.tests.base import BaseTest
|
||||
|
||||
def headers(raw):
|
||||
return BytesIO(textwrap.dedent(raw).encode('utf-8'))
|
||||
|
||||
class TestHTTP(BaseTest):
|
||||
|
||||
def test_header_parsing(self):
|
||||
'Test parsing of HTTP headers'
|
||||
from calibre.srv.http import read_headers
|
||||
|
||||
def test(name, raw, **kwargs):
|
||||
hdict = read_headers(headers(raw).readline)
|
||||
self.assertSetEqual(set(hdict.items()), {(k.replace('_', '-').title(), v) for k, v in kwargs.iteritems()}, name + ' failed')
|
||||
|
||||
test('Continuation line parsing',
|
||||
'''\
|
||||
a: one\r
|
||||
b: two\r
|
||||
2\r
|
||||
\t3\r
|
||||
c:three\r
|
||||
\r\n''', a='one', b='two 2 3', c='three')
|
||||
|
||||
test('Non-ascii headers parsing',
|
||||
'''\
|
||||
a:mūs\r
|
||||
\r\n''', a='mūs'.encode('utf-8'))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
read_headers(headers('Connection:mūs\r\n').readline)
|
||||
read_headers(headers('Connection\r\n').readline)
|
||||
read_headers(headers('Connection:a\r\n').readline)
|
||||
read_headers(headers('Connection:a\n').readline)
|
||||
read_headers(headers(' Connection:a\n').readline)
|
||||
|
101
src/calibre/srv/tests/main.py
Normal file
101
src/calibre/srv/tests/main.py
Normal file
@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python2
|
||||
# vim:fileencoding=utf-8
|
||||
from __future__ import (unicode_literals, division, absolute_import,
|
||||
print_function)
|
||||
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
import unittest, os, argparse, time, functools
|
||||
|
||||
try:
|
||||
import init_calibre
|
||||
del init_calibre
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
def no_endl(f):
|
||||
@functools.wraps(f)
|
||||
def func(*args, **kwargs):
|
||||
self = f.__self__
|
||||
orig = self.stream.writeln
|
||||
self.stream.writeln = self.stream.write
|
||||
try:
|
||||
return f(*args, **kwargs)
|
||||
finally:
|
||||
self.stream.writeln = orig
|
||||
return func
|
||||
|
||||
class TestResult(unittest.TextTestResult):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TestResult, self).__init__(*args, **kwargs)
|
||||
self.start_time = {}
|
||||
for x in ('Success', 'Error', 'Failure', 'Skip', 'ExpectedFailure', 'UnexpectedSuccess'):
|
||||
x = 'add' + x
|
||||
setattr(self, x, no_endl(getattr(self, x)))
|
||||
self.times = {}
|
||||
|
||||
def startTest(self, test):
|
||||
self.start_time[test] = time.time()
|
||||
return super(TestResult, self).startTest(test)
|
||||
|
||||
def stopTest(self, test):
|
||||
orig = self.stream.writeln
|
||||
self.stream.writeln = self.stream.write
|
||||
super(TestResult, self).stopTest(test)
|
||||
elapsed = time.time()
|
||||
elapsed -= self.start_time.get(test, elapsed)
|
||||
self.times[test] = elapsed
|
||||
self.stream.writeln = orig
|
||||
self.stream.writeln(' [%.1g s]' % elapsed)
|
||||
|
||||
def stopTestRun(self):
|
||||
super(TestResult, self).stopTestRun()
|
||||
if self.wasSuccessful():
|
||||
tests = sorted(self.times, key=self.times.get, reverse=True)
|
||||
slowest = ['%s [%g s]' % (t.id(), self.times[t]) for t in tests[:3]]
|
||||
if len(slowest) > 1:
|
||||
self.stream.writeln('\nSlowest tests: %s' % ' '.join(slowest))
|
||||
|
||||
def find_tests():
|
||||
return unittest.defaultTestLoader.discover(os.path.dirname(os.path.abspath(__file__)), pattern='*.py')
|
||||
|
||||
def run_tests(find_tests=find_tests):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('name', nargs='?', default=None,
|
||||
help='The name of the test to run, for e.g. writing.WritingTest.many_many_basic or .many_many_basic for a shortcut')
|
||||
args = parser.parse_args()
|
||||
if args.name and args.name.startswith('.'):
|
||||
tests = find_tests()
|
||||
q = args.name[1:]
|
||||
if not q.startswith('test_'):
|
||||
q = 'test_' + q
|
||||
ans = None
|
||||
try:
|
||||
for suite in tests:
|
||||
for test in suite._tests:
|
||||
if test.__class__.__name__ == 'ModuleImportFailure':
|
||||
raise Exception('Failed to import a test module: %s' % test)
|
||||
for s in test:
|
||||
if s._testMethodName == q:
|
||||
ans = s
|
||||
raise StopIteration()
|
||||
except StopIteration:
|
||||
pass
|
||||
if ans is None:
|
||||
print ('No test named %s found' % args.name)
|
||||
raise SystemExit(1)
|
||||
tests = ans
|
||||
else:
|
||||
tests = unittest.defaultTestLoader.loadTestsFromName(args.name) if args.name else find_tests()
|
||||
r = unittest.TextTestRunner
|
||||
r.resultclass = TestResult
|
||||
r(verbosity=4).run(tests)
|
||||
|
||||
if __name__ == '__main__':
|
||||
from calibre.utils.config_base import reset_tweaks_to_default
|
||||
from calibre.ebooks.metadata.book.base import reset_field_metadata
|
||||
reset_tweaks_to_default()
|
||||
reset_field_metadata()
|
||||
run_tests()
|
69
src/calibre/srv/utils.py
Normal file
69
src/calibre/srv/utils.py
Normal file
@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python2
|
||||
# vim:fileencoding=utf-8
|
||||
from __future__ import (unicode_literals, division, absolute_import,
|
||||
print_function)
|
||||
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
from urlparse import parse_qs
|
||||
|
||||
class MultiDict(dict):
|
||||
|
||||
def __setitem__(self, key, val):
|
||||
vals = dict.get(self, key, [])
|
||||
vals.append(val)
|
||||
dict.__setitem__(self, key, vals)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return dict.__getitem__(self, key)[-1]
|
||||
|
||||
@staticmethod
|
||||
def create_from_query_string(qs):
|
||||
ans = MultiDict()
|
||||
for k, v in parse_qs(qs, keep_blank_values=True):
|
||||
dict.__setitem__(ans, k.decode('utf-8'), [x.decode('utf-8') for x in v])
|
||||
return ans
|
||||
|
||||
def update_from_listdict(self, ld):
|
||||
for key, values in ld.iteritems():
|
||||
for val in values:
|
||||
self[key] = val
|
||||
|
||||
def items(self, duplicates=True):
|
||||
for k, v in dict.iteritems(self):
|
||||
if duplicates:
|
||||
for x in v:
|
||||
yield k, x
|
||||
else:
|
||||
yield k, v[-1]
|
||||
iteritems = items
|
||||
|
||||
def values(self, duplicates=True):
|
||||
for v in dict.itervalues(self):
|
||||
if duplicates:
|
||||
for x in v:
|
||||
yield x
|
||||
else:
|
||||
yield v[-1]
|
||||
itervalues = values
|
||||
|
||||
def set(self, key, val, replace=False):
|
||||
if replace:
|
||||
dict.__setitem__(self, key, [val])
|
||||
else:
|
||||
self[key] = val
|
||||
|
||||
def get(self, key, default=None, all=False):
|
||||
if all:
|
||||
try:
|
||||
return dict.__getitem__(self, key)
|
||||
except KeyError:
|
||||
return []
|
||||
return self.__getitem__(key)
|
||||
|
||||
def pop(self, key, default=None, all=False):
|
||||
ans = dict.pop(self, key, default)
|
||||
if ans is default:
|
||||
return [] if all else default
|
||||
return ans if all else ans[-1]
|
Loading…
x
Reference in New Issue
Block a user