Implement basic and digest auth

This commit is contained in:
Kovid Goyal 2015-06-08 14:14:26 +05:30
parent 5b7ef06e1b
commit 532fd2a7e9
13 changed files with 449 additions and 57 deletions

View File

@ -77,8 +77,7 @@ class OptionRecommendation(object):
self.option.choices: self.option.choices:
raise ValueError('OpRec: %s: Recommended value not in choices'% raise ValueError('OpRec: %s: Recommended value not in choices'%
self.option.name) self.option.name)
if not (isinstance(self.recommended_value, (int, float, str, unicode)) if not (isinstance(self.recommended_value, (int, float, str, unicode)) or self.recommended_value is None):
or self.recommended_value is None):
raise ValueError('OpRec: %s:'%self.option.name + raise ValueError('OpRec: %s:'%self.option.name +
repr(self.recommended_value) + repr(self.recommended_value) +
' is not a string or a number') ' is not a string or a number')
@ -123,6 +122,7 @@ def gui_configuration_widget(name, parent, get_option_by_name,
class InputFormatPlugin(Plugin): class InputFormatPlugin(Plugin):
''' '''
InputFormatPlugins are responsible for converting a document into InputFormatPlugins are responsible for converting a document into
HTML+OPF+CSS+etc. HTML+OPF+CSS+etc.
@ -262,7 +262,7 @@ class InputFormatPlugin(Plugin):
''' '''
Called to create the widget used for configuring this plugin in the Called to create the widget used for configuring this plugin in the
calibre GUI. The widget must be an instance of the PluginWidget class. calibre GUI. The widget must be an instance of the PluginWidget class.
See the builting input plugins for examples. See the builtin input plugins for examples.
''' '''
name = self.name.lower().replace(' ', '_') name = self.name.lower().replace(' ', '_')
return gui_configuration_widget(name, parent, get_option_by_name, return gui_configuration_widget(name, parent, get_option_by_name,
@ -270,6 +270,7 @@ class InputFormatPlugin(Plugin):
class OutputFormatPlugin(Plugin): class OutputFormatPlugin(Plugin):
''' '''
OutputFormatPlugins are responsible for converting an OEB document OutputFormatPlugins are responsible for converting an OEB document
(OPF+HTML) into an output ebook. (OPF+HTML) into an output ebook.
@ -360,7 +361,3 @@ class OutputFormatPlugin(Plugin):
name = self.name.lower().replace(' ', '_') name = self.name.lower().replace(' ', '_')
return gui_configuration_widget(name, parent, get_option_by_name, return gui_configuration_widget(name, parent, get_option_by_name,
get_option_help, db, book_id, for_output=True) get_option_help, db, book_id, for_output=True)

View File

@ -5,3 +5,5 @@ Rewrite server integration with nginx/apache section
Remove dependency on cherrypy from download and contribs pages and remove Remove dependency on cherrypy from download and contribs pages and remove
cherrypy private copy (you will have to re-write jsbrowser.test to not use cherrypy private copy (you will have to re-write jsbrowser.test to not use
cherrypy) cherrypy)
Remove the bundled routes package

218
src/calibre/srv/auth.py Normal file
View File

@ -0,0 +1,218 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import binascii, os, random, struct, base64, httplib
from hashlib import md5, sha1
from calibre.srv.errors import HTTPAuthRequired, HTTPSimpleResponse, InvalidCredentials
from calibre.srv.http_request import parse_uri
from calibre.srv.utils import parse_http_dict
from calibre.utils.monotonic import monotonic
MAX_AGE_SECONDS = 3600
def as_bytestring(x):
if not isinstance(x, bytes):
x = x.encode('utf-8')
return x
def md5_hex(s):
return md5(as_bytestring(s)).hexdigest().decode('ascii')
def sha1_hex(s):
return sha1(as_bytestring(s)).hexdigest().decode('ascii')
def base64_decode(s):
return base64.standard_b64decode(as_bytestring(s)).decode('utf-8')
class DigestAuth(object): # {{{
valid_algorithms = {'MD5', 'MD5-SESS'}
valid_qops = {'auth', 'auth-int'}
def __init__(self, header_val):
data = parse_http_dict(header_val)
self.realm = data.get('realm')
self.username = data.get('username')
self.nonce = data.get('nonce')
self.uri = data.get('uri')
self.method = data.get('method')
self.response = data.get('response')
self.algorithm = data.get('algorithm', 'MD5').upper()
self.cnonce = data.get('cnonce')
self.opaque = data.get('opaque')
self.qop = data.get('qop', '').lower()
self.nonce_count = data.get('nc')
if self.algorithm not in self.valid_algorithms:
raise HTTPSimpleResponse(httplib.BAD_REQUEST, 'Unsupported digest algorithm')
if not (self.username and self.realm and self.nonce and self.uri and self.response):
raise HTTPSimpleResponse(httplib.BAD_REQUEST, 'Digest algorithm required fields missing')
if self.qop:
if self.qop not in self.valid_qops:
raise HTTPSimpleResponse(httplib.BAD_REQUEST, 'Unsupported digest qop')
if not (self.cnonce and self.nonce_count):
raise HTTPSimpleResponse(httplib.BAD_REQUEST, 'qop present, but cnonce and nonce_count absent')
else:
if self.cnonce or self.nonce_count:
raise HTTPSimpleResponse(httplib.BAD_REQUEST, 'qop missing')
@staticmethod
def synthesize_nonce(realm, secret, timestamp=None):
'''Create a nonce for HTTP Digest AUTH.
The nonce is of the form timestamp:hash with has being a hash of the
timestamp, server secret and realm. This allows the timestamp to be
validated and stale nonce's to be rejected.'''
if timestamp is None:
timestamp = binascii.hexlify(struct.pack(b'!f', float(monotonic())))
h = sha1_hex(':'.join((timestamp, realm, secret)))
nonce = ':'.join((timestamp, h))
return nonce
def validate_nonce(self, realm, secret):
timestamp, hashpart = self.nonce.partition(':')[::2]
s_nonce = DigestAuth.synthesize_nonce(realm, secret, timestamp)
return s_nonce == self.nonce
def is_nonce_stale(self, max_age_seconds=MAX_AGE_SECONDS):
try:
timestamp = struct.unpack(b'!f', binascii.unhexlify(as_bytestring(self.nonce.partition(':')[0])))[0]
return timestamp + max_age_seconds < monotonic()
except Exception:
pass
return True
def H(self, val):
return md5_hex(val)
def H_A2(self, data):
"""Returns the H(A2) string. See :rfc:`2617` section 3.2.2.3."""
# RFC 2617 3.2.2.3
# If the "qop" directive's value is "auth" or is unspecified,
# then A2 is:
# A2 = method ":" digest-uri-value
#
# If the "qop" value is "auth-int", then A2 is:
# A2 = method ":" digest-uri-value ":" H(entity-body)
if self.qop == "auth-int":
a2 = "%s:%s:%s" % (data.method, self.uri, self.H(data.peek()))
else:
a2 = '%s:%s' % (data.method, self.uri)
return self.H(a2)
def request_digest(self, pw, data):
ha1 = self.H(':'.join((self.username, self.realm, pw)))
ha2 = self.H_A2(data)
# Request-Digest -- RFC 2617 3.2.2.1
if self.qop:
req = "%s:%s:%s:%s:%s" % (
self.nonce, self.nonce_count, self.cnonce, self.qop, ha2)
else:
req = "%s:%s" % (self.nonce, ha2)
# RFC 2617 3.2.2.2
#
# If the "algorithm" directive's value is "MD5" or is unspecified,
# then A1 is:
# A1 = unq(username-value) ":" unq(realm-value) ":" passwd
#
# If the "algorithm" directive's value is "MD5-sess", then A1 is
# calculated only once - on the first request by the client following
# receipt of a WWW-Authenticate challenge from the server.
# A1 = H( unq(username-value) ":" unq(realm-value) ":" passwd )
# ":" unq(nonce-value) ":" unq(cnonce-value)
if self.algorithm == 'MD5-SESS':
ha1 = self.H('%s:%s:%s' % (ha1, self.nonce, self.cnonce))
return self.H('%s:%s' % (ha1, req))
def validate_request(self, pw, data, log=None):
# We should also be checking for replay attacks by using nonce_count,
# however, various HTTP clients, most prominently Firefox dont
# implement nonce-counts correctly, so we cannot do the check.
# https://bugzil.la/114451
path = parse_uri(self.uri.encode('utf-8'))[1]
if path != data.path:
if log is not None:
log.warn('Authorization URI mismatch: %s != %s from client: %s' % (
data.path, path, data.remote_addr))
raise HTTPSimpleResponse(httplib.BAD_REQUEST, 'The uri in the Request Line and the Authorization header do not match')
return self.response is not None and path == data.path and self.request_digest(pw, data) == self.response
# }}}
class AuthController(object):
def __init__(self, user_credentials=None, prefer_basic_auth=False, realm='calibre', max_age_seconds=MAX_AGE_SECONDS, log=None):
self.user_credentials, self.prefer_basic_auth = user_credentials, prefer_basic_auth
self.log = log
self.secret = binascii.hexlify(os.urandom(random.randint(20, 30))).decode('ascii')
self.max_age_seconds = max_age_seconds
self.key_order = random.choice(('{0}:{1}', '{1}:{0}'))
self.realm = realm
if '"' in realm:
raise ValueError('Double-quotes are not allowed in the authentication realm')
for k, v in self.user_credentials.iteritems():
if '"' in k:
raise ValueError('Double-quotes are not allowed in usernames')
try:
k.encode('ascii'), v.encode('ascii')
except ValueError:
raise InvalidCredentials('Only ASCII characters are allowed in usernames and passwords')
def check(self, un, pw):
return pw and self.user_credentials.get(un) == pw
def __call__(self, data, endpoint):
# TODO: Implement Android workaround for /get
self.do_http_auth(data, endpoint)
def do_http_auth(self, data, endpoint):
auth = data.inheaders.get('Authorization')
nonce_is_stale = False
log_msg = None
data.username = None
if auth:
scheme, rest = auth.partition(' ')[::2]
scheme = scheme.lower()
if scheme == 'digest':
da = DigestAuth(rest.strip())
if da.validate_nonce(self.realm, self.secret):
pw = self.user_credentials.get(da.username)
if pw and da.validate_request(pw, data, self.log):
nonce_is_stale = da.is_nonce_stale(self.max_age_seconds)
if not nonce_is_stale:
data.username = da.username
return
else:
log_msg = 'Failed login attempt from: %s' % data.remote_addr
elif self.prefer_basic_auth and scheme == 'basic':
try:
un, pw = base64_decode(rest.strip()).partition(':')[::2]
except ValueError:
raise HTTPSimpleResponse(httplib.BAD_REQUEST, 'The username or password contained non-UTF8 encoded characters')
if not un or not pw:
raise HTTPSimpleResponse(httplib.BAD_REQUEST, 'The username or password was empty')
if self.check(un, pw):
data.username = un
return
else:
log_msg = 'Failed login attempt from: %s' % data.remote_addr
else:
raise HTTPSimpleResponse(httplib.BAD_REQUEST, 'Unsupported authentication method')
if self.prefer_basic_auth:
raise HTTPAuthRequired('Basic realm="%s"' % self.realm, log=log_msg)
s = 'Digest realm="%s", nonce="%s", algorithm="MD5", qop="auth"' % (
self.realm, DigestAuth.synthesize_nonce(self.realm, self.secret))
if nonce_is_stale:
s += ', stale="true"'
raise HTTPAuthRequired(s, log=log_msg)

View File

@ -16,11 +16,13 @@ class RouteError(ValueError):
class HTTPSimpleResponse(Exception): class HTTPSimpleResponse(Exception):
def __init__(self, http_code, http_message='', close_connection=False, location=None): def __init__(self, http_code, http_message='', close_connection=False, location=None, authenticate=None, log=None):
Exception.__init__(self, http_message) Exception.__init__(self, http_message)
self.http_code = http_code self.http_code = http_code
self.close_connection = close_connection self.close_connection = close_connection
self.location = location self.location = location
self.authenticate = authenticate
self.log = log
class HTTPRedirect(HTTPSimpleResponse): class HTTPRedirect(HTTPSimpleResponse):
@ -31,3 +33,11 @@ class HTTPNotFound(HTTPSimpleResponse):
def __init__(self, http_message='', close_connection=False): def __init__(self, http_message='', close_connection=False):
HTTPSimpleResponse.__init__(self, httplib.NOT_FOUND, http_message, close_connection) HTTPSimpleResponse.__init__(self, httplib.NOT_FOUND, http_message, close_connection)
class HTTPAuthRequired(HTTPSimpleResponse):
def __init__(self, payload, log=None):
HTTPSimpleResponse.__init__(self, httplib.UNAUTHORIZED, authenticate=payload, log=log)
class InvalidCredentials(ValueError):
pass

View File

@ -6,8 +6,6 @@ from __future__ import (unicode_literals, division, absolute_import,
__license__ = 'GPL v3' __license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>' __copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import binascii, os, random
from calibre.srv.routes import Router from calibre.srv.routes import Router
class LibraryBroker(object): class LibraryBroker(object):
@ -23,19 +21,9 @@ class Context(object):
def __init__(self, libraries, opts): def __init__(self, libraries, opts):
self.opts = opts self.opts = opts
self.library_broker = LibraryBroker(libraries) self.library_broker = LibraryBroker(libraries)
self.secret = bytes(binascii.hexlify(os.urandom(random.randint(20, 30))))
self.key_order = random.choice(('{0}:{1}', '{1}:{0}'))
def init_session(self, endpoint, data): def init_session(self, endpoint, data):
cval = data.inheaders.get('Cookie') or '' pass
if isinstance(cval, bytes):
cval = cval.decode('utf-8', 'replace')
data.cookies = c = {}
for x in cval.split(';'):
x = x.strip()
if x:
k, v = x.partition('=')[::2]
c[k] = v
def finalize_session(self, endpoint, data, output): def finalize_session(self, endpoint, data, output):
pass pass

View File

@ -12,6 +12,7 @@ from urllib import unquote
from calibre import as_unicode, force_unicode from calibre import as_unicode, force_unicode
from calibre.ptempfile import SpooledTemporaryFile from calibre.ptempfile import SpooledTemporaryFile
from calibre.srv.errors import HTTPSimpleResponse
from calibre.srv.loop import Connection, READ, WRITE from calibre.srv.loop import Connection, READ, WRITE
from calibre.srv.utils import MultiDict, HTTP1, HTTP11, Accumulator from calibre.srv.utils import MultiDict, HTTP1, HTTP11, Accumulator
@ -19,7 +20,8 @@ protocol_map = {(1, 0):HTTP1, (1, 1):HTTP11}
quoted_slash = re.compile(br'%2[fF]') quoted_slash = re.compile(br'%2[fF]')
HTTP_METHODS = {'HEAD', 'GET', 'PUT', 'POST', 'TRACE', 'DELETE', 'OPTIONS'} HTTP_METHODS = {'HEAD', 'GET', 'PUT', 'POST', 'TRACE', 'DELETE', 'OPTIONS'}
def parse_request_uri(uri): # {{{ # Parse URI {{{
def parse_request_uri(uri):
"""Parse a Request-URI into (scheme, authority, path). """Parse a Request-URI into (scheme, authority, path).
Note that Request-URI's must be one of:: Note that Request-URI's must be one of::
@ -49,7 +51,7 @@ def parse_request_uri(uri): # {{{
# http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query # http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query
# ]] # ]]
scheme, remainder = uri[:i].lower(), uri[i + 3:] scheme, remainder = uri[:i].lower(), uri[i + 3:]
authority, path = remainder.split(b'/', 1) authority, path = remainder.partition(b'/')[::2]
path = b'/' + path path = b'/' + path
return scheme, authority, path return scheme, authority, path
@ -59,6 +61,34 @@ def parse_request_uri(uri): # {{{
else: else:
# An authority. # An authority.
return None, uri, None return None, uri, None
def parse_uri(uri, parse_query=True):
scheme, authority, path = parse_request_uri(uri)
if b'#' in path:
raise HTTPSimpleResponse(httplib.BAD_REQUEST, "Illegal #fragment in Request-URI.")
if scheme:
try:
scheme = scheme.decode('ascii')
except ValueError:
raise HTTPSimpleResponse(httplib.BAD_REQUEST, 'Un-decodeable scheme')
path, qs = path.partition(b'?')[::2]
if parse_query:
try:
query = MultiDict.create_from_query_string(qs)
except Exception:
raise HTTPSimpleResponse(httplib.BAD_REQUEST, 'Unparseable query string')
else:
query = None
try:
path = '%2F'.join(unquote(x).decode('utf-8') for x in quoted_slash.split(path))
except ValueError as e:
raise HTTPSimpleResponse(httplib.BAD_REQUEST, as_unicode(e))
path = tuple(filter(None, (x.replace('%2F', '/') for x in path.split('/'))))
return scheme, path, query
# }}} # }}}
# HTTP Header parsing {{{ # HTTP Header parsing {{{
@ -69,13 +99,21 @@ comma_separated_headers = {
'Connection', 'Content-Encoding', 'Content-Language', 'Expect', 'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE', 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning', 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
'WWW-Authenticate'
} }
decoded_headers = { decoded_headers = {
'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect', 'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect', 'WWW-Authenticate', 'Authorization',
} | comma_separated_headers } | comma_separated_headers
uppercase_headers = {'WWW', 'TE'}
def normalize_header_name(name):
parts = [x.capitalize() for x in name.split('-')]
q = parts[0].upper()
if q in uppercase_headers:
parts[0] = q
return '-'.join(parts)
class HTTPHeaderParser(object): class HTTPHeaderParser(object):
''' '''
@ -102,7 +140,7 @@ class HTTPHeaderParser(object):
def safe_decode(hname, value): def safe_decode(hname, value):
try: try:
return value.decode('ascii') return value.decode('utf-8')
except UnicodeDecodeError: except UnicodeDecodeError:
if hname in decoded_headers: if hname in decoded_headers:
raise raise
@ -115,7 +153,7 @@ class HTTPHeaderParser(object):
del self.lines[:] del self.lines[:]
k, v = line.partition(b':')[::2] k, v = line.partition(b':')[::2]
key = k.strip().decode('ascii').title() key = normalize_header_name(k.strip().decode('ascii'))
val = safe_decode(key, v.strip()) val = safe_decode(key, v.strip())
if not key or not val: if not key or not val:
raise ValueError('Malformed header line: %s' % reprlib.repr(line)) raise ValueError('Malformed header line: %s' % reprlib.repr(line))
@ -224,29 +262,10 @@ class HTTPRequest(Connection):
except KeyError: except KeyError:
return self.simple_response(httplib.HTTP_VERSION_NOT_SUPPORTED) return self.simple_response(httplib.HTTP_VERSION_NOT_SUPPORTED)
self.response_protocol = protocol_map[min((1, 1), rp)] self.response_protocol = protocol_map[min((1, 1), rp)]
scheme, authority, path = parse_request_uri(uri)
if b'#' in path:
return self.simple_response(httplib.BAD_REQUEST, "Illegal #fragment in Request-URI.")
if scheme:
try: try:
self.scheme = scheme.decode('ascii') self.scheme, self.path, self.query = parse_uri(uri)
except ValueError: except HTTPSimpleResponse as e:
return self.simple_response(httplib.BAD_REQUEST, 'Un-decodeable scheme') return self.simple_response(e.http_code, e.message, close_after_response=False)
qs = b''
if b'?' in path:
path, qs = path.split(b'?', 1)
try:
self.query = MultiDict.create_from_query_string(qs)
except Exception:
return self.simple_response(httplib.BAD_REQUEST, 'Unparseable query string')
try:
path = '%2F'.join(unquote(x).decode('utf-8') for x in quoted_slash.split(path))
except ValueError as e:
return self.simple_response(httplib.BAD_REQUEST, as_unicode(e))
self.path = tuple(filter(None, (x.replace('%2F', '/') for x in path.split('/'))))
self.header_line_too_long_error_code = httplib.REQUEST_ENTITY_TOO_LARGE self.header_line_too_long_error_code = httplib.REQUEST_ENTITY_TOO_LARGE
self.request_line = line.rstrip() self.request_line = line.rstrip()
self.set_state(READ, self.parse_header_line, HTTPHeaderParser(), Accumulator()) self.set_state(READ, self.parse_header_line, HTTPHeaderParser(), Accumulator())

View File

@ -180,6 +180,9 @@ def get_range_parts(ranges, content_type, content_length): # {{{
class RequestData(object): # {{{ class RequestData(object): # {{{
cookies = {}
username = None
def __init__(self, method, path, query, inheaders, request_body_file, outheaders, response_protocol, def __init__(self, method, path, query, inheaders, request_body_file, outheaders, response_protocol,
static_cache, opts, remote_addr, remote_port, translator_cache): static_cache, opts, remote_addr, remote_port, translator_cache):
@ -201,6 +204,13 @@ class RequestData(object): # {{{
def read(self, size=-1): def read(self, size=-1):
return self.request_body_file.read(size) return self.request_body_file.read(size)
def peek(self, size=-1):
pos = self.request_body_file.tell()
try:
return self.read(size)
finally:
self.request_body_file.seek(pos)
def get_translator(self, bcp_47_code): def get_translator(self, bcp_47_code):
return get_translator_for_lang(self.translator_cache, bcp_47_code) return get_translator_for_lang(self.translator_cache, bcp_47_code)
@ -368,6 +378,10 @@ class HTTPConnection(HTTPRequest):
eh = {} eh = {}
if e.location: if e.location:
eh['Location'] = e.location eh['Location'] = e.location
if e.authenticate:
eh['WWW-Authenticate'] = e.authenticate
if e.log:
self.log.warn(e.log)
return self.simple_response(e.http_code, msg=e.message or '', close_after_response=e.close_connection, extra_headers=eh) return self.simple_response(e.http_code, msg=e.message or '', close_after_response=e.close_connection, extra_headers=eh)
raise etype, e, tb raise etype, e, tb

View File

@ -117,8 +117,8 @@ class ReadBuffer(object): # {{{
class Connection(object): # {{{ class Connection(object): # {{{
def __init__(self, socket, opts, ssl_context, tdir, addr, pool): def __init__(self, socket, opts, ssl_context, tdir, addr, pool, log):
self.opts, self.pool = opts, pool self.opts, self.pool, self.log = opts, pool, log
try: try:
self.remote_addr = addr[0] self.remote_addr = addr[0]
self.remote_port = addr[1] self.remote_port = addr[1]
@ -532,7 +532,7 @@ class ServerLoop(object):
if sock is not None: if sock is not None:
s = sock.fileno() s = sock.fileno()
if s > -1: if s > -1:
self.connection_map[s] = conn = self.handler(sock, self.opts, self.ssl_context, self.tdir, addr, self.pool) self.connection_map[s] = conn = self.handler(sock, self.opts, self.ssl_context, self.tdir, addr, self.pool, self.log)
if self.ssl_context is not None: if self.ssl_context is not None:
yield s, conn, RDWR yield s, conn, RDWR
elif s == control: elif s == control:

View File

@ -17,11 +17,13 @@ default_methods = frozenset(('HEAD', 'GET'))
def route_key(route): def route_key(route):
return route.partition('{')[0].rstrip('/') return route.partition('{')[0].rstrip('/')
def endpoint(route, methods=default_methods, types=None): def endpoint(route, methods=default_methods, types=None, auth_required=False, android_workaround=False):
def annotate(f): def annotate(f):
f.route = route.rstrip('/') or '/' f.route = route.rstrip('/') or '/'
f.types = types or {} f.types = types or {}
f.methods = methods f.methods = methods
f.auth_required = auth_required
f.android_workaround = android_workaround
f.is_endpoint = True f.is_endpoint = True
return f return f
return annotate return annotate
@ -128,12 +130,16 @@ class Route(object):
class Router(object): class Router(object):
def __init__(self, ctx=None, url_prefix=None): def __init__(self, endpoints=None, ctx=None, url_prefix=None, auth_controller=None):
self.routes = {} self.routes = {}
self.url_prefix = url_prefix or '' self.url_prefix = url_prefix or ''
self.ctx = ctx self.ctx = ctx
self.auth_controller = auth_controller
self.init_session = getattr(ctx, 'init_session', lambda ep, data:None) self.init_session = getattr(ctx, 'init_session', lambda ep, data:None)
self.finalize_session = getattr(ctx, 'finalize_session', lambda ep, data, output:None) self.finalize_session = getattr(ctx, 'finalize_session', lambda ep, data, output:None)
if endpoints is not None:
self.load_routes(endpoints)
self.finalize()
def add(self, endpoint): def add(self, endpoint):
key = route_key(endpoint.route) key = route_key(endpoint.route)
@ -141,6 +147,11 @@ class Router(object):
raise RouteError('A route with the key: %s already exists as: %s' % (key, self.routes[key])) raise RouteError('A route with the key: %s already exists as: %s' % (key, self.routes[key]))
self.routes[key] = Route(endpoint) self.routes[key] = Route(endpoint)
def load_routes(self, items):
for item in items:
if getattr(item, 'is_endpoint', False) is True:
self.add(item)
def __iter__(self): def __iter__(self):
return self.routes.itervalues() return self.routes.itervalues()
@ -168,10 +179,29 @@ class Router(object):
return route.endpoint, args return route.endpoint, args
raise HTTPNotFound() raise HTTPNotFound()
def read_cookies(self, data):
data.cookies = c = {}
for cval in data.inheaders.get('Cookie', all=True):
if isinstance(cval, bytes):
cval = cval.decode('utf-8', 'replace')
for x in cval.split(';'):
x = x.strip()
if x:
k, v = x.partition('=')[::2]
if k:
c[k] = v
def dispatch(self, data): def dispatch(self, data):
endpoint_, args = self.find_route(data.path) endpoint_, args = self.find_route(data.path)
if data.method not in endpoint_.methods: if data.method not in endpoint_.methods:
raise HTTPSimpleResponse(httplib.METHOD_NOT_ALLOWED) raise HTTPSimpleResponse(httplib.METHOD_NOT_ALLOWED)
self.read_cookies(data)
if endpoint_.auth_required and self.auth_controller is not None:
self.auth_controller(data, endpoint_)
self.init_session(endpoint_, data) self.init_session(endpoint_, data)
ans = endpoint_(self.ctx, data, *args) ans = endpoint_(self.ctx, data, *args)
self.finalize_session(endpoint_, data, ans) self.finalize_session(endpoint_, data, ans)

View File

@ -10,6 +10,7 @@ import sys, os, signal
from calibre import as_unicode from calibre import as_unicode
from calibre.constants import plugins, iswindows from calibre.constants import plugins, iswindows
from calibre.srv.errors import InvalidCredentials
from calibre.srv.loop import ServerLoop from calibre.srv.loop import ServerLoop
from calibre.srv.bonjour import BonJour from calibre.srv.bonjour import BonJour
from calibre.srv.opts import opts_to_parser from calibre.srv.opts import opts_to_parser
@ -120,7 +121,10 @@ def main(args=sys.argv):
return auto_reload(default_log) return auto_reload(default_log)
except NoAutoReload as e: except NoAutoReload as e:
raise SystemExit(e.message) raise SystemExit(e.message)
try:
server = Server(libraries, opts) server = Server(libraries, opts)
except InvalidCredentials as e:
raise SystemExit(e.message)
if opts.daemonize: if opts.daemonize:
if not opts.log and not iswindows: if not opts.log and not iswindows:
raise SystemExit('In order to daemonize you must specify a log file, you can use /dev/stdout to log to screen even as a daemon') raise SystemExit('In order to daemonize you must specify a log file, you can use /dev/stdout to log to screen even as a daemon')

View File

@ -0,0 +1,96 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import httplib, base64, urllib2
from calibre.srv.tests.base import BaseTest, TestServer
from calibre.srv.routes import endpoint, Router
REALM = 'calibre-test'
@endpoint('/open')
def noauth(ctx, data):
return 'open'
@endpoint('/closed', auth_required=True)
def auth(ctx, data):
return 'closed'
@endpoint('/android', auth_required=True, android_workaround=True)
def android(ctx, data):
return '/android'
def router(prefer_basic_auth=False):
from calibre.srv.auth import AuthController
return Router(globals().itervalues(), auth_controller=AuthController(
{'testuser':'testpw', '!@#$%^&*()-=_+':'!@#$%^&*()-=_+'},
prefer_basic_auth=prefer_basic_auth, realm=REALM, max_age_seconds=1))
def urlopen(server, path='/closed', un='testuser', pw='testpw', method='digest'):
auth_handler = urllib2.HTTPBasicAuthHandler() if method == 'basic' else urllib2.HTTPDigestAuthHandler()
url = 'http://localhost:%d%s' % (server.address[1], path)
auth_handler.add_password(realm=REALM, uri=url, user=un, passwd=pw)
return urllib2.build_opener(auth_handler).open(url)
class TestAuth(BaseTest):
def test_basic_auth(self):
'Test HTTP Basic auth'
r = router(prefer_basic_auth=True)
with TestServer(r.dispatch) as server:
r.auth_controller.log = server.log
conn = server.connect()
conn.request('GET', '/open')
r = conn.getresponse()
self.ae(r.status, httplib.OK)
self.ae(r.read(), b'open')
conn.request('GET', '/closed')
r = conn.getresponse()
self.ae(r.status, httplib.UNAUTHORIZED)
self.ae(r.getheader('WWW-Authenticate'), b'Basic realm="%s"' % bytes(REALM))
self.assertFalse(r.read())
conn.request('GET', '/closed', headers={'Authorization': b'Basic ' + base64.standard_b64encode(b'testuser:testpw')})
r = conn.getresponse()
self.ae(r.read(), b'closed')
self.ae(r.status, httplib.OK)
self.ae(b'closed', urlopen(server, method='basic').read())
self.ae(b'closed', urlopen(server, un='!@#$%^&*()-=_+', pw='!@#$%^&*()-=_+', method='basic').read())
def request(un='testuser', pw='testpw'):
conn.request('GET', '/closed', headers={'Authorization': b'Basic ' + base64.standard_b64encode(bytes('%s:%s' % (un, pw)))})
r = conn.getresponse()
return r.status, r.read()
warnings = []
server.loop.log.warn = lambda *args, **kwargs: warnings.append(' '.join(args))
self.ae((httplib.OK, b'closed'), request())
self.ae((httplib.UNAUTHORIZED, b''), request('x', 'y'))
self.ae(1, len(warnings))
self.ae((httplib.UNAUTHORIZED, b''), request('testuser', 'y'))
self.ae((httplib.UNAUTHORIZED, b''), request('asf', 'testpw'))
def test_digest_auth(self):
'Test HTTP Digest auth'
from calibre.srv.http_request import normalize_header_name
from calibre.srv.utils import parse_http_dict
r = router()
with TestServer(r.dispatch) as server:
r.auth_controller.log = server.log
def test(conn, path, headers={}, status=httplib.OK, body=b'', request_body=b''):
conn.request('GET', path, request_body, headers)
r = conn.getresponse()
self.ae(r.status, status)
self.ae(r.read(), body)
return {normalize_header_name(k):v for k, v in r.getheaders()}
conn = server.connect()
test(conn, '/open', body=b'open')
auth = parse_http_dict(test(conn, '/closed', status=httplib.UNAUTHORIZED)['WWW-Authenticate'])
self.ae(auth[b'Digest realm'], bytes(REALM)), self.ae(auth[b'algorithm'], b'MD5'), self.ae(auth[b'qop'], b'auth')
self.assertNotIn('stale', auth)
self.ae(urlopen(server).read(), b'closed')

View File

@ -35,7 +35,7 @@ class TestHTTP(BaseTest):
'\r\n', a='one', b='two 2 3', c='three') '\r\n', a='one', b='two 2 3', c='three')
test('Non-ascii headers parsing', test('Non-ascii headers parsing',
b'a:mūs\r', '\r\n', a='mūs'.encode('utf-8')) b'a:mūs\r', '\r\n', a='mūs')
test('Comma-separated parsing', test('Comma-separated parsing',
'Accept-Encoding: one', 'Accept-Encoding: one',
@ -47,7 +47,7 @@ class TestHTTP(BaseTest):
lines.append(b'\r\n') lines.append(b'\r\n')
self.assertRaises(ValueError, HTTPHeaderParser().push, *lines) self.assertRaises(ValueError, HTTPHeaderParser().push, *lines)
parse(b'Connection:mūs\r\n') parse('Connection:mūs\r\n'.encode('utf-16'))
parse(b'Connection\r\n') parse(b'Connection\r\n')
parse(b'Connection:a\r\n', b'\r\n') parse(b'Connection:a\r\n', b'\r\n')
parse(b' Connection:a\n') parse(b' Connection:a\n')

View File

@ -214,6 +214,20 @@ def parse_http_list(header_val):
if part: if part:
yield part.strip() yield part.strip()
def parse_http_dict(header_val):
'Parse an HTTP comma separated header with items of the form a=1, b="xxx" into a dictionary'
if not header_val:
return {}
ans = {}
sep, dquote = b'="' if isinstance(header_val, bytes) else '="'
for item in parse_http_list(header_val):
k, v = item.partition(sep)[::2]
if k:
if v.startswith(dquote) and v.endswith(dquote):
v = v[1:-1]
ans[k] = v
return ans
def sort_q_values(header_val): def sort_q_values(header_val):
'Get sorted items from an HTTP header of type: a;q=0.5, b;q=0.7...' 'Get sorted items from an HTTP header of type: a;q=0.5, b;q=0.7...'
if not header_val: if not header_val: