mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
Finish up HTTP header parsing and add some tests
This commit is contained in:
parent
18a04ed23d
commit
48f548236e
@ -6,13 +6,15 @@ from __future__ import (unicode_literals, division, absolute_import,
|
|||||||
__license__ = 'GPL v3'
|
__license__ = 'GPL v3'
|
||||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||||
|
|
||||||
import httplib, socket, re
|
import httplib, socket, re, os
|
||||||
|
from io import BytesIO
|
||||||
|
import repr as reprlib
|
||||||
from urllib import unquote
|
from urllib import unquote
|
||||||
from urlparse import parse_qs
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from calibre import as_unicode
|
from calibre import as_unicode
|
||||||
from calibre.srv.errors import MaxSizeExceeded, NonHTTPConnRequest
|
from calibre.srv.errors import MaxSizeExceeded, NonHTTPConnRequest
|
||||||
|
from calibre.srv.utils import MultiDict
|
||||||
|
|
||||||
HTTP1 = 'HTTP/1.0'
|
HTTP1 = 'HTTP/1.0'
|
||||||
HTTP11 = 'HTTP/1.1'
|
HTTP11 = 'HTTP/1.1'
|
||||||
@ -62,19 +64,19 @@ def parse_request_uri(uri): # {{{
|
|||||||
# }}}
|
# }}}
|
||||||
|
|
||||||
comma_separated_headers = {
|
comma_separated_headers = {
|
||||||
b'Accept', b'Accept-Charset', b'Accept-Encoding',
|
'Accept', 'Accept-Charset', 'Accept-Encoding',
|
||||||
b'Accept-Language', b'Accept-Ranges', b'Allow', b'Cache-Control',
|
'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control',
|
||||||
b'Connection', b'Content-Encoding', b'Content-Language', b'Expect',
|
'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
|
||||||
b'If-Match', b'If-None-Match', b'Pragma', b'Proxy-Authenticate', b'TE',
|
'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
|
||||||
b'Trailer', b'Transfer-Encoding', b'Upgrade', b'Vary', b'Via', b'Warning',
|
'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
|
||||||
b'WWW-Authenticate'
|
'WWW-Authenticate'
|
||||||
}
|
}
|
||||||
|
|
||||||
decoded_headers = {
|
decoded_headers = {
|
||||||
'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect',
|
'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect',
|
||||||
}
|
} | comma_separated_headers
|
||||||
|
|
||||||
def read_headers(readline, max_line_size, hdict=None): # {{{
|
def read_headers(readline): # {{{
|
||||||
"""
|
"""
|
||||||
Read headers from the given stream into the given header dict.
|
Read headers from the given stream into the given header dict.
|
||||||
|
|
||||||
@ -87,8 +89,27 @@ def read_headers(readline, max_line_size, hdict=None): # {{{
|
|||||||
This function raises ValueError when the read bytes violate the HTTP spec.
|
This function raises ValueError when the read bytes violate the HTTP spec.
|
||||||
You should probably return "400 Bad Request" if this happens.
|
You should probably return "400 Bad Request" if this happens.
|
||||||
"""
|
"""
|
||||||
if hdict is None:
|
hdict = MultiDict()
|
||||||
hdict = {}
|
|
||||||
|
def safe_decode(hname, value):
|
||||||
|
try:
|
||||||
|
return value.decode('ascii')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
if hname in decoded_headers:
|
||||||
|
raise
|
||||||
|
return value
|
||||||
|
|
||||||
|
current_key = current_value = None
|
||||||
|
|
||||||
|
def commit():
|
||||||
|
if current_key:
|
||||||
|
key = current_key.decode('ascii')
|
||||||
|
val = safe_decode(key, current_value)
|
||||||
|
if key in comma_separated_headers:
|
||||||
|
existing = hdict.pop(key)
|
||||||
|
if existing is not None:
|
||||||
|
val = existing + ', ' + val
|
||||||
|
hdict[key] = val
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
line = readline()
|
line = readline()
|
||||||
@ -98,32 +119,22 @@ def read_headers(readline, max_line_size, hdict=None): # {{{
|
|||||||
|
|
||||||
if line == b'\r\n':
|
if line == b'\r\n':
|
||||||
# Normal end of headers
|
# Normal end of headers
|
||||||
|
commit()
|
||||||
break
|
break
|
||||||
if not line.endswith(b'\r\n'):
|
if not line.endswith(b'\r\n'):
|
||||||
raise ValueError("HTTP requires CRLF terminators")
|
raise ValueError("HTTP requires CRLF terminators")
|
||||||
|
|
||||||
if line[0] in (b' ', b'\t'):
|
if line[0] in b' \t':
|
||||||
# It's a continuation line.
|
# It's a continuation line.
|
||||||
v = line.strip()
|
if current_key is None or current_value is None:
|
||||||
|
raise ValueError('Orphaned continuation line')
|
||||||
|
current_value += b' ' + line.strip()
|
||||||
else:
|
else:
|
||||||
try:
|
commit()
|
||||||
|
current_key = current_value = None
|
||||||
k, v = line.split(b':', 1)
|
k, v = line.split(b':', 1)
|
||||||
except ValueError:
|
current_key = k.strip().title()
|
||||||
raise ValueError("Illegal header line.")
|
current_value = v.strip()
|
||||||
k = k.strip().title()
|
|
||||||
v = v.strip()
|
|
||||||
hname = k.decode('ascii')
|
|
||||||
|
|
||||||
if k in comma_separated_headers:
|
|
||||||
existing = hdict.get(hname)
|
|
||||||
if existing:
|
|
||||||
v = b", ".join((existing, v))
|
|
||||||
try:
|
|
||||||
v = v.decode('ascii')
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
if hname in decoded_headers:
|
|
||||||
raise
|
|
||||||
hdict[hname] = v
|
|
||||||
|
|
||||||
return hdict
|
return hdict
|
||||||
# }}}
|
# }}}
|
||||||
@ -133,53 +144,130 @@ def http_communicate(conn):
|
|||||||
request_seen = False
|
request_seen = False
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
# (re)set req to None so that if something goes wrong in
|
# (re)set pair to None so that if something goes wrong in
|
||||||
# the HTTPPair constructor, the error doesn't
|
# the HTTPPair constructor, the error doesn't
|
||||||
# get written to the previous request.
|
# get written to the previous request.
|
||||||
req = None
|
pair = None
|
||||||
req = conn.server_loop.http_handler(conn)
|
pair = conn.server_loop.http_handler(conn)
|
||||||
|
|
||||||
# This order of operations should guarantee correct pipelining.
|
# This order of operations should guarantee correct pipelining.
|
||||||
req.parse_request()
|
pair.parse_request()
|
||||||
if not req.ready:
|
if not pair.ready:
|
||||||
# Something went wrong in the parsing (and the server has
|
# Something went wrong in the parsing (and the server has
|
||||||
# probably already made a simple_response). Return and
|
# probably already made a simple_response). Return and
|
||||||
# let the conn close.
|
# let the conn close.
|
||||||
return
|
return
|
||||||
|
|
||||||
request_seen = True
|
request_seen = True
|
||||||
req.respond()
|
pair.respond()
|
||||||
if req.close_connection:
|
if pair.close_connection:
|
||||||
return
|
return
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
# Don't error if we're between requests; only error
|
# Don't error if we're between requests; only error
|
||||||
# if 1) no request has been started at all, or 2) we're
|
# if 1) no request has been started at all, or 2) we're
|
||||||
# in the middle of a request. This allows persistent
|
# in the middle of a request. This allows persistent
|
||||||
# connections for HTTP/1.1
|
# connections for HTTP/1.1
|
||||||
if (not request_seen) or (req and req.started_request):
|
if (not request_seen) or (pair and pair.started_request):
|
||||||
# Don't bother writing the 408 if the response
|
# Don't bother writing the 408 if the response
|
||||||
# has already started being written.
|
# has already started being written.
|
||||||
if req and not req.sent_headers:
|
if pair and not pair.sent_headers:
|
||||||
req.simple_response(httplib.REQUEST_TIMEOUT, "Request Timeout")
|
pair.simple_response(httplib.REQUEST_TIMEOUT, "Request Timeout")
|
||||||
except NonHTTPConnRequest:
|
except NonHTTPConnRequest:
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
conn.server_loop.log.exception()
|
conn.server_loop.log.exception('Error serving request:', pair.path if pair else None)
|
||||||
if req and not req.sent_headers:
|
if pair and not pair.sent_headers:
|
||||||
req.simple_response(httplib.INTERNAL_SERVER_ERROR, "Internal Server Error")
|
pair.simple_response(httplib.INTERNAL_SERVER_ERROR, "Internal Server Error")
|
||||||
|
|
||||||
|
class FixedSizeReader(object):
|
||||||
|
|
||||||
|
def __init__(self, socket_file, content_length):
|
||||||
|
self.socket_file, self.remaining = socket_file, content_length
|
||||||
|
|
||||||
|
def __call__(self, size=-1):
|
||||||
|
if size < 0:
|
||||||
|
size = self.remaining
|
||||||
|
size = min(self.remaining, size)
|
||||||
|
if size < 1:
|
||||||
|
return b''
|
||||||
|
data = self.socket_file.read(size)
|
||||||
|
self.remaining -= len(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkedReader(object):
|
||||||
|
|
||||||
|
def __init__(self, socket_file, maxsize):
|
||||||
|
self.socket_file, self.maxsize = socket_file, maxsize
|
||||||
|
self.rbuf = BytesIO()
|
||||||
|
self.bytes_read = 0
|
||||||
|
self.finished = False
|
||||||
|
|
||||||
|
def check_size(self):
|
||||||
|
if self.bytes_read > self.maxsize:
|
||||||
|
raise MaxSizeExceeded('Request entity too large', self.bytes_read, self.maxsize)
|
||||||
|
|
||||||
|
def read_chunk(self):
|
||||||
|
if self.finished:
|
||||||
|
return
|
||||||
|
line = self.socket_file.readline()
|
||||||
|
self.bytes_read += len(line)
|
||||||
|
self.check_size()
|
||||||
|
chunk_size = line.strip().split(b';', 1)[0]
|
||||||
|
try:
|
||||||
|
chunk_size = int(line, 16) + 2
|
||||||
|
except Exception:
|
||||||
|
raise ValueError('%s is not a valid chunk size' % reprlib.repr(chunk_size))
|
||||||
|
if chunk_size + self.bytes_read > self.maxsize:
|
||||||
|
raise MaxSizeExceeded('Request entity too large', self.bytes_read + chunk_size, self.maxsize)
|
||||||
|
chunk = self.socket_file.read(chunk_size)
|
||||||
|
if len(chunk) < chunk_size:
|
||||||
|
raise ValueError('Bad chunked encoding, chunk truncated: %d < %s' % (len(chunk), chunk_size))
|
||||||
|
if not chunk.endswith(b'\r\n'):
|
||||||
|
raise ValueError('Bad chunked encoding: %r != CRLF' % chunk[:-2])
|
||||||
|
self.rbuf.seek(0, os.SEEK_END)
|
||||||
|
self.bytes_read += chunk_size
|
||||||
|
if chunk_size == 2:
|
||||||
|
self.finished = True
|
||||||
|
else:
|
||||||
|
self.rbuf.write(chunk[:-2])
|
||||||
|
|
||||||
|
def __call__(self, size=-1):
|
||||||
|
if size < 0:
|
||||||
|
# Read all data
|
||||||
|
while not self.finished:
|
||||||
|
self.read_chunk()
|
||||||
|
self.rbuf.seek(0)
|
||||||
|
rv = self.rbuf.read()
|
||||||
|
if rv:
|
||||||
|
self.rbuf.truncate(0)
|
||||||
|
return rv
|
||||||
|
if size == 0:
|
||||||
|
return b''
|
||||||
|
while self.rbuf.tell() < size and not self.finished:
|
||||||
|
self.read_chunk()
|
||||||
|
data = self.rbuf.getvalue()
|
||||||
|
self.rbuf.truncate(0)
|
||||||
|
if size < len(data):
|
||||||
|
self.rbuf.write(data[size:])
|
||||||
|
return data[:size]
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class HTTPPair(object):
|
class HTTPPair(object):
|
||||||
|
|
||||||
''' Represents a HTTP request/response pair '''
|
''' Represents a HTTP request/response pair '''
|
||||||
|
|
||||||
def __init__(self, conn):
|
def __init__(self, conn, handle_request):
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
self.server_loop = conn.server_loop
|
self.server_loop = conn.server_loop
|
||||||
self.max_header_line_size = self.server_loop.max_header_line_size
|
self.max_header_line_size = self.server_loop.max_header_line_size
|
||||||
self.scheme = 'http' if self.server_loop.ssl_context is None else 'https'
|
self.scheme = 'http' if self.server_loop.ssl_context is None else 'https'
|
||||||
self.inheaders = {}
|
self.inheaders = MultiDict()
|
||||||
self.outheaders = []
|
self.outheaders = MultiDict()
|
||||||
|
self.handle_request = handle_request
|
||||||
|
self.path = ()
|
||||||
|
self.qs = MultiDict()
|
||||||
|
|
||||||
"""When True, the request has been parsed and is ready to begin generating
|
"""When True, the request has been parsed and is ready to begin generating
|
||||||
the response. When False, signals the calling Connection that the response
|
the response. When False, signals the calling Connection that the response
|
||||||
@ -198,6 +286,9 @@ class HTTPPair(object):
|
|||||||
self.status = b''
|
self.status = b''
|
||||||
self.sent_headers = False
|
self.sent_headers = False
|
||||||
|
|
||||||
|
self.request_content_length = 0
|
||||||
|
self.chunked_read = False
|
||||||
|
|
||||||
def parse_request(self):
|
def parse_request(self):
|
||||||
"""Parse the next HTTP request start-line and message-headers."""
|
"""Parse the next HTTP request start-line and message-headers."""
|
||||||
try:
|
try:
|
||||||
@ -273,7 +364,7 @@ class HTTPPair(object):
|
|||||||
if b'?' in path:
|
if b'?' in path:
|
||||||
path, qs = path.split(b'?', 1)
|
path, qs = path.split(b'?', 1)
|
||||||
try:
|
try:
|
||||||
self.qs = {k.decode('utf-8'):tuple(x.decode('utf-8') for x in v) for k, v in parse_qs(qs, keep_blank_values=True).iteritems()}
|
self.qs = MultiDict.create_from_query_string(qs)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.simple_response(httplib.BAD_REQUEST, "Bad Request", "Malformed Request-Line",
|
self.simple_response(httplib.BAD_REQUEST, "Bad Request", "Malformed Request-Line",
|
||||||
'Unparseable query string')
|
'Unparseable query string')
|
||||||
@ -293,13 +384,13 @@ class HTTPPair(object):
|
|||||||
def read_request_headers(self):
|
def read_request_headers(self):
|
||||||
# then all the http headers
|
# then all the http headers
|
||||||
try:
|
try:
|
||||||
read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size), self.inheaders)
|
self.inheaders = read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size))
|
||||||
content_length = int(self.inheaders.get('Content-Length', 0))
|
self.request_content_length = int(self.inheaders.get('Content-Length', 0))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
self.simple_response(httplib.BAD_REQUEST, "Bad Request", as_unicode(e))
|
self.simple_response(httplib.BAD_REQUEST, "Bad Request", as_unicode(e))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if content_length > self.server_loop.max_request_body_size:
|
if self.request_content_length > self.server_loop.max_request_body_size:
|
||||||
self.simple_response(
|
self.simple_response(
|
||||||
httplib.REQUEST_ENTITY_TOO_LARGE, "Request Entity Too Large",
|
httplib.REQUEST_ENTITY_TOO_LARGE, "Request Entity Too Large",
|
||||||
"The entity sent with the request exceeds the maximum "
|
"The entity sent with the request exceeds the maximum "
|
||||||
|
1
src/calibre/srv/tests/__init__.py
Normal file
1
src/calibre/srv/tests/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
22
src/calibre/srv/tests/base.py
Normal file
22
src/calibre/srv/tests/base.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#!/usr/bin/env python2
|
||||||
|
# vim:fileencoding=UTF-8:ts=4:sw=4:sta:et:sts=4:ai
|
||||||
|
from __future__ import (unicode_literals, division, absolute_import,
|
||||||
|
print_function)
|
||||||
|
|
||||||
|
__license__ = 'GPL v3'
|
||||||
|
__copyright__ = '2011, Kovid Goyal <kovid@kovidgoyal.net>'
|
||||||
|
__docformat__ = 'restructuredtext en'
|
||||||
|
|
||||||
|
import unittest, shutil
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
rmtree = partial(shutil.rmtree, ignore_errors=True)
|
||||||
|
|
||||||
|
class BaseTest(unittest.TestCase):
|
||||||
|
|
||||||
|
longMessage = True
|
||||||
|
maxDiff = None
|
||||||
|
|
||||||
|
ae = unittest.TestCase.assertEqual
|
||||||
|
|
||||||
|
|
47
src/calibre/srv/tests/http.py
Normal file
47
src/calibre/srv/tests/http.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
#!/usr/bin/env python2
|
||||||
|
# vim:fileencoding=utf-8
|
||||||
|
from __future__ import (unicode_literals, division, absolute_import,
|
||||||
|
print_function)
|
||||||
|
|
||||||
|
__license__ = 'GPL v3'
|
||||||
|
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from calibre.srv.tests.base import BaseTest
|
||||||
|
|
||||||
|
def headers(raw):
|
||||||
|
return BytesIO(textwrap.dedent(raw).encode('utf-8'))
|
||||||
|
|
||||||
|
class TestHTTP(BaseTest):
|
||||||
|
|
||||||
|
def test_header_parsing(self):
|
||||||
|
'Test parsing of HTTP headers'
|
||||||
|
from calibre.srv.http import read_headers
|
||||||
|
|
||||||
|
def test(name, raw, **kwargs):
|
||||||
|
hdict = read_headers(headers(raw).readline)
|
||||||
|
self.assertSetEqual(set(hdict.items()), {(k.replace('_', '-').title(), v) for k, v in kwargs.iteritems()}, name + ' failed')
|
||||||
|
|
||||||
|
test('Continuation line parsing',
|
||||||
|
'''\
|
||||||
|
a: one\r
|
||||||
|
b: two\r
|
||||||
|
2\r
|
||||||
|
\t3\r
|
||||||
|
c:three\r
|
||||||
|
\r\n''', a='one', b='two 2 3', c='three')
|
||||||
|
|
||||||
|
test('Non-ascii headers parsing',
|
||||||
|
'''\
|
||||||
|
a:mūs\r
|
||||||
|
\r\n''', a='mūs'.encode('utf-8'))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
read_headers(headers('Connection:mūs\r\n').readline)
|
||||||
|
read_headers(headers('Connection\r\n').readline)
|
||||||
|
read_headers(headers('Connection:a\r\n').readline)
|
||||||
|
read_headers(headers('Connection:a\n').readline)
|
||||||
|
read_headers(headers(' Connection:a\n').readline)
|
||||||
|
|
101
src/calibre/srv/tests/main.py
Normal file
101
src/calibre/srv/tests/main.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
#!/usr/bin/env python2
|
||||||
|
# vim:fileencoding=utf-8
|
||||||
|
from __future__ import (unicode_literals, division, absolute_import,
|
||||||
|
print_function)
|
||||||
|
|
||||||
|
__license__ = 'GPL v3'
|
||||||
|
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||||
|
|
||||||
|
import unittest, os, argparse, time, functools
|
||||||
|
|
||||||
|
try:
|
||||||
|
import init_calibre
|
||||||
|
del init_calibre
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def no_endl(f):
|
||||||
|
@functools.wraps(f)
|
||||||
|
def func(*args, **kwargs):
|
||||||
|
self = f.__self__
|
||||||
|
orig = self.stream.writeln
|
||||||
|
self.stream.writeln = self.stream.write
|
||||||
|
try:
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
self.stream.writeln = orig
|
||||||
|
return func
|
||||||
|
|
||||||
|
class TestResult(unittest.TextTestResult):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(TestResult, self).__init__(*args, **kwargs)
|
||||||
|
self.start_time = {}
|
||||||
|
for x in ('Success', 'Error', 'Failure', 'Skip', 'ExpectedFailure', 'UnexpectedSuccess'):
|
||||||
|
x = 'add' + x
|
||||||
|
setattr(self, x, no_endl(getattr(self, x)))
|
||||||
|
self.times = {}
|
||||||
|
|
||||||
|
def startTest(self, test):
|
||||||
|
self.start_time[test] = time.time()
|
||||||
|
return super(TestResult, self).startTest(test)
|
||||||
|
|
||||||
|
def stopTest(self, test):
|
||||||
|
orig = self.stream.writeln
|
||||||
|
self.stream.writeln = self.stream.write
|
||||||
|
super(TestResult, self).stopTest(test)
|
||||||
|
elapsed = time.time()
|
||||||
|
elapsed -= self.start_time.get(test, elapsed)
|
||||||
|
self.times[test] = elapsed
|
||||||
|
self.stream.writeln = orig
|
||||||
|
self.stream.writeln(' [%.1g s]' % elapsed)
|
||||||
|
|
||||||
|
def stopTestRun(self):
|
||||||
|
super(TestResult, self).stopTestRun()
|
||||||
|
if self.wasSuccessful():
|
||||||
|
tests = sorted(self.times, key=self.times.get, reverse=True)
|
||||||
|
slowest = ['%s [%g s]' % (t.id(), self.times[t]) for t in tests[:3]]
|
||||||
|
if len(slowest) > 1:
|
||||||
|
self.stream.writeln('\nSlowest tests: %s' % ' '.join(slowest))
|
||||||
|
|
||||||
|
def find_tests():
|
||||||
|
return unittest.defaultTestLoader.discover(os.path.dirname(os.path.abspath(__file__)), pattern='*.py')
|
||||||
|
|
||||||
|
def run_tests(find_tests=find_tests):
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('name', nargs='?', default=None,
|
||||||
|
help='The name of the test to run, for e.g. writing.WritingTest.many_many_basic or .many_many_basic for a shortcut')
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.name and args.name.startswith('.'):
|
||||||
|
tests = find_tests()
|
||||||
|
q = args.name[1:]
|
||||||
|
if not q.startswith('test_'):
|
||||||
|
q = 'test_' + q
|
||||||
|
ans = None
|
||||||
|
try:
|
||||||
|
for suite in tests:
|
||||||
|
for test in suite._tests:
|
||||||
|
if test.__class__.__name__ == 'ModuleImportFailure':
|
||||||
|
raise Exception('Failed to import a test module: %s' % test)
|
||||||
|
for s in test:
|
||||||
|
if s._testMethodName == q:
|
||||||
|
ans = s
|
||||||
|
raise StopIteration()
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
if ans is None:
|
||||||
|
print ('No test named %s found' % args.name)
|
||||||
|
raise SystemExit(1)
|
||||||
|
tests = ans
|
||||||
|
else:
|
||||||
|
tests = unittest.defaultTestLoader.loadTestsFromName(args.name) if args.name else find_tests()
|
||||||
|
r = unittest.TextTestRunner
|
||||||
|
r.resultclass = TestResult
|
||||||
|
r(verbosity=4).run(tests)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from calibre.utils.config_base import reset_tweaks_to_default
|
||||||
|
from calibre.ebooks.metadata.book.base import reset_field_metadata
|
||||||
|
reset_tweaks_to_default()
|
||||||
|
reset_field_metadata()
|
||||||
|
run_tests()
|
69
src/calibre/srv/utils.py
Normal file
69
src/calibre/srv/utils.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
#!/usr/bin/env python2
|
||||||
|
# vim:fileencoding=utf-8
|
||||||
|
from __future__ import (unicode_literals, division, absolute_import,
|
||||||
|
print_function)
|
||||||
|
|
||||||
|
__license__ = 'GPL v3'
|
||||||
|
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||||
|
|
||||||
|
from urlparse import parse_qs
|
||||||
|
|
||||||
|
class MultiDict(dict):
|
||||||
|
|
||||||
|
def __setitem__(self, key, val):
|
||||||
|
vals = dict.get(self, key, [])
|
||||||
|
vals.append(val)
|
||||||
|
dict.__setitem__(self, key, vals)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return dict.__getitem__(self, key)[-1]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_from_query_string(qs):
|
||||||
|
ans = MultiDict()
|
||||||
|
for k, v in parse_qs(qs, keep_blank_values=True):
|
||||||
|
dict.__setitem__(ans, k.decode('utf-8'), [x.decode('utf-8') for x in v])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def update_from_listdict(self, ld):
|
||||||
|
for key, values in ld.iteritems():
|
||||||
|
for val in values:
|
||||||
|
self[key] = val
|
||||||
|
|
||||||
|
def items(self, duplicates=True):
|
||||||
|
for k, v in dict.iteritems(self):
|
||||||
|
if duplicates:
|
||||||
|
for x in v:
|
||||||
|
yield k, x
|
||||||
|
else:
|
||||||
|
yield k, v[-1]
|
||||||
|
iteritems = items
|
||||||
|
|
||||||
|
def values(self, duplicates=True):
|
||||||
|
for v in dict.itervalues(self):
|
||||||
|
if duplicates:
|
||||||
|
for x in v:
|
||||||
|
yield x
|
||||||
|
else:
|
||||||
|
yield v[-1]
|
||||||
|
itervalues = values
|
||||||
|
|
||||||
|
def set(self, key, val, replace=False):
|
||||||
|
if replace:
|
||||||
|
dict.__setitem__(self, key, [val])
|
||||||
|
else:
|
||||||
|
self[key] = val
|
||||||
|
|
||||||
|
def get(self, key, default=None, all=False):
|
||||||
|
if all:
|
||||||
|
try:
|
||||||
|
return dict.__getitem__(self, key)
|
||||||
|
except KeyError:
|
||||||
|
return []
|
||||||
|
return self.__getitem__(key)
|
||||||
|
|
||||||
|
def pop(self, key, default=None, all=False):
|
||||||
|
ans = dict.pop(self, key, default)
|
||||||
|
if ans is default:
|
||||||
|
return [] if all else default
|
||||||
|
return ans if all else ans[-1]
|
Loading…
x
Reference in New Issue
Block a user