From d075eff7582adfd14ab1c0c5639fad23d12b2602 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Mon, 25 May 2015 18:34:51 +0530 Subject: [PATCH] Finish implementation of async http server --- src/calibre/srv/async.py | 337 ------------ src/calibre/srv/errors.py | 21 - src/calibre/srv/http.py | 602 --------------------- src/calibre/srv/http_request.py | 374 +++++++++++++ src/calibre/srv/http_response.py | 525 ++++++++++++++++++ src/calibre/srv/loop.py | 900 +++++++++---------------------- src/calibre/srv/opts.py | 16 +- src/calibre/srv/respond.py | 356 ------------ src/calibre/srv/sendfile.py | 39 +- src/calibre/srv/tests/base.py | 6 +- src/calibre/srv/tests/http.py | 165 ++++-- src/calibre/srv/utils.py | 37 +- 12 files changed, 1314 insertions(+), 2064 deletions(-) delete mode 100644 src/calibre/srv/async.py delete mode 100644 src/calibre/srv/http.py create mode 100644 src/calibre/srv/http_request.py create mode 100644 src/calibre/srv/http_response.py delete mode 100644 src/calibre/srv/respond.py diff --git a/src/calibre/srv/async.py b/src/calibre/srv/async.py deleted file mode 100644 index d689b364bb..0000000000 --- a/src/calibre/srv/async.py +++ /dev/null @@ -1,337 +0,0 @@ -#!/usr/bin/env python2 -# vim:fileencoding=utf-8 -from __future__ import (unicode_literals, division, absolute_import, - print_function) - -__license__ = 'GPL v3' -__copyright__ = '2015, Kovid Goyal ' - -import ssl, socket, select, os -from io import BytesIO - -from calibre import as_unicode -from calibre.srv.opts import Options -from calibre.srv.utils import ( - socket_errors_socket_closed, socket_errors_nonblocking, HandleInterrupt) -from calibre.utils.socket_inheritance import set_socket_inherit -from calibre.utils.logging import ThreadSafeLog -from calibre.utils.monotonic import monotonic - -READ, WRITE, RDWR = 'READ', 'WRITE', 'RDWR' - -class Connection(object): - - def __init__(self, socket, opts, ssl_context): - self.opts = opts - self.ssl_context = ssl_context - self.wait_for = READ - if self.ssl_context is not None: - self.ready = False - self.socket = self.ssl_context.wrap_socket(socket, server_side=True, do_handshake_on_connect=False) - self.set_state(RDWR, self.do_ssl_handshake) - else: - self.ready = True - self.socket = socket - self.connection_ready() - self.last_activity = monotonic() - - def set_state(self, wait_for, func): - self.wait_for = wait_for - self.handle_event = func - - def do_ssl_handshake(self, event): - try: - self._sslobj.do_handshake() - except ssl.SSLWantReadError: - self.set_state(READ, self.do_ssl_handshake) - except ssl.SSLWantWriteError: - self.set_state(WRITE, self.do_ssl_handshake) - self.ready = True - self.connection_ready() - - def send(self, data): - try: - ret = self.socket.send(data) - self.last_activity = monotonic() - return ret - except socket.error as e: - if e.errno in socket_errors_nonblocking: - return 0 - elif e.errno in socket_errors_socket_closed: - self.ready = False - return 0 - raise - - def recv(self, buffer_size): - try: - data = self.socket.recv(buffer_size) - self.last_activity = monotonic() - if not data: - # a closed connection is indicated by signaling - # a read condition, and having recv() return 0. - self.ready = False - return b'' - return data - except socket.error as e: - if e.errno in socket_errors_socket_closed: - self.ready = False - return b'' - - def close(self): - self.ready = False - try: - self.socket.shutdown(socket.SHUT_WR) - self.socket.close() - except socket.error: - pass - - def connection_ready(self): - raise NotImplementedError() - -class ServerLoop(object): - - def __init__( - self, - handler, - bind_address=('localhost', 8080), - opts=None, - # A calibre logging object. If None, a default log that logs to - # stdout is used - log=None - ): - self.ready = False - self.handler = handler - self.opts = opts or Options() - self.log = log or ThreadSafeLog(level=ThreadSafeLog.DEBUG) - - ba = tuple(bind_address) - if not ba[0]: - # AI_PASSIVE does not work with host of '' or None - ba = ('0.0.0.0', ba[1]) - self.bind_address = ba - self.bound_address = None - - self.ssl_context = None - if self.opts.ssl_certfile is not None and self.opts.ssl_keyfile is not None: - self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - self.ssl_context.load_cert_chain(certfile=self.opts.ssl_certfile, keyfile=self.opts.ssl_keyfile) - - self.pre_activated_socket = None - if self.opts.allow_socket_preallocation: - from calibre.srv.pre_activated import pre_activated_socket - self.pre_activated_socket = pre_activated_socket() - if self.pre_activated_socket is not None: - set_socket_inherit(self.pre_activated_socket, False) - self.bind_address = self.pre_activated_socket.getsockname() - - def __str__(self): - return "%s(%r)" % (self.__class__.__name__, self.bind_address) - __repr__ = __str__ - - def serve_forever(self): - """ Listen for incoming connections. """ - - if self.pre_activated_socket is None: - # AF_INET or AF_INET6 socket - # Get the correct address family for our host (allows IPv6 - # addresses) - host, port = self.bind_address - try: - info = socket.getaddrinfo( - host, port, socket.AF_UNSPEC, - socket.SOCK_STREAM, 0, socket.AI_PASSIVE) - except socket.gaierror: - if ':' in host: - info = [(socket.AF_INET6, socket.SOCK_STREAM, - 0, "", self.bind_address + (0, 0))] - else: - info = [(socket.AF_INET, socket.SOCK_STREAM, - 0, "", self.bind_address)] - - self.socket = None - msg = "No socket could be created" - for res in info: - af, socktype, proto, canonname, sa = res - try: - self.bind(af, socktype, proto) - except socket.error, serr: - msg = "%s -- (%s: %s)" % (msg, sa, serr) - if self.socket: - self.socket.close() - self.socket = None - continue - break - if not self.socket: - raise socket.error(msg) - else: - self.socket = self.pre_activated_socket - self.pre_activated_socket = None - self.setup_socket() - - self.ready = True - self.connection_map = {} - self.socket.listen(min(socket.SOMAXCONN, 128)) - self.bound_address = ba = self.socket.getsockname() - if isinstance(ba, tuple): - ba = ':'.join(map(type(''), ba)) - self.log('calibre server listening on', ba) - - while True: - try: - self.tick() - except (KeyboardInterrupt, SystemExit): - self.shutdown() - break - except: - self.log.exception('Error in ServerLoop.tick') - - def setup_socket(self): - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if self.opts.no_delay and not isinstance(self.bind_address, basestring): - self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), - # activate dual-stack. - if (hasattr(socket, 'AF_INET6') and self.socket.family == socket.AF_INET6 and - self.bind_address[0] in ('::', '::0', '::0.0.0.0')): - try: - self.socket.setsockopt( - socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) - except (AttributeError, socket.error): - # Apparently, the socket option is not available in - # this machine's TCP stack - pass - self.socket.setblocking(0) - - def bind(self, family, atype, proto=0): - '''Create (or recreate) the actual socket object.''' - self.socket = socket.socket(family, atype, proto) - set_socket_inherit(self.socket, False) - self.setup_socket() - self.socket.bind(self.bind_address) - - def tick(self): - now = monotonic() - for s, conn in tuple(self.connection_map.iteritems()): - if now - conn.last_activity > self.opts.timeout: - self.log.debug('Closing connection because of extended inactivity') - self.close(s, conn) - - read_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is READ or c.wait_for is RDWR] - write_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is WRITE or c.wait_for is RDWR] - readable, writable = select.select([self.socket] + read_needed, write_needed, [], self.opts.timeout)[:2] - if not self.ready: - return - - for s, conn, event in self.get_actions(readable, writable): - try: - conn.handle_event(event) - if not conn.ready: - self.close(s, conn) - except Exception as e: - if conn.ready: - self.log.exception('Unhandled exception, terminating connection') - else: - self.log.error('Error in SSL handshake, terminating connection: %s' % as_unicode(e)) - self.close(s, conn) - - def wakeup(self): - # Touch our own socket to make select() return immediately. - sock = getattr(self, "socket", None) - if sock is not None: - try: - host, port = sock.getsockname()[:2] - except socket.error as e: - if e.errno not in socket_errors_socket_closed: - raise - else: - for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, - socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - s = None - try: - s = socket.socket(af, socktype, proto) - s.settimeout(1.0) - s.connect((host, port)) - s.close() - except socket.error: - if s is not None: - s.close() - return sock - - def close(self, s, conn): - self.connection_map.pop(s, None) - conn.close() - - def get_actions(self, readable, writable): - for s in readable: - if s is self.socket: - s, addr = self.accept() - if s is not None: - self.connection_map[s] = conn = self.handler(s, self.opts, self.ssl_context) - if self.ssl_context is not None: - yield s, conn, RDWR - else: - yield s, self.connection_map[s], READ - for s in writable: - yield s, self.connection_map[s], WRITE - - def accept(self): - try: - return self.socket.accept() - except socket.error: - return None, None - - def stop(self): - self.ready = False - self.wakeup() - - def shutdown(self): - try: - if getattr(self, 'socket', None): - self.socket.close() - self.socket = None - except socket.error: - pass - for s, conn in tuple(self.connection_map.iteritems()): - self.close(s, conn) - -class EchoLine(Connection): # {{{ - - bye_after_echo = False - - def connection_ready(self): - self.rbuf = BytesIO() - self.set_state(READ, self.read_line) - - def read_line(self, event): - data = self.recv(1) - if data: - self.rbuf.write(data) - if b'\n' == data: - if self.rbuf.tell() < 3: - # Empty line - self.rbuf = BytesIO(b'bye' + self.rbuf.getvalue()) - self.bye_after_echo = True - self.set_state(WRITE, self.echo) - self.rbuf.seek(0) - - def echo(self, event): - pos = self.rbuf.tell() - self.rbuf.seek(0, os.SEEK_END) - left = self.rbuf.tell() - pos - self.rbuf.seek(pos) - sent = self.send(self.rbuf.read(512)) - if sent == left: - self.rbuf = BytesIO() - self.set_state(READ, self.read_line) - if self.bye_after_echo: - self.ready = False - else: - self.rbuf.seek(pos + sent) -# }}} - -if __name__ == '__main__': - s = ServerLoop(EchoLine) - with HandleInterrupt(s.wakeup): - s.serve_forever() diff --git a/src/calibre/srv/errors.py b/src/calibre/srv/errors.py index bfd951d7fe..b0073c3d8d 100644 --- a/src/calibre/srv/errors.py +++ b/src/calibre/srv/errors.py @@ -7,26 +7,5 @@ __license__ = 'GPL v3' __copyright__ = '2015, Kovid Goyal ' -class MaxSizeExceeded(Exception): - - def __init__(self, prefix, size, limit): - Exception.__init__(self, prefix + (' %d > maximum %d' % (size, limit))) - self.size = size - self.limit = limit - class HTTP404(Exception): pass - -class IfNoneMatch(Exception): - def __init__(self, etag=None): - Exception.__init__(self, '') - self.etag = etag - -class BadChunkedInput(ValueError): - pass - -class RangeNotSatisfiable(ValueError): - - def __init__(self, content_length): - ValueError.__init__(self) - self.content_length = content_length diff --git a/src/calibre/srv/http.py b/src/calibre/srv/http.py deleted file mode 100644 index 704d32f232..0000000000 --- a/src/calibre/srv/http.py +++ /dev/null @@ -1,602 +0,0 @@ -#!/usr/bin/env python2 -# vim:fileencoding=utf-8 -from __future__ import (unicode_literals, division, absolute_import, - print_function) - -__license__ = 'GPL v3' -__copyright__ = '2015, Kovid Goyal ' - -import httplib, socket, re, os -from io import BytesIO -import repr as reprlib -from urllib import unquote -from functools import partial -from operator import itemgetter - -from calibre import as_unicode -from calibre.constants import __version__ -from calibre.srv.errors import ( - MaxSizeExceeded, HTTP404, IfNoneMatch, BadChunkedInput, RangeNotSatisfiable) -from calibre.srv.respond import finalize_output, generate_static_output -from calibre.srv.utils import MultiDict, http_date, socket_errors_to_ignore - -HTTP1 = 'HTTP/1.0' -HTTP11 = 'HTTP/1.1' -protocol_map = {(1, 0):HTTP1, (1, 1):HTTP11} -quoted_slash = re.compile(br'%2[fF]') - -def parse_request_uri(uri): # {{{ - """Parse a Request-URI into (scheme, authority, path). - - Note that Request-URI's must be one of:: - - Request-URI = "*" | absoluteURI | abs_path | authority - - Therefore, a Request-URI which starts with a double forward-slash - cannot be a "net_path":: - - net_path = "//" authority [ abs_path ] - - Instead, it must be interpreted as an "abs_path" with an empty first - path segment:: - - abs_path = "/" path_segments - path_segments = segment *( "/" segment ) - segment = *pchar *( ";" param ) - param = *pchar - """ - if uri == b'*': - return None, None, uri - - i = uri.find(b'://') - if i > 0 and b'?' not in uri[:i]: - # An absoluteURI. - # If there's a scheme (and it must be http or https), then: - # http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query - # ]] - scheme, remainder = uri[:i].lower(), uri[i + 3:] - authority, path = remainder.split(b'/', 1) - path = b'/' + path - return scheme, authority, path - - if uri.startswith(b'/'): - # An abs_path. - return None, None, uri - else: - # An authority. - return None, uri, None -# }}} - -comma_separated_headers = { - 'Accept', 'Accept-Charset', 'Accept-Encoding', - 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control', - 'Connection', 'Content-Encoding', 'Content-Language', 'Expect', - 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE', - 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning', - 'WWW-Authenticate' -} - -decoded_headers = { - 'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect', -} | comma_separated_headers - -def read_headers(readline): # {{{ - """ - Read headers from the given stream into the given header dict. - - If hdict is None, a new header dict is created. Returns the populated - header dict. - - Headers which are repeated are folded together using a comma if their - specification so dictates. - - This function raises ValueError when the read bytes violate the HTTP spec. - You should probably return "400 Bad Request" if this happens. - """ - hdict = MultiDict() - - def safe_decode(hname, value): - try: - return value.decode('ascii') - except UnicodeDecodeError: - if hname in decoded_headers: - raise - return value - - current_key = current_value = None - - def commit(): - if current_key: - key = current_key.decode('ascii') - val = safe_decode(key, current_value) - if key in comma_separated_headers: - existing = hdict.pop(key) - if existing is not None: - val = existing + ', ' + val - hdict[key] = val - - while True: - line = readline() - if not line: - # No more data--illegal end of headers - raise ValueError("Illegal end of headers.") - - if line == b'\r\n': - # Normal end of headers - commit() - break - if not line.endswith(b'\r\n'): - raise ValueError("HTTP requires CRLF terminators") - - if line[0] in b' \t': - # It's a continuation line. - if current_key is None or current_value is None: - raise ValueError('Orphaned continuation line') - current_value += b' ' + line.strip() - else: - commit() - current_key = current_value = None - k, v = line.split(b':', 1) - current_key = k.strip().title() - current_value = v.strip() - - return hdict -# }}} - -def http_communicate(handle_request, conn): - ' Represents interaction with a http client over a single, persistent connection ' - request_seen = False - def repr_for_pair(pair): - return pair.repr_for_log() if getattr(pair, 'started_request', False) else 'None' - - def simple_response(pair, code, msg=''): - if pair and not pair.sent_headers: - try: - pair.simple_response(code, msg=msg) - except socket.error as e: - if e.errno not in socket_errors_to_ignore: - raise - - try: - while True: - # (re)set pair to None so that if something goes wrong in - # the HTTPPair constructor, the error doesn't - # get written to the previous request. - pair = None - pair = HTTPPair(handle_request, conn) - - # This order of operations should guarantee correct pipelining. - pair.parse_request() - if not pair.ready: - # Something went wrong in the parsing (and the server has - # probably already made a simple_response). Return and - # let the conn close. - return - - request_seen = True - pair.respond() - if pair.close_connection: - return - except socket.timeout: - # Don't error if we're between requests; only error - # if 1) no request has been started at all, or 2) we're - # in the middle of a request. This allows persistent - # connections for HTTP/1.1 - if (not request_seen) or (pair and pair.started_request): - # Don't bother writing the 408 if the response - # has already started being written. - simple_response(pair, httplib.REQUEST_TIMEOUT) - except socket.error: - # This socket is broken. Log the error and close connection - conn.server_loop.log.exception( - 'Communication failed (socket error) while processing request:', repr_for_pair(pair)) - except MaxSizeExceeded as e: - conn.server_loop.log.warn('Too large request body (%d > %d) for request:' % (e.size, e.limit), repr_for_pair(pair)) - # Can happen if the request uses chunked transfer encoding - simple_response(pair, httplib.REQUEST_ENTITY_TOO_LARGE, - "The entity sent with the request exceeds the maximum " - "allowed bytes (%d)." % pair.max_request_body_size) - except BadChunkedInput as e: - conn.server_loop.log.warn('Bad chunked encoding (%s) for request:' % as_unicode(e.message), repr_for_pair(pair)) - simple_response(pair, httplib.BAD_REQUEST, - 'Invalid chunked encoding for request body: %s' % as_unicode(e.message)) - except Exception: - conn.server_loop.log.exception('Error serving request:', pair.repr_for_log() if getattr(pair, 'started_request', False) else 'None') - simple_response(pair, httplib.INTERNAL_SERVER_ERROR) - -class FixedSizeReader(object): # {{{ - - def __init__(self, socket_file, content_length): - self.socket_file, self.remaining = socket_file, content_length - - def read(self, size=-1): - if size < 0: - size = self.remaining - size = min(self.remaining, size) - if size < 1: - return b'' - data = self.socket_file.read(size) - self.remaining -= len(data) - return data -# }}} - -class ChunkedReader(object): # {{{ - - def __init__(self, socket_file, maxsize): - self.socket_file, self.maxsize = socket_file, maxsize - self.rbuf = BytesIO() - self.bytes_read = 0 - self.finished = False - - def check_size(self): - if self.bytes_read > self.maxsize: - raise MaxSizeExceeded('Request entity too large', self.bytes_read, self.maxsize) - - def read_chunk(self): - if self.finished: - return - line = self.socket_file.readline() - self.bytes_read += len(line) - self.check_size() - chunk_size = line.strip().split(b';', 1)[0] - try: - chunk_size = int(line, 16) + 2 - except Exception: - raise BadChunkedInput('%s is not a valid chunk size' % reprlib.repr(chunk_size)) - if chunk_size + self.bytes_read > self.maxsize: - raise MaxSizeExceeded('Request entity too large', self.bytes_read + chunk_size, self.maxsize) - try: - chunk = self.socket_file.read(chunk_size) - except socket.timeout: - raise BadChunkedInput('Timed out waiting for chunk of size %d to complete' % chunk_size) - if len(chunk) < chunk_size: - raise BadChunkedInput('Bad chunked encoding, chunk truncated: %d < %s' % (len(chunk), chunk_size)) - if not chunk.endswith(b'\r\n'): - raise BadChunkedInput('Bad chunked encoding: %r != CRLF' % chunk[-2:]) - self.rbuf.seek(0, os.SEEK_END) - self.bytes_read += chunk_size - if chunk_size == 2: - self.finished = True - else: - self.rbuf.write(chunk[:-2]) - - def read(self, size=-1): - if size < 0: - # Read all data - while not self.finished: - self.read_chunk() - self.rbuf.seek(0) - rv = self.rbuf.read() - if rv: - self.rbuf.truncate(0) - return rv - if size == 0: - return b'' - while self.rbuf.tell() < size and not self.finished: - self.read_chunk() - data = self.rbuf.getvalue() - self.rbuf.truncate(0) - if size < len(data): - self.rbuf.write(data[size:]) - return data[:size] - return data -# }}} - -class HTTPPair(object): - - ''' Represents a HTTP request/response pair ''' - - def __init__(self, handle_request, conn): - self.conn = conn - self.server_loop = conn.server_loop - self.max_header_line_size = int(self.server_loop.opts.max_header_line_size * 1024) - self.max_request_body_size = int(self.server_loop.opts.max_request_body_size * 1024 * 1024) - self.scheme = 'http' if self.server_loop.ssl_context is None else 'https' - self.inheaders = MultiDict() - self.outheaders = MultiDict() - self.handle_request = handle_request - self.request_line = None - self.path = () - self.qs = MultiDict() - self.method = None - - """When True, the request has been parsed and is ready to begin generating - the response. When False, signals the calling Connection that the response - should not be generated and the connection should close, immediately after - parsing the request.""" - self.ready = False - - """Signals the calling Connection that the request should close. This does - not imply an error! The client and/or server may each request that the - connection be closed, after the response.""" - self.close_connection = False - - self.started_request = False - self.response_protocol = HTTP1 - - self.status_code = None - self.sent_headers = False - - self.request_content_length = 0 - self.chunked_read = False - - def parse_request(self): - """Parse the next HTTP request start-line and message-headers.""" - try: - if not self.read_request_line(): - return - except MaxSizeExceeded as e: - self.server_loop.log.warn('Too large request URI (%d > %d), dropping connection' % (e.size, e.limit)) - self.simple_response( - httplib.REQUEST_URI_TOO_LONG, - "The Request-URI sent with the request exceeds the maximum allowed bytes.") - return - - try: - if not self.read_request_headers(): - return - except MaxSizeExceeded as e: - self.server_loop.log.warn('Too large header (%d > %d) for request, dropping connection' % (e.size, e.limit)) - self.simple_response( - httplib.REQUEST_ENTITY_TOO_LARGE, - "The headers sent with the request exceed the maximum allowed bytes.") - return - - self.ready = True - - def read_request_line(self): - self.request_line = request_line = self.conn.socket_file.readline(maxsize=self.max_header_line_size) - - # Set started_request to True so http_communicate() knows to send 408 - # from here on out. - self.started_request = True - if not request_line: - return False - - if request_line == b'\r\n': - # RFC 2616 sec 4.1: "...if the server is reading the protocol - # stream at the beginning of a message and receives a CRLF - # first, it should ignore the CRLF." - # But only ignore one leading line! else we enable a DoS. - request_line = self.conn.socket_file.readline(maxsize=self.max_header_line_size) - if not request_line: - return False - - if not request_line.endswith(b'\r\n'): - self.simple_response( - httplib.BAD_REQUEST, "HTTP requires CRLF terminators") - return False - - try: - method, uri, req_protocol = request_line.strip().split(b' ', 2) - rp = int(req_protocol[5]), int(req_protocol[7]) - self.method = method.decode('ascii') - except (ValueError, IndexError): - self.simple_response(httplib.BAD_REQUEST, "Malformed Request-Line") - return False - - try: - self.request_protocol = protocol_map[rp] - except KeyError: - self.simple_response(httplib.HTTP_VERSION_NOT_SUPPORTED) - return False - self.response_protocol = protocol_map[min((1, 1), rp)] - - scheme, authority, path = parse_request_uri(uri) - if b'#' in path: - self.simple_response(httplib.BAD_REQUEST, "Illegal #fragment in Request-URI.") - return False - - if scheme: - try: - self.scheme = scheme.decode('ascii') - except ValueError: - self.simple_response(httplib.BAD_REQUEST, 'Un-decodeable scheme') - return False - - qs = b'' - if b'?' in path: - path, qs = path.split(b'?', 1) - try: - self.qs = MultiDict.create_from_query_string(qs) - except Exception: - self.simple_response(httplib.BAD_REQUEST, "Malformed Request-Line", - 'Unparseable query string') - return False - - try: - path = '%2F'.join(unquote(x).decode('utf-8') for x in quoted_slash.split(path)) - except ValueError as e: - self.simple_response(httplib.BAD_REQUEST, as_unicode(e)) - return False - self.path = tuple(filter(None, (x.replace('%2F', '/') for x in path.split('/')))) - - return True - - def read_request_headers(self): - # then all the http headers - try: - self.inheaders = read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size)) - self.request_content_length = int(self.inheaders.get('Content-Length', 0)) - except ValueError as e: - self.simple_response(httplib.BAD_REQUEST, as_unicode(e)) - return False - - if self.request_content_length > self.max_request_body_size: - self.simple_response( - httplib.REQUEST_ENTITY_TOO_LARGE, - "The entity sent with the request exceeds the maximum " - "allowed bytes (%d)." % self.max_request_body_size) - return False - - # Persistent connection support - if self.response_protocol is HTTP11: - # Both server and client are HTTP/1.1 - if self.inheaders.get("Connection", "") == "close": - self.close_connection = True - else: - # Either the server or client (or both) are HTTP/1.0 - if self.inheaders.get("Connection", "") != "Keep-Alive": - self.close_connection = True - - # Transfer-Encoding support - te = () - if self.response_protocol is HTTP11: - rte = self.inheaders.get("Transfer-Encoding") - if rte: - te = [x.strip().lower() for x in rte.split(",") if x.strip()] - self.chunked_read = False - if te: - for enc in te: - if enc == "chunked": - self.chunked_read = True - else: - # Note that, even if we see "chunked", we must reject - # if there is an extension we don't recognize. - self.simple_response(httplib.NOT_IMPLEMENTED, "Unknown transfer encoding: %r" % enc) - self.close_connection = True - return False - - if self.inheaders.get("Expect", '').lower() == "100-continue": - # Don't use simple_response here, because it emits headers - # we don't want. - msg = HTTP11 + " 100 Continue\r\n\r\n" - self.flushed_write(msg.encode('ascii')) - return True - - def simple_response(self, status_code, msg="", read_remaining_input=False): - abort = status_code in (httplib.REQUEST_ENTITY_TOO_LARGE, httplib.REQUEST_URI_TOO_LONG) - if abort: - self.close_connection = True - if self.response_protocol is HTTP1: - # HTTP/1.0 has no 413/414 codes - status_code = httplib.BAD_REQUEST - - msg = msg.encode('utf-8') - buf = [ - '%s %d %s' % (self.response_protocol, status_code, httplib.responses[status_code]), - "Content-Length: %s" % len(msg), - "Content-Type: text/plain; charset=UTF-8", - "Date: " + http_date(), - ] - if abort and self.response_protocol is HTTP11: - buf.append("Connection: close") - buf.append('') - buf = [(x + '\r\n').encode('ascii') for x in buf] - if self.method != 'HEAD': - buf.append(msg) - if read_remaining_input: - self.input_reader.read() - self.flushed_write(b''.join(buf)) - - def send_not_modified(self, etag=None): - buf = [ - '%s %d %s' % (self.response_protocol, httplib.NOT_MODIFIED, httplib.responses[httplib.NOT_MODIFIED]), - "Content-Length: 0", - "Date: " + http_date(), - ] - if etag is not None: - buf.append('ETag: ' + etag) - self.send_buf(buf) - - def send_buf(self, buf, include_cache_headers=True): - if include_cache_headers: - for header in ('Expires', 'Cache-Control', 'Vary'): - val = self.outheaders.get(header) - if val: - buf.append(header + ': ' + val) - buf.append('') - buf = [(x + '\r\n').encode('ascii') for x in buf] - self.flushed_write(b''.join(buf)) - - def send_range_not_satisfiable(self, content_length): - buf = [ - '%s %d %s' % (self.response_protocol, httplib.REQUESTED_RANGE_NOT_SATISFIABLE, httplib.responses[httplib.REQUESTED_RANGE_NOT_SATISFIABLE]), - "Date: " + http_date(), - "Content-Range: bytes */%d" % content_length, - ] - self.send_buf(buf) - - def flushed_write(self, data): - self.conn.socket_file.write(data) - self.conn.socket_file.flush() - - def repr_for_log(self): - ans = ['HTTPPair: %r' % self.request_line] - if self.path: - ans.append('Path: %r' % (self.path,)) - if self.qs: - ans.append('Query: %r' % self.qs) - if self.inheaders: - ans.extend(('In Headers:', self.inheaders.pretty('\t'))) - if self.outheaders: - ans.extend(('Out Headers:', self.outheaders.pretty('\t'))) - return '\n'.join(ans) - - def generate_static_output(self, name, generator): - return generate_static_output(self.server_loop.gso_cache, self.server_loop.gso_lock, name, generator) - - def respond(self): - if self.chunked_read: - self.input_reader = ChunkedReader(self.conn.socket_file, self.max_request_body_size) - else: - self.input_reader = FixedSizeReader(self.conn.socket_file, self.request_content_length) - - try: - output = self.handle_request(self) - except HTTP404 as e: - self.simple_response(httplib.NOT_FOUND, e.message, read_remaining_input=True) - return - # Read and discard any remaining body from the HTTP request - self.input_reader.read() - if self.status_code is None: - self.status_code = httplib.CREATED if self.method == 'POST' else httplib.OK - - try: - self.status_code, output = finalize_output( - output, self.inheaders, self.outheaders, self.status_code, - self.response_protocol is HTTP1, self.method, self.server_loop.opts) - except IfNoneMatch as e: - if self.method in ('GET', 'HEAD'): - self.send_not_modified(e.etag) - else: - self.simple_response(httplib.PRECONDITION_FAILED) - return - except RangeNotSatisfiable as e: - self.send_range_not_satisfiable(e.content_length) - return - - with self.conn.corked: - self.send_headers() - if self.method != 'HEAD': - output.commit(self.conn.socket_file) - self.conn.socket_file.flush() - - def send_headers(self): - self.sent_headers = True - self.outheaders.set('Date', http_date(), replace_all=True) - self.outheaders.set('Server', 'calibre %s' % __version__, replace_all=True) - keep_alive = not self.close_connection and self.server_loop.opts.timeout > 0 - if keep_alive: - self.outheaders.set('Keep-Alive', 'timeout=%d' % self.server_loop.opts.timeout) - if 'Connection' not in self.outheaders: - if self.response_protocol is HTTP11: - if self.close_connection: - self.outheaders.set('Connection', 'close') - else: - if not self.close_connection: - self.outheaders.set('Connection', 'Keep-Alive') - - ct = self.outheaders.get('Content-Type', '') - if ct.startswith('text/') and 'charset=' not in ct: - self.outheaders.set('Content-Type', ct + '; charset=UTF-8') - - buf = [HTTP11 + (' %d ' % self.status_code) + httplib.responses[self.status_code]] - for header, value in sorted(self.outheaders.iteritems(), key=itemgetter(0)): - buf.append('%s: %s' % (header, value)) - buf.append('') - self.conn.socket_file.write(b''.join((x + '\r\n').encode('ascii') for x in buf)) - - -def create_http_handler(handle_request): - return partial(http_communicate, handle_request) diff --git a/src/calibre/srv/http_request.py b/src/calibre/srv/http_request.py new file mode 100644 index 0000000000..e6684c2079 --- /dev/null +++ b/src/calibre/srv/http_request.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python2 +# vim:fileencoding=utf-8 +from __future__ import (unicode_literals, division, absolute_import, + print_function) + +__license__ = 'GPL v3' +__copyright__ = '2015, Kovid Goyal ' + +import re, httplib, repr as reprlib +from io import BytesIO, DEFAULT_BUFFER_SIZE +from urllib import unquote + +from calibre import as_unicode, force_unicode +from calibre.ptempfile import SpooledTemporaryFile +from calibre.srv.loop import Connection, READ, WRITE +from calibre.srv.utils import MultiDict, HTTP1, HTTP11 + +protocol_map = {(1, 0):HTTP1, (1, 1):HTTP11} +quoted_slash = re.compile(br'%2[fF]') +HTTP_METHODS = {'HEAD', 'GET', 'PUT', 'POST', 'TRACE', 'DELETE', 'OPTIONS'} + +def parse_request_uri(uri): # {{{ + """Parse a Request-URI into (scheme, authority, path). + + Note that Request-URI's must be one of:: + + Request-URI = "*" | absoluteURI | abs_path | authority + + Therefore, a Request-URI which starts with a double forward-slash + cannot be a "net_path":: + + net_path = "//" authority [ abs_path ] + + Instead, it must be interpreted as an "abs_path" with an empty first + path segment:: + + abs_path = "/" path_segments + path_segments = segment *( "/" segment ) + segment = *pchar *( ";" param ) + param = *pchar + """ + if uri == b'*': + return None, None, uri + + i = uri.find(b'://') + if i > 0 and b'?' not in uri[:i]: + # An absoluteURI. + # If there's a scheme (and it must be http or https), then: + # http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query + # ]] + scheme, remainder = uri[:i].lower(), uri[i + 3:] + authority, path = remainder.split(b'/', 1) + path = b'/' + path + return scheme, authority, path + + if uri.startswith(b'/'): + # An abs_path. + return None, None, uri + else: + # An authority. + return None, uri, None +# }}} + +# HTTP Header parsing {{{ + +comma_separated_headers = { + 'Accept', 'Accept-Charset', 'Accept-Encoding', + 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control', + 'Connection', 'Content-Encoding', 'Content-Language', 'Expect', + 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE', + 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning', + 'WWW-Authenticate' +} + +decoded_headers = { + 'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect', +} | comma_separated_headers + +class HTTPHeaderParser(object): + + ''' + Parse HTTP headers. Use this class by repeatedly calling the created object + with a single line at a time and checking the finished attribute. Can raise ValueError + for malformed headers, in which case you should probably return BAD_REQUEST. + + Headers which are repeated are folded together using a comma if their + specification so dictates. + ''' + __slots__ = ('hdict', 'lines', 'finished') + + def __init__(self): + self.hdict = MultiDict() + self.lines = [] + self.finished = False + + def push(self, *lines): + for line in lines: + self(line) + + def __call__(self, line): + 'Process a single line' + + def safe_decode(hname, value): + try: + return value.decode('ascii') + except UnicodeDecodeError: + if hname in decoded_headers: + raise + return value + + def commit(): + if not self.lines: + return + line = b' '.join(self.lines) + del self.lines[:] + + k, v = line.partition(b':')[::2] + key = k.strip().decode('ascii').title() + val = safe_decode(key, v.strip()) + if not key or not val: + raise ValueError('Malformed header line: %s' % reprlib.repr(line)) + if key in comma_separated_headers: + existing = self.hdict.pop(key) + if existing is not None: + val = existing + ', ' + val + self.hdict[key] = val + + if line == b'\r\n': + # Normal end of headers + commit() + self.finished = True + return + + if line[0] in b' \t': + # It's a continuation line. + if not self.lines: + raise ValueError('Orphaned continuation line') + self.lines.append(line.lstrip()) + else: + commit() + self.lines.append(line) + +def read_headers(readline): + p = HTTPHeaderParser() + while not p.finished: + p(readline()) + return p.hdict +# }}} + +class HTTPRequest(Connection): + + request_handler = None + static_cache = None + + def __init__(self, *args, **kwargs): + Connection.__init__(self, *args, **kwargs) + self.corked = False + self.max_header_line_size = int(1024 * self.opts.max_header_line_size) + self.max_request_body_size = int(1024 * 1024 * self.opts.max_request_body_size) + + def read(self, buf, endpos): + size = endpos - buf.tell() + if size > 0: + data = self.recv(min(size, DEFAULT_BUFFER_SIZE)) + if data: + buf.write(data) + return len(data) >= size + else: + return False + else: + return True + + def readline(self, buf): + if buf.tell() >= self.max_header_line_size - 1: + self.simple_response(self.header_line_too_long_error_code) + return + data = self.recv(1) + if data: + buf.write(data) + if b'\n' == data: + line = buf.getvalue() + buf.seek(0), buf.truncate() + if line.endswith(b'\r\n'): + return line + else: + self.simple_response(httplib.BAD_REQUEST, 'HTTP requires CRLF line terminators') + + def connection_ready(self): + 'Become ready to read an HTTP request' + self.method = self.request_line = None + self.response_protocol = self.request_protocol = HTTP1 + self.path = self.query = None + self.close_after_response = False + self.header_line_too_long_error_code = httplib.REQUEST_URI_TOO_LONG + self.response_started = False + self.set_state(READ, self.parse_request_line, BytesIO(), first=True) + + def parse_request_line(self, buf, event, first=False): # {{{ + line = self.readline(buf) + if line is None: + return + if line == b'\r\n': + # Ignore a single leading empty line, as per RFC 2616 sec 4.1 + if first: + return self.set_state(READ, self.parse_request_line, BytesIO()) + return self.simple_response(httplib.BAD_REQUEST, 'Multiple leading empty lines not allowed') + + try: + method, uri, req_protocol = line.strip().split(b' ', 2) + rp = int(req_protocol[5]), int(req_protocol[7]) + self.method = method.decode('ascii').upper() + except Exception: + return self.simple_response(httplib.BAD_REQUEST, "Malformed Request-Line") + + if self.method not in HTTP_METHODS: + return self.simple_response(httplib.BAD_REQUEST, "Unknown HTTP method") + + try: + self.request_protocol = protocol_map[rp] + except KeyError: + return self.simple_response(httplib.HTTP_VERSION_NOT_SUPPORTED) + self.response_protocol = protocol_map[min((1, 1), rp)] + scheme, authority, path = parse_request_uri(uri) + if b'#' in path: + return self.simple_response(httplib.BAD_REQUEST, "Illegal #fragment in Request-URI.") + + if scheme: + try: + self.scheme = scheme.decode('ascii') + except ValueError: + return self.simple_response(httplib.BAD_REQUEST, 'Un-decodeable scheme') + + qs = b'' + if b'?' in path: + path, qs = path.split(b'?', 1) + try: + self.query = MultiDict.create_from_query_string(qs) + except Exception: + return self.simple_response(httplib.BAD_REQUEST, 'Unparseable query string') + + try: + path = '%2F'.join(unquote(x).decode('utf-8') for x in quoted_slash.split(path)) + except ValueError as e: + return self.simple_response(httplib.BAD_REQUEST, as_unicode(e)) + self.path = tuple(filter(None, (x.replace('%2F', '/') for x in path.split('/')))) + self.header_line_too_long_error_code = httplib.REQUEST_ENTITY_TOO_LARGE + self.request_line = line.rstrip() + self.set_state(READ, self.parse_header_line, HTTPHeaderParser(), BytesIO()) + # }}} + + @property + def state_description(self): + return 'Request: %s' % force_unicode(self.request_line, 'utf-8') + + def parse_header_line(self, parser, buf, event): + line = self.readline(buf) + if line is None: + return + try: + parser(line) + except ValueError: + self.simple_response(httplib.BAD_REQUEST, 'Failed to parse header line') + return + if parser.finished: + self.finalize_headers(parser.hdict) + + def finalize_headers(self, inheaders): + request_content_length = int(inheaders.get('Content-Length', 0)) + if request_content_length > self.max_request_body_size: + return self.simple_response(httplib.REQUEST_ENTITY_TOO_LARGE, + "The entity sent with the request exceeds the maximum " + "allowed bytes (%d)." % self.max_request_body_size) + # Persistent connection support + if self.response_protocol is HTTP11: + # Both server and client are HTTP/1.1 + if inheaders.get("Connection", "") == "close": + self.close_after_response = True + else: + # Either the server or client (or both) are HTTP/1.0 + if inheaders.get("Connection", "") != "Keep-Alive": + self.close_after_response = True + + # Transfer-Encoding support + te = () + if self.response_protocol is HTTP11: + rte = inheaders.get("Transfer-Encoding") + if rte: + te = [x.strip().lower() for x in rte.split(",") if x.strip()] + chunked_read = False + if te: + for enc in te: + if enc == "chunked": + chunked_read = True + else: + # Note that, even if we see "chunked", we must reject + # if there is an extension we don't recognize. + return self.simple_response(httplib.NOT_IMPLEMENTED, "Unknown transfer encoding: %r" % enc) + + if inheaders.get("Expect", '').lower() == "100-continue": + buf = BytesIO((HTTP11 + " 100 Continue\r\n\r\n").encode('ascii')) + return self.set_state(WRITE, self.write_continue, buf, inheaders, request_content_length, chunked_read) + + self.read_request_body(inheaders, request_content_length, chunked_read) + + def write_continue(self, buf, inheaders, request_content_length, chunked_read, event): + if self.write(buf): + self.read_request_body(inheaders, request_content_length, chunked_read) + + def read_request_body(self, inheaders, request_content_length, chunked_read): + buf = SpooledTemporaryFile(prefix='rq-body-', max_size=DEFAULT_BUFFER_SIZE, dir=self.tdir) + if chunked_read: + self.set_state(READ, self.read_chunk_length, inheaders, BytesIO(), buf, [0]) + else: + if request_content_length > 0: + self.set_state(READ, self.sized_read, inheaders, buf, request_content_length) + else: + self.prepare_response(inheaders, BytesIO()) + + def sized_read(self, inheaders, buf, request_content_length, event): + if self.read(buf, request_content_length): + self.prepare_response(inheaders, buf) + + def read_chunk_length(self, inheaders, line_buf, buf, bytes_read, event): + line = self.readline(line_buf) + if line is None: + return + bytes_read[0] += len(line) + try: + chunk_size = int(line.strip(), 16) + except Exception: + return self.simple_response(httplib.BAD_REQUEST, '%s is not a valid chunk size' % reprlib.repr(line.strip())) + if bytes_read[0] + chunk_size + 2 > self.max_request_body_size: + return self.simple_response(httplib.REQUEST_ENTITY_TOO_LARGE, + 'Chunked request is larger than %d bytes' % self.max_request_body_size) + if chunk_size == 0: + self.set_state(READ, self.read_chunk_separator, inheaders, BytesIO(), buf, bytes_read, last=True) + else: + self.set_state(READ, self.read_chunk, inheaders, buf, chunk_size, buf.tell() + chunk_size, bytes_read) + + def read_chunk(self, inheaders, buf, chunk_size, end, bytes_read, event): + if not self.read(buf, end): + return + bytes_read[0] += chunk_size + self.set_state(READ, self.read_chunk_separator, inheaders, BytesIO(), buf, bytes_read) + + def read_chunk_separator(self, inheaders, line_buf, buf, bytes_read, event, last=False): + line = self.readline(line_buf) + if line is None: + return + if line != b'\r\n': + return self.simple_response(httplib.BAD_REQUEST, 'Chunk does not have trailing CRLF') + bytes_read[0] += len(line) + if bytes_read[0] > self.max_request_body_size: + return self.simple_response(httplib.REQUEST_ENTITY_TOO_LARGE, + 'Chunked request is larger than %d bytes' % self.max_request_body_size) + if last: + self.prepare_response(inheaders, buf) + else: + self.set_state(READ, self.read_chunk_length, inheaders, BytesIO(), buf, bytes_read) + + def handle_timeout(self): + if self.response_started: + return False + self.simple_response(httplib.REQUEST_TIMEOUT) + return True + + def write(self, buf, end=None): + raise NotImplementedError() + + def simple_response(self, status_code, msg='', close_after_response=True): + raise NotImplementedError() + + def prepare_response(self, inheaders, request_body_file): + raise NotImplementedError() diff --git a/src/calibre/srv/http_response.py b/src/calibre/srv/http_response.py new file mode 100644 index 0000000000..f6b975ed95 --- /dev/null +++ b/src/calibre/srv/http_response.py @@ -0,0 +1,525 @@ +#!/usr/bin/env python2 +# vim:fileencoding=utf-8 +from __future__ import (unicode_literals, division, absolute_import, + print_function) + +__license__ = 'GPL v3' +__copyright__ = '2015, Kovid Goyal ' + +import os, httplib, hashlib, uuid, zlib, time, struct, repr as reprlib +from select import PIPE_BUF +from collections import namedtuple +from io import BytesIO, DEFAULT_BUFFER_SIZE +from itertools import chain, repeat, izip_longest +from operator import itemgetter +from functools import wraps + +from calibre import guess_type, force_unicode +from calibre.constants import __version__ +from calibre.srv.loop import WRITE +from calibre.srv.errors import HTTP404 +from calibre.srv.http_request import HTTPRequest, read_headers +from calibre.srv.sendfile import file_metadata, sendfile_to_socket_async, CannotSendfile, SendfileInterrupted +from calibre.srv.utils import MultiDict, start_cork, stop_cork, http_date, HTTP1, HTTP11, socket_errors_socket_closed + +Range = namedtuple('Range', 'start stop size') +MULTIPART_SEPARATOR = uuid.uuid4().hex.decode('ascii') + +def header_list_to_file(buf): # {{{ + buf.append('') + return BytesIO(b''.join((x + '\r\n').encode('ascii') for x in buf)) +# }}} + +def parse_multipart_byterange(buf, content_type): # {{{ + sep = (content_type.rsplit('=', 1)[-1]).encode('utf-8') + ans = [] + + def parse_part(): + line = buf.readline() + if not line: + raise ValueError('Premature end of message') + if not line.startswith(b'--' + sep): + raise ValueError('Malformed start of multipart message: %s' % reprlib.repr(line)) + if line.endswith(b'--'): + return None + headers = read_headers(buf.readline) + cr = headers.get('Content-Range') + if not cr: + raise ValueError('Missing Content-Range header in sub-part') + if not cr.startswith('bytes '): + raise ValueError('Malformed Content-Range header in sub-part, no prefix') + try: + start, stop = map(lambda x: int(x.strip()), cr.partition(' ')[-1].partition('/')[0].partition('-')[::2]) + except Exception: + raise ValueError('Malformed Content-Range header in sub-part, failed to parse byte range') + content_length = stop - start + 1 + ret = buf.read(content_length) + if len(ret) != content_length: + raise ValueError('Malformed sub-part, length of body not equal to length specified in Content-Range') + buf.readline() + return (start, ret) + while True: + data = parse_part() + if data is None: + break + ans.append(data) + return ans +# }}} + +def parse_if_none_match(val): # {{{ + return {x.strip() for x in val.split(',')} +# }}} + +def acceptable_encoding(val, allowed=frozenset({'gzip'})): # {{{ + def enc(x): + e, r = x.partition(';')[::2] + p, v = r.partition('=')[::2] + q = 1.0 + if p == 'q' and v: + try: + q = float(v) + except Exception: + pass + return e.lower(), q + + emap = dict(enc(x.strip()) for x in val.split(',')) + acceptable = sorted(set(emap) & allowed, key=emap.__getitem__, reverse=True) + if acceptable: + return acceptable[0] +# }}} + +def get_ranges(headervalue, content_length): # {{{ + ''' Return a list of ranges from the Range header. If this function returns + an empty list, it indicates no valid range was found. ''' + if not headervalue: + return None + + result = [] + try: + bytesunit, byteranges = headervalue.split("=", 1) + except Exception: + return None + if bytesunit.strip() != 'bytes': + return None + + for brange in byteranges.split(","): + start, stop = [x.strip() for x in brange.split("-", 1)] + if start: + if not stop: + stop = content_length - 1 + try: + start, stop = int(start), int(stop) + except Exception: + continue + if start >= content_length: + continue + if stop < start: + continue + stop = min(stop, content_length - 1) + result.append(Range(start, stop, stop - start + 1)) + elif stop: + # Negative subscript (last N bytes) + try: + stop = int(stop) + except Exception: + continue + if stop > content_length: + result.append(Range(0, content_length-1, content_length)) + else: + result.append(Range(content_length - stop, content_length - 1, stop)) + + return result +# }}} + +# gzip transfer encoding {{{ +def gzip_prefix(mtime=None): + # See http://www.gzip.org/zlib/rfc-gzip.html + if mtime is None: + mtime = time.time() + return b''.join(( + b'\x1f\x8b', # ID1 and ID2: gzip marker + b'\x08', # CM: compression method + b'\x00', # FLG: none set + # MTIME: 4 bytes + struct.pack(b" 0 + if keep_alive: + outheaders.set('Keep-Alive', 'timeout=%d' % self.opts.timeout) + if 'Connection' not in outheaders: + if self.response_protocol is HTTP11: + if self.close_after_response: + outheaders.set('Connection', 'close') + else: + if not self.close_after_response: + outheaders.set('Connection', 'Keep-Alive') + + ct = outheaders.get('Content-Type', '') + if ct.startswith('text/') and 'charset=' not in ct: + outheaders.set('Content-Type', ct + '; charset=UTF-8') + + buf = [HTTP11 + (' %d ' % data.status_code) + httplib.responses[data.status_code]] + for header, value in sorted(outheaders.iteritems(), key=itemgetter(0)): + buf.append('%s: %s' % (header, value)) + buf.append('') + self.response_ready(BytesIO(b''.join((x + '\r\n').encode('ascii') for x in buf)), output=output) + + def send_range_not_satisfiable(self, content_length): + buf = [ + '%s %d %s' % (self.response_protocol, httplib.REQUESTED_RANGE_NOT_SATISFIABLE, httplib.responses[httplib.REQUESTED_RANGE_NOT_SATISFIABLE]), + "Date: " + http_date(), + "Content-Range: bytes */%d" % content_length, + ] + self.response_ready(header_list_to_file(buf)) + + def send_not_modified(self, etag=None): + buf = [ + '%s %d %s' % (self.response_protocol, httplib.NOT_MODIFIED, httplib.responses[httplib.NOT_MODIFIED]), + "Content-Length: 0", + "Date: " + http_date(), + ] + if etag is not None: + buf.append('ETag: ' + etag) + self.response_ready(header_list_to_file(buf)) + + def response_ready(self, header_file, output=None): + self.response_started = True + start_cork(self.socket) + self.corked = True + self.use_sendfile = False + self.set_state(WRITE, self.write_response_headers, header_file, output) + + def write_response_headers(self, buf, output, event): + if self.write(buf): + self.write_response_body(output) + + def write_response_body(self, output): + if output is None or self.method == 'HEAD': + self.reset_state() + return + if isinstance(output, ReadableOutput): + self.use_sendfile = output.use_sendfile and self.opts.use_sendfile and sendfile_to_socket_async is not None + if output.ranges is not None: + if isinstance(output.ranges, Range): + r = output.ranges + output.src_file.seek(r.start) + self.set_state(WRITE, self.write_buf, output.src_file, end=r.stop + 1) + else: + self.set_state(WRITE, self.write_ranges, output.src_file, output.ranges, first=True) + else: + self.set_state(WRITE, self.write_buf, output.src_file) + elif isinstance(output, GeneratedOutput): + self.set_state(WRITE, self.write_iter, chain(output.output, repeat(None, 1))) + else: + raise TypeError('Unknown output type: %r' % output) + + def write_buf(self, buf, event, end=None): + if self.write(buf, end=end): + self.reset_state() + + def write_ranges(self, buf, ranges, event, first=False): + r, range_part = next(ranges) + if r is None: + # EOF range part + self.set_state(WRITE, self.write_buf, BytesIO(b'\r\n' + range_part)) + else: + buf.seek(r.start) + self.set_state(WRITE, self.write_range_part, BytesIO((b'' if first else b'\r\n') + range_part + b'\r\n'), buf, r.stop + 1, ranges) + + def write_range_part(self, part_buf, buf, end, ranges, event): + if self.write(part_buf): + self.set_state(WRITE, self.write_range, buf, end, ranges) + + def write_range(self, buf, end, ranges, event): + if self.write(buf, end=end): + self.set_state(WRITE, self.write_ranges, buf, ranges) + + def write_iter(self, output, event): + chunk = next(output) + if chunk is None: + self.set_state(WRITE, self.write_chunk, BytesIO(b'0\r\n\r\n'), output, last=True) + else: + if not isinstance(chunk, bytes): + chunk = chunk.encode('utf-8') + chunk = ('%X\r\n' % len(chunk)).encode('ascii') + chunk + b'\r\n' + self.set_state(WRITE, self.write_chunk, BytesIO(chunk), output) + + def write_chunk(self, buf, output, event, last=False): + if self.write(buf): + if last: + self.reset_state() + else: + self.set_state(WRITE, self.write_iter, output) + + def reset_state(self): + self.connection_ready() + self.ready = not self.close_after_response + stop_cork(self.socket) + self.corked = False + + def report_unhandled_exception(self, e, formatted_traceback): + self.simple_response(httplib.INTERNAL_SERVER_ERROR) + + def finalize_output(self, output, request, is_http1): + opts = self.opts + outheaders = request.outheaders + stat_result = file_metadata(output) + if stat_result is not None: + output = filesystem_file_output(output, outheaders, stat_result) + if 'Content-Type' not in outheaders: + mt = guess_type(output.name)[0] + if mt: + if mt in {'text/plain', 'text/html', 'application/javascript', 'text/css'}: + mt =+ '; charset=UTF-8' + outheaders['Content-Type'] = mt + elif isinstance(output, (bytes, type(''))): + output = dynamic_output(output, outheaders) + elif hasattr(output, 'read'): + output = ReadableOutput(output) + elif isinstance(output, StaticOutput): + output = ReadableOutput(BytesIO(output.data), etag=output.etag, content_length=output.content_length) + else: + output = GeneratedOutput(output) + ct = outheaders.get('Content-Type', '').partition(';')[0] + compressible = (not ct or ct.startswith('text/') or ct.startswith('image/svg') or + ct in {'application/json', 'application/javascript'}) + compressible = (compressible and request.status_code == httplib.OK and + (opts.compress_min_size > -1 and output.content_length >= opts.compress_min_size) and + acceptable_encoding(request.inheaders.get('Accept-Encoding', '')) and not is_http1) + accept_ranges = (not compressible and output.accept_ranges is not None and request.status_code == httplib.OK and + not is_http1) + ranges = get_ranges(request.inheaders.get('Range'), output.content_length) if output.accept_ranges and self.method in ('GET', 'HEAD') else None + if_range = (request.inheaders.get('If-Range') or '').strip() + if if_range and if_range != output.etag: + ranges = None + if ranges is not None and not ranges: + return self.send_range_not_satisfiable(output.content_length) + + for header in ('Accept-Ranges', 'Content-Encoding', 'Transfer-Encoding', 'ETag', 'Content-Length'): + outheaders.pop('header', all=True) + + none_match = parse_if_none_match(request.inheaders.get('If-None-Match', '')) + matched = '*' in none_match or (output.etag and output.etag in none_match) + if matched: + if self.method in ('GET', 'HEAD'): + self.send_not_modified(output.etag) + else: + self.simple_response(httplib.PRECONDITION_FAILED) + return + + output.ranges = None + + if output.etag and self.method in ('GET', 'HEAD'): + outheaders.set('ETag', output.etag, replace_all=True) + if accept_ranges: + outheaders.set('Accept-Ranges', 'bytes', replace_all=True) + if compressible and not ranges: + outheaders.set('Content-Encoding', 'gzip', replace_all=True) + output = GeneratedOutput(compress_readable_output(output.src_file), etag=output.etag) + if output.content_length is not None and not compressible and not ranges: + outheaders.set('Content-Length', '%d' % output.content_length, replace_all=True) + + if compressible or output.content_length is None: + outheaders.set('Transfer-Encoding', 'chunked', replace_all=True) + + if ranges: + if len(ranges) == 1: + r = ranges[0] + outheaders.set('Content-Length', '%d' % r.size, replace_all=True) + outheaders.set('Content-Range', 'bytes %d-%d/%d' % (r.start, r.stop, output.content_length), replace_all=True) + output.ranges = r + else: + range_parts = get_range_parts(ranges, outheaders.get('Content-Type'), output.content_length) + size = sum(map(len, range_parts)) + sum(r.size + 4 for r in ranges) + outheaders.set('Content-Length', '%d' % size, replace_all=True) + outheaders.set('Content-Type', 'multipart/byteranges; boundary=' + MULTIPART_SEPARATOR, replace_all=True) + output.ranges = izip_longest(ranges, range_parts) + request.status_code = httplib.PARTIAL_CONTENT + return output + +def create_http_handler(handler): + static_cache = {} # noqa + @wraps(handler) + def wrapper(*args, **kwargs): + ans = HTTPConnection(*args, **kwargs) + ans.request_handler = handler + ans.static_cache = {} + return ans + return wrapper diff --git a/src/calibre/srv/loop.py b/src/calibre/srv/loop.py index 58828363b2..05ae578950 100644 --- a/src/calibre/srv/loop.py +++ b/src/calibre/srv/loop.py @@ -6,510 +6,132 @@ from __future__ import (unicode_literals, division, absolute_import, __license__ = 'GPL v3' __copyright__ = '2015, Kovid Goyal ' -import socket, os, ssl, time, sys -from operator import and_ -from Queue import Queue, Full -from threading import Thread, current_thread, Lock -from io import DEFAULT_BUFFER_SIZE, BytesIO +import ssl, socket, select, os, traceback +from io import BytesIO +from functools import partial -from calibre.srv.errors import MaxSizeExceeded +from calibre import as_unicode +from calibre.ptempfile import TemporaryDirectory from calibre.srv.opts import Options -from calibre.srv.utils import socket_errors_to_ignore, socket_error_eintr, socket_errors_nonblocking, Corked, HandleInterrupt +from calibre.srv.utils import ( + socket_errors_socket_closed, socket_errors_nonblocking, HandleInterrupt, socket_errors_eintr) from calibre.utils.socket_inheritance import set_socket_inherit from calibre.utils.logging import ThreadSafeLog +from calibre.utils.monotonic import monotonic -class SocketFile(object): # {{{ - """Faux file object attached to a socket object. Works with non-blocking - sockets, unlike the fileobject created by socket.makefile() """ +READ, WRITE, RDWR = 'READ', 'WRITE', 'RDWR' - name = "" +class Connection(object): - __slots__ = ( - "mode", "bufsize", "softspace", "_sock", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", "_wbuf_len", "_close", 'bytes_read', 'bytes_written', - ) - - def __init__(self, sock, bufsize=-1, close=False): - self._sock = sock - self.bytes_read = self.bytes_written = 0 - self.mode = 'r+b' - self.bufsize = DEFAULT_BUFFER_SIZE if bufsize < 0 else bufsize - self.softspace = False - # _rbufsize is the suggested recv buffer size. It is *strictly* - # obeyed within readline() for recv calls. If it is larger than - # default_bufsize it will be used for recv calls within read(). - if self.bufsize == 0: - self._rbufsize = 1 - elif bufsize == 1: - self._rbufsize = DEFAULT_BUFFER_SIZE + def __init__(self, socket, opts, ssl_context, tdir): + self.opts = opts + self.tdir = tdir + self.ssl_context = ssl_context + self.wait_for = READ + self.response_started = False + if self.ssl_context is not None: + self.ready = False + self.socket = self.ssl_context.wrap_socket(socket, server_side=True, do_handshake_on_connect=False) + self.set_state(RDWR, self.do_ssl_handshake) else: - self._rbufsize = bufsize - self._wbufsize = bufsize - # We use BytesIO for the read buffer to avoid holding a list - # of variously sized string objects which have been known to - # fragment the heap due to how they are malloc()ed and often - # realloc()ed down much smaller than their original allocation. - self._rbuf = BytesIO() - self._wbuf = [] # A list of strings - self._wbuf_len = 0 - self._close = close + self.ready = True + self.socket = socket + self.connection_ready() + self.last_activity = monotonic() - @property - def closed(self): - return self._sock is None + def set_state(self, wait_for, func, *args, **kwargs): + self.wait_for = wait_for + if args or kwargs: + func = partial(func, *args, **kwargs) + self.handle_event = func + + def do_ssl_handshake(self, event): + try: + self._sslobj.do_handshake() + except ssl.SSLWantReadError: + self.set_state(READ, self.do_ssl_handshake) + except ssl.SSLWantWriteError: + self.set_state(WRITE, self.do_ssl_handshake) + self.ready = True + self.connection_ready() + + def send(self, data): + try: + ret = self.socket.send(data) + self.last_activity = monotonic() + return ret + except socket.error as e: + if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr: + return 0 + elif e.errno in socket_errors_socket_closed: + self.ready = False + return 0 + raise + + def recv(self, buffer_size): + try: + data = self.socket.recv(buffer_size) + self.last_activity = monotonic() + if not data: + # a closed connection is indicated by signaling + # a read condition, and having recv() return 0. + self.ready = False + return b'' + return data + except socket.error as e: + if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr: + return b'' + if e.errno in socket_errors_socket_closed: + self.ready = False + return b'' + raise def close(self): - try: - if self._sock is not None: - try: - self.flush() - except socket.error: - pass - finally: - if self._close and self._sock is not None: - self._sock.close() - self._sock = None - - def __del__(self): - try: - self.close() - except: - # close() may fail if __init__ didn't complete - pass - - def fileno(self): - return self._sock.fileno() - - def gettimeout(self): - return self._sock.gettimeout() - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - - def flush(self): - if self._wbuf_len: - data = b''.join(self._wbuf) - self._wbuf = [] - self._wbuf_len = 0 - data_size = len(data) - view = memoryview(data) - write_offset = 0 - buffer_size = max(self._rbufsize, DEFAULT_BUFFER_SIZE) - try: - while write_offset < data_size: - try: - bytes_sent = self._sock.send(view[write_offset:write_offset+buffer_size]) - write_offset += bytes_sent - self.bytes_written += bytes_sent - except socket.error as e: - if e.args[0] not in socket_errors_nonblocking: - raise - finally: - if write_offset < data_size: - remainder = data[write_offset:] - self._wbuf.append(remainder) - self._wbuf_len = len(remainder) - del view, data # explicit free - - def write(self, data): - if not isinstance(data, bytes): - raise TypeError('Cannot write data of type: %s to a socket' % type(data)) - if not data: - return - self._wbuf.append(data) - self._wbuf_len += len(data) - if self._wbufsize == 0 or (self._wbufsize == 1 and b'\n' in data) or (self._wbufsize > 1 and self._wbuf_len >= self._wbufsize): - self.flush() - - def writelines(self, lines): - for line in lines: - self.write(line) - - def recv(self, size): - while True: - try: - data = self._sock.recv(size) - self.bytes_read += len(data) - return data - except socket.error, e: - if e.args[0] not in socket_errors_nonblocking and e.args[0] not in socket_error_eintr: - raise - - def read(self, size=-1): - # Use max, disallow tiny reads in a loop as they are very inefficient. - # We never leave read() with any leftover data from a new recv() call - # in our internal buffer. - rbufsize = max(self._rbufsize, DEFAULT_BUFFER_SIZE) - buf = self._rbuf - buf.seek(0, os.SEEK_END) - if size < 0: - # Read until EOF - self._rbuf = BytesIO() # reset _rbuf. we consume it via buf. - while True: - data = self.recv(rbufsize) - if not data: - break - buf.write(data) - return buf.getvalue() - else: - # Read until size bytes or EOF seen, whichever comes first - buf_len = buf.tell() - if buf_len >= size: - # Already have size bytes in our buffer? Extract and return. - buf.seek(0) - rv = buf.read(size) - self._rbuf = BytesIO() - self._rbuf.write(buf.read()) - return rv - - self._rbuf = BytesIO() # reset _rbuf. we consume it via buf. - while True: - left = size - buf_len - # recv() will malloc the amount of memory given as its - # parameter even though it often returns much less data - # than that. The returned data string is short lived - # as we copy it into a StringIO and free it. This avoids - # fragmentation issues on many platforms. - data = self.recv(left) - if not data: - break - n = len(data) - if n == size and not buf_len: - # Shortcut. Avoid buffer data copies when: - # - We have no data in our buffer. - # AND - # - Our call to recv returned exactly the - # number of bytes we were asked to read. - return data - if n == left: - buf.write(data) - del data # explicit free - break - buf.write(data) # noqa - buf_len += n - del data # noqa explicit free - return buf.getvalue() - - def readline(self, size=-1, maxsize=sys.maxsize): - buf = self._rbuf - buf.seek(0, os.SEEK_END) - if buf.tell() > 0: - # check if we already have it in our buffer - buf.seek(0) - bline = buf.readline(size) - self._rbuf = BytesIO() - if bline.endswith(b'\n') or len(bline) == size: - self._rbuf.write(buf.read()) - if len(bline) > maxsize: - raise MaxSizeExceeded('Line length', len(bline), maxsize) - return bline - else: - self._rbuf.write(bline) - self._rbuf.write(buf.read()) - del bline - - if size < 0: - # Read until \n or EOF, whichever comes first - if self._rbufsize <= 1: - # Speed up unbuffered case - buf.seek(0) - buffers = [buf.read()] - self._rbuf = BytesIO() # reset _rbuf. we consume it via buf. - data = None - recv = self.recv - sz = len(buffers[0]) - while data != b'\n': - data = recv(1) - if not data: - break - sz += 1 - if sz > maxsize: - raise MaxSizeExceeded('Line length', sz, maxsize) - buffers.append(data) - return b''.join(buffers) - - buf.seek(0, os.SEEK_END) - self._rbuf = BytesIO() # reset _rbuf. we consume it via buf. - while True: - data = self.recv(self._rbufsize) - if not data: - break - nl = data.find(b'\n') - if nl >= 0: - nl += 1 - buf.write(data[:nl]) - self._rbuf.write(data[nl:]) - del data - break - buf.write(data) # noqa - if buf.tell() > maxsize: - raise MaxSizeExceeded('Line length', buf.tell(), maxsize) - return buf.getvalue() - else: - # Read until size bytes or \n or EOF seen, whichever comes first - buf.seek(0, os.SEEK_END) - buf_len = buf.tell() - if buf_len >= size: - buf.seek(0) - rv = buf.read(size) - self._rbuf = BytesIO() - self._rbuf.write(buf.read()) - if len(rv) > maxsize: - raise MaxSizeExceeded('Line length', len(rv), maxsize) - return rv - self._rbuf = BytesIO() # reset _rbuf. we consume it via buf. - while True: - data = self.recv(self._rbufsize) - if not data: - break - left = size - buf_len - # did we just receive a newline? - nl = data.find(b'\n', 0, left) - if nl >= 0: - nl += 1 - # save the excess data to _rbuf - self._rbuf.write(data[nl:]) - if buf_len: - buf.write(data[:nl]) - break - else: - # Shortcut. Avoid data copy through buf when returning - # a substring of our first recv() and buf has no - # existing data. - if nl > maxsize: - raise MaxSizeExceeded('Line length', nl, maxsize) - return data[:nl] - n = len(data) - if n == size and not buf_len: - # Shortcut. Avoid data copy through buf when - # returning exactly all of our first recv(). - if n > maxsize: - raise MaxSizeExceeded('Line length', n, maxsize) - return data - if n >= left: - buf.write(data[:left]) - self._rbuf.write(data[left:]) - break - buf.write(data) - buf_len += n - if buf.tell() > maxsize: - raise MaxSizeExceeded('Line length', buf.tell(), maxsize) - return buf.getvalue() - - def readlines(self, sizehint=0, maxsize=sys.maxsize): - total = 0 - ans = [] - while True: - line = self.readline(maxsize=maxsize) - if not line: - break - ans.append(line) - total += len(line) - if sizehint and total >= sizehint: - break - return ans - - def __iter__(self): - line = True - while line: - line = self.readline() - if line: - yield line - -# }}} - -class Connection(object): # {{{ - - ' A thin wrapper around an active socket ' - - remote_addr = None - remote_port = None - - def __init__(self, server_loop, socket): - self.server_loop = server_loop - self.socket = socket - self.corked = Corked(socket) - self.socket_file = SocketFile(socket) - self.closed = False - - def close(self): - """Close the socket underlying this connection.""" - if self.closed: - return - self.socket_file.close() + self.ready = False + self.handle_event = None # prevent reference cycles try: self.socket.shutdown(socket.SHUT_WR) self.socket.close() except socket.error: pass - self.closed = True - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() -# }}} - -class WorkerThread(Thread): # {{{ - - daemon = True - - def __init__(self, server_loop): - self.serving = False - self.server_loop = server_loop - self.conn = None - self.forcible_shutdown = False - Thread.__init__(self, name='ServerWorker') - - def run(self): - try: - while True: - self.serving = False - self.conn = conn = self.server_loop.requests.get() - if conn is None: - return # Clean exit - with conn, self: - self.server_loop.req_resp_handler(conn) - except (KeyboardInterrupt, SystemExit): - self.server_loop.stop() - except socket.error: - if not self.forcible_shutdown: - raise - - def __enter__(self): - self.serving = True - return self - - def __exit__(self, *args): - self.serving = False -# }}} - -class ThreadPool(object): # {{{ - - def __init__(self, server_loop, min_threads=10, max_threads=-1, accepted_queue_size=-1, accepted_queue_timeout=10): - self.server_loop = server_loop - self.min_threads = max(1, min_threads) - self.max_threads = max_threads - self._threads = [] - self._queue = Queue(maxsize=accepted_queue_size) - self._queue_put_timeout = accepted_queue_timeout - self.get = self._queue.get - - def start(self): - """Start the pool of threads.""" - self._threads = [self._spawn_worker() for i in xrange(self.min_threads)] @property - def idle(self): - return sum(int(not w.serving) for w in self._threads) + def state_description(self): + return '' - @property - def busy(self): - return sum(int(w.serving) for w in self._threads) + def report_unhandled_exception(self, e, formatted_traceback): + pass - def put(self, obj): - self._queue.put(obj, block=True, timeout=self._queue_put_timeout) + def connection_ready(self): + raise NotImplementedError() - def grow(self, amount): - """Spawn new worker threads (not above self.max_threads).""" - budget = max(self.max_threads - len(self._threads), 0) if self.max_threads > 0 else sys.maxsize - n_new = min(amount, budget) - self._threads.extend([self._spawn_worker() for i in xrange(n_new)]) - - def _spawn_worker(self): - worker = WorkerThread(self.server_loop) - worker.start() - return worker - - @staticmethod - def _all(func, items): - results = [func(item) for item in items] - return reduce(and_, results, True) - - def shrink(self, amount): - """Kill off worker threads (not below self.min_threads).""" - # Grow/shrink the pool if necessary. - # Remove any dead threads from our list - orig = len(self._threads) - self._threads = [t for t in self._threads if t.is_alive()] - amount -= orig - len(self._threads) - - # calculate the number of threads above the minimum - n_extra = max(len(self._threads) - self.min_threads, 0) - - # don't remove more than amount - n_to_remove = min(amount, n_extra) - - # put shutdown requests on the queue equal to the number of threads - # to remove. As each request is processed by a worker, that worker - # will terminate and be culled from the list. - for n in xrange(n_to_remove): - self._queue.put(None) - - def stop(self, timeout=5): - # Must shut down threads here so the code that calls - # this method can know when all threads are stopped. - for worker in self._threads: - self._queue.put(None) - - # Don't join the current thread (this should never happen, since - # ServerLoop calls stop() in its own thread, but better to be safe). - current = current_thread() - if timeout and timeout >= 0: - endtime = time.time() + timeout - while self._threads: - worker = self._threads.pop() - if worker is not current and worker.is_alive(): - try: - if timeout is None or timeout < 0: - worker.join() - else: - remaining_time = endtime - time.time() - if remaining_time > 0: - worker.join(remaining_time) - if worker.is_alive(): - # We exhausted the timeout. - # Forcibly shut down the socket. - worker.forcible_shutdown = True - c = worker.conn - if c and not c.socket_file.closed: - c.socket.shutdown(socket.SHUT_RDWR) - c.socket.close() - worker.join() - except KeyboardInterrupt: - pass # Ignore repeated Ctrl-C. - - @property - def qsize(self): - return self._queue.qsize() -# }}} + def handle_timeout(self): + return False class ServerLoop(object): def __init__( self, - req_resp_handler, + handler, bind_address=('localhost', 8080), opts=None, - # A calibre logging object. If None a default log that logs to + # A calibre logging object. If None, a default log that logs to # stdout is used log=None ): + self.ready = False + self.handler = handler self.opts = opts or Options() - self.req_resp_handler = req_resp_handler self.log = log or ThreadSafeLog(level=ThreadSafeLog.DEBUG) - self.gso_cache, self.gso_lock = {}, Lock() - ba = bind_address - if not isinstance(ba, basestring): - ba = tuple(ba) - if not ba[0]: - # AI_PASSIVE does not work with host of '' or None - ba = ('0.0.0.0', ba[1]) + + ba = tuple(bind_address) + if not ba[0]: + # AI_PASSIVE does not work with host of '' or None + ba = ('0.0.0.0', ba[1]) self.bind_address = ba self.bound_address = None + self.connection_map = {} + self.ssl_context = None if self.opts.ssl_certfile is not None and self.opts.ssl_keyfile is not None: self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) @@ -523,51 +145,33 @@ class ServerLoop(object): set_socket_inherit(self.pre_activated_socket, False) self.bind_address = self.pre_activated_socket.getsockname() - self.ready = False - self.requests = ThreadPool(self, min_threads=self.opts.min_threads, max_threads=self.opts.max_threads) - def __str__(self): return "%s(%r)" % (self.__class__.__name__, self.bind_address) __repr__ = __str__ + @property + def num_active_connections(self): + return len(self.connection_map) + def serve_forever(self): """ Listen for incoming connections. """ if self.pre_activated_socket is None: - # Select the appropriate socket - if isinstance(self.bind_address, basestring): - # AF_UNIX socket - - # So we can reuse the socket... - try: - os.unlink(self.bind_address) - except EnvironmentError: - pass - - # So everyone can access the socket... - try: - os.chmod(self.bind_address, 0777) - except EnvironmentError: - pass - - info = [ - (socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_address)] - else: - # AF_INET or AF_INET6 socket - # Get the correct address family for our host (allows IPv6 - # addresses) - host, port = self.bind_address - try: - info = socket.getaddrinfo( - host, port, socket.AF_UNSPEC, - socket.SOCK_STREAM, 0, socket.AI_PASSIVE) - except socket.gaierror: - if ':' in host: - info = [(socket.AF_INET6, socket.SOCK_STREAM, - 0, "", self.bind_address + (0, 0))] - else: - info = [(socket.AF_INET, socket.SOCK_STREAM, - 0, "", self.bind_address)] + # AF_INET or AF_INET6 socket + # Get the correct address family for our host (allows IPv6 + # addresses) + host, port = self.bind_address + try: + info = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM, 0, socket.AI_PASSIVE) + except socket.gaierror: + if ':' in host: + info = [(socket.AF_INET6, socket.SOCK_STREAM, + 0, "", self.bind_address + (0, 0))] + else: + info = [(socket.AF_INET, socket.SOCK_STREAM, + 0, "", self.bind_address)] self.socket = None msg = "No socket could be created" @@ -589,28 +193,28 @@ class ServerLoop(object): self.pre_activated_socket = None self.setup_socket() - self.socket.listen(5) + self.connection_map = {} + self.socket.listen(min(socket.SOMAXCONN, 128)) self.bound_address = ba = self.socket.getsockname() if isinstance(ba, tuple): ba = ':'.join(map(type(''), ba)) - self.log('calibre server listening on', ba) + with TemporaryDirectory(prefix='srv-') as tdir: + self.tdir = tdir + self.ready = True + self.log('calibre server listening on', ba) - # Create worker threads - self.requests.start() - self.ready = True - - while self.ready: - try: - self.tick() - except (KeyboardInterrupt, SystemExit): - raise - except: - self.log.exception('Error in ServerLoop.tick') + while True: + try: + self.tick() + except (KeyboardInterrupt, SystemExit): + self.shutdown() + break + except: + self.log.exception('Error in ServerLoop.tick') def setup_socket(self): self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if self.opts.no_delay and not isinstance(self.bind_address, basestring): - self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), # activate dual-stack. @@ -623,149 +227,159 @@ class ServerLoop(object): # Apparently, the socket option is not available in # this machine's TCP stack pass + self.socket.setblocking(0) def bind(self, family, atype, proto=0): - """Create (or recreate) the actual socket object.""" + '''Create (or recreate) the actual socket object.''' self.socket = socket.socket(family, atype, proto) set_socket_inherit(self.socket, False) self.setup_socket() self.socket.bind(self.bind_address) def tick(self): - """Accept a new connection and put it on the Queue.""" + now = monotonic() + for s, conn in tuple(self.connection_map.iteritems()): + if now - conn.last_activity > self.opts.timeout: + if not conn.handle_timeout(): + self.log.debug('Closing connection because of extended inactivity') + self.close(s, conn) + + read_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is READ or c.wait_for is RDWR] + write_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is WRITE or c.wait_for is RDWR] try: - s, addr = self.socket.accept() - if not self.ready: + readable, writable, _ = select.select([self.socket] + read_needed, write_needed, [], self.opts.timeout) + except select.error as e: + if e.errno in socket_errors_eintr: return - - set_socket_inherit(s, False) - if hasattr(s, 'settimeout'): - s.settimeout(self.opts.timeout) - - if self.ssl_context is not None: + for s, conn in tuple(self.connection_map.iteritems()): try: - s = self.ssl_context.wrap_socket(s, server_side=True) - except ssl.SSLEOFError: - return # Ignore, client closed connection - except ssl.SSLError as e: - if e.args[1].endswith('http request'): - msg = (b"The client sent a plain HTTP request, but " - b"this server only speaks HTTPS on this port.") - response = [ - b"HTTP/1.1 400 Bad Request\r\n", - str("Content-Length: %s\r\n" % len(msg)), - b"Content-Type: text/plain\r\n\r\n", - msg - ] - with SocketFile(s._sock) as f: - f.write(response) - return - elif e.args[1].endswith('unknown protocol'): - return # Drop connection - raise - if hasattr(s, 'settimeout'): - s.settimeout(self.opts.timeout) + select.select([s], [], [], 0) + except select.error: + self.close(s, conn) # Bad socket, discard - conn = Connection(self, s) - - if not isinstance(self.bind_address, basestring): - # optional values - # Until we do DNS lookups, omit REMOTE_HOST - if addr is None: # sometimes this can happen - # figure out if AF_INET or AF_INET6. - if len(s.getsockname()) == 2: - # AF_INET - addr = ('0.0.0.0', 0) - else: - # AF_INET6 - addr = ('::', 0) - conn.remote_addr = addr[0] - conn.remote_port = addr[1] - - try: - self.requests.put(conn) - except Full: - self.log.warn('Server overloaded, dropping connection') - conn.close() - return - except socket.timeout: - # The only reason for the timeout in start() is so we can - # notice keyboard interrupts on Win32, which don't interrupt - # accept() by default - return - except socket.error as e: - if e.args[0] in socket_error_eintr | socket_errors_nonblocking | socket_errors_to_ignore: - return - raise - - def stop(self): - """ Gracefully shutdown the server loop. """ if not self.ready: return - # We run the stop code in its own thread so that it is not interrupted - # by KeyboardInterrupt - self.ready = False - t = Thread(target=self._stop) - t.start() - try: - t.join() - except KeyboardInterrupt: - pass - def _stop(self): - self.log('Shutting down server gracefully, waiting for connections to close...') - self.requests.stop(self.opts.shutdown_timeout) - sock = self.tick_once() - if hasattr(sock, "close"): - sock.close() - self.socket = None + ignore = set() + for s, conn, event in self.get_actions(readable, writable): + if s in ignore: + continue + try: + conn.handle_event(event) + if not conn.ready: + self.close(s, conn) + except Exception as e: + ignore.add(s) + if conn.ready: + self.log.exception('Unhandled exception in state: %s' % conn.state_description) + if conn.response_started: + self.close(s, conn) + else: + try: + conn.report_unhandled_exception(e, traceback.format_exc()) + except Exception: + self.close(s, conn) + else: + self.log.error('Error in SSL handshake, terminating connection: %s' % as_unicode(e)) + self.close(s, conn) - def tick_once(self): - # Touch our own socket to make accept() return immediately. + def wakeup(self): + # Touch our own socket to make select() return immediately. sock = getattr(self, "socket", None) if sock is not None: - if not isinstance(self.bind_address, basestring): - try: - host, port = sock.getsockname()[:2] - except socket.error as e: - if e.args[0] not in socket_errors_to_ignore: - raise - else: - # Ensure tick() returns by opening a transient connection - # to our own listening socket - for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, - socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - s = None - try: - s = socket.socket(af, socktype, proto) - s.settimeout(1.0) - s.connect((host, port)) + try: + host, port = sock.getsockname()[:2] + except socket.error as e: + if e.errno not in socket_errors_socket_closed: + raise + else: + for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + s = None + try: + s = socket.socket(af, socktype, proto) + s.settimeout(1.0) + s.connect((host, port)) + s.close() + except socket.error: + if s is not None: s.close() - except socket.error: - if s is not None: - s.close() return sock -def echo_handler(conn): - keep_going = True - while keep_going: + def close(self, s, conn): + self.connection_map.pop(s, None) + conn.close() + + def get_actions(self, readable, writable): + for s in readable: + if s is self.socket: + s, addr = self.accept() + if s is not None: + self.connection_map[s] = conn = self.handler(s, self.opts, self.ssl_context, self.tdir) + if self.ssl_context is not None: + yield s, conn, RDWR + else: + yield s, self.connection_map[s], READ + for s in writable: + yield s, self.connection_map[s], WRITE + + def accept(self): try: - line = conn.socket_file.readline() - except socket.timeout: - continue - conn.server_loop.log('Received:', repr(line)) - if not line.rstrip(): - keep_going = False - line = b'bye\r\n' - conn.socket_file.write(line) - conn.socket_file.flush() + return self.socket.accept() + except socket.error: + return None, None + + def stop(self): + self.ready = False + self.wakeup() + + def shutdown(self): + try: + if getattr(self, 'socket', None): + self.socket.close() + self.socket = None + except socket.error: + pass + for s, conn in tuple(self.connection_map.iteritems()): + self.close(s, conn) + +class EchoLine(Connection): # {{{ + + bye_after_echo = False + + def connection_ready(self): + self.rbuf = BytesIO() + self.set_state(READ, self.read_line) + + def read_line(self, event): + data = self.recv(1) + if data: + self.rbuf.write(data) + if b'\n' == data: + if self.rbuf.tell() < 3: + # Empty line + self.rbuf = BytesIO(b'bye' + self.rbuf.getvalue()) + self.bye_after_echo = True + self.set_state(WRITE, self.echo) + self.rbuf.seek(0) + + def echo(self, event): + pos = self.rbuf.tell() + self.rbuf.seek(0, os.SEEK_END) + left = self.rbuf.tell() - pos + self.rbuf.seek(pos) + sent = self.send(self.rbuf.read(512)) + if sent == left: + self.rbuf = BytesIO() + self.set_state(READ, self.read_line) + if self.bye_after_echo: + self.ready = False + else: + self.rbuf.seek(pos + sent) +# }}} if __name__ == '__main__': - s = ServerLoop(echo_handler) - with HandleInterrupt(s.tick_once): - try: - s.serve_forever() - except KeyboardInterrupt: - pass - s.stop() + s = ServerLoop(EchoLine) + with HandleInterrupt(s.wakeup): + s.serve_forever() diff --git a/src/calibre/srv/opts.py b/src/calibre/srv/opts.py index 482564f2cd..1e027e50b9 100644 --- a/src/calibre/srv/opts.py +++ b/src/calibre/srv/opts.py @@ -36,14 +36,6 @@ raw_options = ( 'shutdown_timeout', 5.0, None, - 'Minimum number of connection handling threads', - 'min_threads', 10, - None, - - 'Maximum number of simultaneous connections (beyond this number of connections will be dropped)', - 'max_threads', 500, - None, - 'Allow socket pre-allocation, for example, with systemd socket activation', 'allow_socket_preallocation', True, None, @@ -52,7 +44,7 @@ raw_options = ( 'max_header_line_size', 8.0, None, - 'Max. size of a HTTP request (in MB)', + 'Max. allowed size for files uploaded to the server (in MB)', 'max_request_body_size', 500.0, None, @@ -60,12 +52,6 @@ raw_options = ( 'compress_min_size', 1024, None, - 'Decrease latency by using the TCP_NODELAY feature', - 'no_delay', True, - 'no_delay turns on TCP_NODELAY which decreases latency at the cost of' - ' worse overall performance when sending multiple small packets. It' - ' prevents the TCP stack from aggregating multiple small TCP packets.', - 'Use zero copy file transfers for increased performance', 'use_sendfile', True, 'This will use zero-copy in-kernel transfers when sending files over the network,' diff --git a/src/calibre/srv/respond.py b/src/calibre/srv/respond.py deleted file mode 100644 index 3f80b10aa7..0000000000 --- a/src/calibre/srv/respond.py +++ /dev/null @@ -1,356 +0,0 @@ -#!/usr/bin/env python2 -# vim:fileencoding=utf-8 -from __future__ import (unicode_literals, division, absolute_import, - print_function) - -__license__ = 'GPL v3' -__copyright__ = '2015, Kovid Goyal ' - -import os, hashlib, httplib, zlib, struct, time, uuid -from io import DEFAULT_BUFFER_SIZE, BytesIO -from collections import namedtuple -from functools import partial -from future_builtins import map -from itertools import izip_longest - -from calibre import force_unicode, guess_type -from calibre.srv.errors import IfNoneMatch, RangeNotSatisfiable -from calibre.srv.sendfile import file_metadata, copy_range, sendfile_to_socket - -Range = namedtuple('Range', 'start stop size') -MULTIPART_SEPARATOR = uuid.uuid4().hex.decode('ascii') - -def get_ranges(headervalue, content_length): - ''' Return a list of ranges from the Range header. If this function returns - an empty list, it indicates no valid range was found. ''' - if not headervalue: - return None - - result = [] - try: - bytesunit, byteranges = headervalue.split("=", 1) - except Exception: - return None - if bytesunit.strip() != 'bytes': - return None - - for brange in byteranges.split(","): - start, stop = [x.strip() for x in brange.split("-", 1)] - if start: - if not stop: - stop = content_length - 1 - try: - start, stop = int(start), int(stop) - except Exception: - continue - if start >= content_length: - continue - if stop < start: - continue - stop = min(stop, content_length - 1) - result.append(Range(start, stop, stop - start + 1)) - elif stop: - # Negative subscript (last N bytes) - try: - stop = int(stop) - except Exception: - continue - if stop > content_length: - result.append(Range(0, content_length-1, content_length)) - else: - result.append(Range(content_length - stop, content_length - 1, stop)) - - return result - - -def acceptable_encoding(val, allowed=frozenset({'gzip'})): - def enc(x): - e, r = x.partition(';')[::2] - p, v = r.partition('=')[::2] - q = 1.0 - if p == 'q' and v: - try: - q = float(v) - except Exception: - pass - return e.lower(), q - - emap = dict(enc(x.strip()) for x in val.split(',')) - acceptable = sorted(set(emap) & allowed, key=emap.__getitem__, reverse=True) - if acceptable: - return acceptable[0] - -def gzip_prefix(mtime): - # See http://www.gzip.org/zlib/rfc-gzip.html - return b''.join(( - b'\x1f\x8b', # ID1 and ID2: gzip marker - b'\x08', # CM: compression method - b'\x00', # FLG: none set - # MTIME: 4 bytes - struct.pack(b"'), sent, self.content_length)) - self.src_file = None - - def write_compressed(self, dest): - self.src_file.seek(0) - write_compressed_file_obj(self.src_file, dest) - self.src_file = None - - def write_ranges(self, ranges, dest): - if isinstance(ranges, Range): - r = ranges - self.copy_range(r.start, r.size, dest) - else: - for r, header in ranges: - dest.write(header) - if r is not None: - dest.write(b'\r\n') - self.copy_range(r.start, r.size, dest) - dest.write(b'\r\n') - self.src_file = None - - def copy_range(self, start, size, dest): - if self.use_sendfile: - dest.flush() # Ensure everything in the SocketFile buffer is sent before calling sendfile() - sent = sendfile_to_socket(self.src_file, start, size, dest) - else: - sent = copy_range(self.src_file, start, size, dest) - if sent != size: - raise IOError('Failed to send byte range from file (%r) (%s != %s bytes), perhaps the file was modified during send?' % ( - getattr(self.src_file, 'name', ''), sent, size)) - -class FileSystemOutputFile(ReadableOutput): - - def __init__(self, output, outheaders, stat_result, use_sendfile): - self.src_file = output - self.name = output.name - self.content_length = stat_result.st_size - self.etag = '"%s"' % hashlib.sha1(type('')(stat_result.st_mtime) + force_unicode(output.name or '')).hexdigest() - self.accept_ranges = True - self.use_sendfile = use_sendfile and sendfile_to_socket is not None - - -class DynamicOutput(object): - - def __init__(self, output, outheaders): - if isinstance(output, bytes): - self.data = output - else: - self.data = output.encode('utf-8') - ct = outheaders.get('Content-Type') - if not ct: - outheaders.set('Content-Type', 'text/plain; charset=UTF-8', replace_all=True) - self.content_length = len(self.data) - self.etag = None - self.accept_ranges = False - - def write(self, dest): - dest.write(self.data) - self.data = None - - def write_compressed(self, dest): - write_compressed_file_obj(BytesIO(self.data), dest) - -class GeneratedOutput(object): - - def __init__(self, output, outheaders): - self.output = output - self.content_length = self.etag = None - self.accept_ranges = False - - def write(self, dest): - for line in self.output: - if line: - write_chunked_data(dest, line) - -class StaticGeneratedOutput(object): - - def __init__(self, data): - if isinstance(data, type('')): - data = data.encode('utf-8') - self.data = data - self.etag = '"%s"' % hashlib.sha1(data).hexdigest() - self.content_length = len(data) - self.accept_ranges = False - - def write(self, dest): - dest.write(self.data) - - def write_compressed(self, dest): - write_compressed_file_obj(BytesIO(self.data), dest) - -def generate_static_output(cache, gso_lock, name, generator): - with gso_lock: - ans = cache.get(name) - if ans is None: - ans = cache[name] = StaticGeneratedOutput(generator()) - return ans - -def parse_if_none_match(val): - return {x.strip() for x in val.split(',')} - -def finalize_output(output, inheaders, outheaders, status_code, is_http1, method, opts): - ct = outheaders.get('Content-Type', '') - compressible = not ct or ct.startswith('text/') or ct.startswith('image/svg') or ct.startswith('application/json') - stat_result = file_metadata(output) - if stat_result is not None: - output = FileSystemOutputFile(output, outheaders, stat_result, opts.use_sendfile) - if 'Content-Type' not in outheaders: - mt = guess_type(output.name)[0] - if mt: - if mt in ('text/plain', 'text/html'): - mt =+ '; charset=UTF-8' - outheaders['Content-Type'] = mt - elif isinstance(output, (bytes, type(''))): - output = DynamicOutput(output, outheaders) - elif hasattr(output, 'read'): - output = ReadableOutput(output, outheaders) - elif isinstance(output, StaticGeneratedOutput): - pass - else: - output = GeneratedOutput(output, outheaders) - compressible = (status_code == httplib.OK and compressible and - (opts.compress_min_size > -1 and output.content_length >= opts.compress_min_size) and - acceptable_encoding(inheaders.get('Accept-Encoding', '')) and not is_http1) - accept_ranges = (not compressible and output.accept_ranges is not None and status_code == httplib.OK and - not is_http1) - ranges = get_ranges(inheaders.get('Range'), output.content_length) if output.accept_ranges and method in ('GET', 'HEAD') else None - if_range = (inheaders.get('If-Range') or '').strip() - if if_range and if_range != output.etag: - ranges = None - if ranges is not None and not ranges: - raise RangeNotSatisfiable(output.content_length) - - for header in ('Accept-Ranges', 'Content-Encoding', 'Transfer-Encoding', 'ETag', 'Content-Length'): - outheaders.pop('header', all=True) - - none_match = parse_if_none_match(inheaders.get('If-None-Match', '')) - matched = '*' in none_match or (output.etag and output.etag in none_match) - if matched: - raise IfNoneMatch(output.etag) - - if output.etag and method in ('GET', 'HEAD'): - outheaders.set('ETag', output.etag, replace_all=True) - if accept_ranges: - outheaders.set('Accept-Ranges', 'bytes', replace_all=True) - elif compressible: - outheaders.set('Content-Encoding', 'gzip', replace_all=True) - if output.content_length is not None and not compressible and not ranges: - outheaders.set('Content-Length', '%d' % output.content_length, replace_all=True) - - if compressible or output.content_length is None: - outheaders.set('Transfer-Encoding', 'chunked', replace_all=True) - - if ranges: - if len(ranges) == 1: - r = ranges[0] - outheaders.set('Content-Length', '%d' % r.size, replace_all=True) - outheaders.set('Content-Range', 'bytes %d-%d/%d' % (r.start, r.stop, output.content_length), replace_all=True) - output.commit = partial(output.write_ranges, r) - else: - range_parts = get_range_parts(ranges, outheaders.get('Content-Type'), output.content_length) - size = sum(map(len, range_parts)) + sum(r.size + 4 for r in ranges) - outheaders.set('Content-Length', '%d' % size, replace_all=True) - outheaders.set('Content-Type', 'multipart/byteranges; boundary=' + MULTIPART_SEPARATOR, replace_all=True) - output.commit = partial(output.write_ranges, izip_longest(ranges, range_parts)) - status_code = httplib.PARTIAL_CONTENT - else: - output.commit = output.write_compressed if compressible else output.write - - return status_code, output diff --git a/src/calibre/srv/sendfile.py b/src/calibre/srv/sendfile.py index 33fd5388af..a740008332 100644 --- a/src/calibre/srv/sendfile.py +++ b/src/calibre/srv/sendfile.py @@ -10,7 +10,7 @@ import os, ctypes, errno, socket from io import DEFAULT_BUFFER_SIZE from select import select -from calibre.constants import iswindows, isosx +from calibre.constants import islinux, isosx from calibre.srv.utils import eintr_retry_call def file_metadata(fileobj): @@ -33,10 +33,15 @@ def copy_range(src_file, start, size, dest): del data return total_sent +class CannotSendfile(Exception): + pass -if iswindows: - sendfile_to_socket = None -elif isosx: +class SendfileInterrupted(Exception): + pass + +sendfile_to_socket = sendfile_to_socket_async = None + +if isosx: libc = ctypes.CDLL(None, use_errno=True) sendfile = ctypes.CFUNCTYPE( ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int64, ctypes.POINTER(ctypes.c_int64), ctypes.c_void_p, ctypes.c_int, use_errno=True)( @@ -68,7 +73,19 @@ elif isosx: offset += num_bytes.value return total_sent -else: + def sendfile_to_socket_async(fileobj, offset, size, socket_file): + num_bytes = ctypes.c_int64(size) + ret = sendfile(fileobj.fileno(), socket_file.fileno(), offset, ctypes.byref(num_bytes), None, 0) + if ret != 0: + err = ctypes.get_errno() + if err in (errno.EBADF, errno.ENOTSUP, errno.ENOTSOCK, errno.EOPNOTSUPP): + raise CannotSendfile() + if err in (errno.EINTR, errno.EAGAIN): + raise SendfileInterrupted() + raise IOError((err, os.strerror(err))) + return num_bytes.value + +elif islinux: libc = ctypes.CDLL(None, use_errno=True) sendfile = ctypes.CFUNCTYPE( ctypes.c_ssize_t, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int64), ctypes.c_size_t, use_errno=True)(('sendfile64', libc)) @@ -97,3 +114,15 @@ else: size -= sent total_sent += sent return total_sent + + def sendfile_to_socket_async(fileobj, offset, size, socket_file): + off = ctypes.c_int64(offset) + sent = sendfile(socket_file.fileno(), fileobj.fileno(), ctypes.byref(off), size) + if sent < 0: + err = ctypes.get_errno() + if err in (errno.ENOSYS, errno.EINVAL): + raise CannotSendfile() + if err in (errno.EINTR, errno.EAGAIN): + raise SendfileInterrupted() + raise IOError((err, os.strerror(err))) + return sent diff --git a/src/calibre/srv/tests/base.py b/src/calibre/srv/tests/base.py index f02d39d077..416bf2b66f 100644 --- a/src/calibre/srv/tests/base.py +++ b/src/calibre/srv/tests/base.py @@ -36,7 +36,7 @@ class TestServer(Thread): Thread.__init__(self, name='ServerMain') from calibre.srv.opts import Options from calibre.srv.loop import ServerLoop - from calibre.srv.http import create_http_handler + from calibre.srv.http_response import create_http_handler kwargs['shutdown_timeout'] = kwargs.get('shutdown_timeout', 0.1) self.loop = ServerLoop( create_http_handler(handler), @@ -68,5 +68,5 @@ class TestServer(Thread): return httplib.HTTPConnection(self.address[0], self.address[1], strict=True, timeout=timeout) def change_handler(self, handler): - from calibre.srv.http import create_http_handler - self.loop.req_resp_handler = create_http_handler(handler) + from calibre.srv.http_response import create_http_handler + self.loop.handler = create_http_handler(handler) diff --git a/src/calibre/srv/tests/http.py b/src/calibre/srv/tests/http.py index ee9c5e61c6..89220d095f 100644 --- a/src/calibre/srv/tests/http.py +++ b/src/calibre/srv/tests/http.py @@ -6,57 +6,56 @@ from __future__ import (unicode_literals, division, absolute_import, __license__ = 'GPL v3' __copyright__ = '2015, Kovid Goyal ' -import textwrap, httplib, hashlib, zlib, string +import httplib, hashlib, zlib, string from io import BytesIO from tempfile import NamedTemporaryFile from calibre import guess_type from calibre.srv.tests.base import BaseTest, TestServer -def headers(raw): - return BytesIO(textwrap.dedent(raw).encode('utf-8')) - class TestHTTP(BaseTest): def test_header_parsing(self): # {{{ 'Test parsing of HTTP headers' - from calibre.srv.http import read_headers + from calibre.srv.http_request import HTTPHeaderParser - def test(name, raw, **kwargs): - hdict = read_headers(headers(raw).readline) - self.assertSetEqual(set(hdict.items()), {(k.replace('_', '-').title(), v) for k, v in kwargs.iteritems()}, name + ' failed') + def test(name, *lines, **kwargs): + p = HTTPHeaderParser() + p.push(*lines) + self.assertTrue(p.finished) + self.assertSetEqual(set(p.hdict.items()), {(k.replace('_', '-').title(), v) for k, v in kwargs.iteritems()}, name + ' failed') test('Continuation line parsing', - '''\ - a: one\r - b: two\r - 2\r - \t3\r - c:three\r - \r\n''', a='one', b='two 2 3', c='three') + 'a: one', + 'b: two', + ' 2', + '\t3', + 'c:three', + '\r\n', a='one', b='two 2 3', c='three') test('Non-ascii headers parsing', - '''\ - a:mūs\r - \r\n''', a='mūs'.encode('utf-8')) + b'a:mūs\r', '\r\n', a='mūs'.encode('utf-8')) test('Comma-separated parsing', - '''\ - Accept-Encoding: one\r - Accept-Encoding: two\r - \r\n''', accept_encoding='one, two') + 'Accept-Encoding: one', + 'accept-Encoding: two', + '\r\n', accept_encoding='one, two') + + def parse(line): + HTTPHeaderParser()(line) with self.assertRaises(ValueError): - read_headers(headers('Connection:mūs\r\n').readline) - read_headers(headers('Connection\r\n').readline) - read_headers(headers('Connection:a\r\n').readline) - read_headers(headers('Connection:a\n').readline) - read_headers(headers(' Connection:a\n').readline) + parse('Connection:mūs\r\n') + parse('Connection\r\n') + parse('Connection:a\r\n') + parse('Connection:a\n') + parse(' Connection:a\n') + parse(':a\n') # }}} def test_accept_encoding(self): # {{{ 'Test parsing of Accept-Encoding' - from calibre.srv.respond import acceptable_encoding + from calibre.srv.http_response import acceptable_encoding def test(name, val, ans, allowed={'gzip'}): self.ae(acceptable_encoding(val, allowed), ans, name + ' failed') test('Empty field', '', None) @@ -68,7 +67,7 @@ class TestHTTP(BaseTest): def test_range_parsing(self): # {{{ 'Test parsing of Range header' - from calibre.srv.respond import get_ranges + from calibre.srv.http_response import get_ranges def test(val, *args): pval = get_ranges(val, 100) if len(args) == 1 and args[0] is None: @@ -91,11 +90,38 @@ class TestHTTP(BaseTest): 'Test basic HTTP protocol conformance' from calibre.srv.errors import HTTP404 body = 'Requested resource not found' - def handler(conn): + def handler(data): raise HTTP404(body) + def raw_send(conn, raw): + conn.send(raw) + conn._HTTPConnection__state = httplib._CS_REQ_SENT + return conn.getresponse() + with TestServer(handler, timeout=0.1, max_header_line_size=100./1024, max_request_body_size=100./(1024*1024)) as server: - # Test 404 conn = server.connect() + r = raw_send(conn, b'hello\n') + self.ae(r.status, httplib.BAD_REQUEST) + self.ae(r.read(), b'HTTP requires CRLF line terminators') + + r = raw_send(conn, b'\r\nGET /index.html HTTP/1.1\r\n\r\n') + self.ae(r.status, httplib.NOT_FOUND), self.ae(r.read(), b'Requested resource not found') + + r = raw_send(conn, b'\r\n\r\nGET /index.html HTTP/1.1\r\n\r\n') + self.ae(r.status, httplib.BAD_REQUEST) + self.ae(r.read(), b'Multiple leading empty lines not allowed') + + r = raw_send(conn, b'hello world\r\n') + self.ae(r.status, httplib.BAD_REQUEST) + self.ae(r.read(), b'Malformed Request-Line') + + r = raw_send(conn, b'x' * 200) + self.ae(r.status, httplib.BAD_REQUEST) + self.ae(r.read(), b'') + + r = raw_send(conn, b'XXX /index.html HTTP/1.1\r\n\r\n') + self.ae(r.status, httplib.BAD_REQUEST), self.ae(r.read(), b'Unknown HTTP method') + + # Test 404 conn.request('HEAD', '/moose') r = conn.getresponse() self.ae(r.status, httplib.NOT_FOUND) @@ -104,32 +130,48 @@ class TestHTTP(BaseTest): self.ae(r.getheader('Content-Type'), 'text/plain; charset=UTF-8') self.ae(len(r.getheaders()), 3) self.ae(r.read(), '') - conn.request('GET', '/moose') + conn.request('GET', '/choose') r = conn.getresponse() self.ae(r.status, httplib.NOT_FOUND) - self.ae(r.read(), 'Requested resource not found') + self.ae(r.read(), b'Requested resource not found') - server.change_handler(lambda conn:conn.path[0] + conn.input_reader.read().decode('ascii')) + # Test 500 + orig = server.loop.log.filter_level + server.loop.log.filter_level = server.loop.log.ERROR + 10 + server.change_handler(lambda data:1/0) + conn = server.connect() + conn.request('GET', '/test/') + r = conn.getresponse() + self.ae(r.status, httplib.INTERNAL_SERVER_ERROR) + server.loop.log.filter_level = orig + + server.change_handler(lambda data:data.path[0] + data.read().decode('ascii')) conn = server.connect() # Test simple GET conn.request('GET', '/test/') r = conn.getresponse() self.ae(r.status, httplib.OK) - self.ae(r.read(), 'test') + self.ae(r.read(), b'test') + + # Test TRACE + lines = ['TRACE /xxx HTTP/1.1', 'Test: value', 'Xyz: abc, def', '', ''] + r = raw_send(conn, ('\r\n'.join(lines)).encode('ascii')) + self.ae(r.status, httplib.OK) + self.ae(r.read().decode('utf-8'), '\n'.join(lines[:-2])) # Test POST with simple body conn.request('POST', '/test', 'body') r = conn.getresponse() self.ae(r.status, httplib.CREATED) - self.ae(r.read(), 'testbody') + self.ae(r.read(), b'testbody') # Test POST with chunked transfer encoding conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'}) - conn.send(b'4\r\nbody\r\n0\r\n\r\n') + conn.send(b'4\r\nbody\r\na\r\n1234567890\r\n0\r\n\r\n') r = conn.getresponse() self.ae(r.status, httplib.CREATED) - self.ae(r.read(), 'testbody') + self.ae(r.read(), b'testbody1234567890') # Test various incorrect input orig_level, server.log.filter_level = server.log.filter_level, server.log.ERROR @@ -150,19 +192,26 @@ class TestHTTP(BaseTest): self.ae(r.status, httplib.BAD_REQUEST) self.assertIn(b'not a valid chunk size', r.read()) + conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'}) + conn.send(b'4\r\nbody\r\n200\r\n\r\n') + r = conn.getresponse() + self.ae(r.status, httplib.REQUEST_ENTITY_TOO_LARGE) + conn.request('POST', '/test', body='a'*200) + r = conn.getresponse() + self.ae(r.status, httplib.REQUEST_ENTITY_TOO_LARGE) + conn = server.connect() conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'}) conn.send(b'3\r\nbody\r\n0\r\n\r\n') r = conn.getresponse() - self.ae(r.status, httplib.BAD_REQUEST) - self.assertIn(b'!= CRLF', r.read()) + self.ae(r.status, httplib.BAD_REQUEST), self.ae(r.read(), b'Chunk does not have trailing CRLF') conn = server.connect(timeout=1) conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'}) conn.send(b'30\r\nbody\r\n0\r\n\r\n') r = conn.getresponse() - self.ae(r.status, httplib.BAD_REQUEST) - self.assertIn(b'Timed out waiting for chunk', r.read()) + self.ae(r.status, httplib.REQUEST_TIMEOUT) + self.assertIn(b'', r.read()) server.log.filter_level = orig_level conn = server.connect() @@ -180,18 +229,17 @@ class TestHTTP(BaseTest): # Test closing conn.request('GET', '/close', headers={'Connection':'close'}) - self.ae(server.loop.requests.busy, 1) r = conn.getresponse() + self.ae(server.loop.num_active_connections, 1) self.ae(r.status, 200), self.ae(r.read(), 'close') - self.ae(server.loop.requests.busy, 0) + server.loop.wakeup() + self.ae(server.loop.num_active_connections, 0) self.assertIsNone(conn.sock) - self.ae(server.loop.requests.idle, 10) - # }}} def test_http_response(self): # {{{ 'Test HTTP protocol responses' - from calibre.srv.respond import parse_multipart_byterange + from calibre.srv.http_response import parse_multipart_byterange def handler(conn): return conn.generate_static_output('test', lambda : ''.join(conn.path)) with TestServer(handler, timeout=0.1, compress_min_size=0) as server, \ @@ -216,9 +264,10 @@ class TestHTTP(BaseTest): r = conn.getresponse() self.ae(r.status, httplib.OK), self.ae(zlib.decompress(r.read(), 16+zlib.MAX_WBITS), b'an_etagged_path') - for i in '12': - # Test getting a filesystem file + # Test getting a filesystem file + for use_sendfile in (True, False): server.change_handler(lambda conn: f) + server.loop.opts.use_sendfile = use_sendfile conn = server.connect() conn.request('GET', '/test') r = conn.getresponse() @@ -229,27 +278,27 @@ class TestHTTP(BaseTest): self.ae(int(r.getheader('Content-Length')), len(fdata)) self.ae(r.status, httplib.OK), self.ae(r.read(), fdata) - conn.request('GET', '/test', headers={'Range':'bytes=0-25'}) + conn.request('GET', '/test', headers={'Range':'bytes=2-25'}) r = conn.getresponse() self.ae(type('')(r.getheader('Accept-Ranges')), 'bytes') - self.ae(type('')(r.getheader('Content-Range')), 'bytes 0-25/%d' % len(fdata)) - self.ae(int(r.getheader('Content-Length')), 26) - self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[0:26]) + self.ae(type('')(r.getheader('Content-Range')), 'bytes 2-25/%d' % len(fdata)) + self.ae(int(r.getheader('Content-Length')), 24) + self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[2:26]) conn.request('GET', '/test', headers={'Range':'bytes=100000-'}) r = conn.getresponse() self.ae(type('')(r.getheader('Content-Range')), 'bytes */%d' % len(fdata)) self.ae(r.status, httplib.REQUESTED_RANGE_NOT_SATISFIABLE) - conn.request('GET', '/test', headers={'Range':'bytes=0-1000000'}) - r = conn.getresponse() - self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata) - conn.request('GET', '/test', headers={'Range':'bytes=25-50', 'If-Range':etag}) r = conn.getresponse() self.ae(int(r.getheader('Content-Length')), 26) self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[25:51]) + conn.request('GET', '/test', headers={'Range':'bytes=0-1000000'}) + r = conn.getresponse() + self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata) + conn.request('GET', '/test', headers={'Range':'bytes=25-50', 'If-Range':'"nomatch"'}) r = conn.getresponse() self.assertFalse(r.getheader('Content-Range')) @@ -274,6 +323,4 @@ class TestHTTP(BaseTest): self.ae(data, r.read()) # Now try it without sendfile - server.loop.opts.use_sendfile ^= True - conn = server.connect() # }}} diff --git a/src/calibre/srv/utils.py b/src/calibre/srv/utils.py index 60d6ced1d1..cb6602db69 100644 --- a/src/calibre/srv/utils.py +++ b/src/calibre/srv/utils.py @@ -11,9 +11,13 @@ from contextlib import closing from urlparse import parse_qs import repr as reprlib from email.utils import formatdate +from operator import itemgetter from calibre.constants import iswindows +HTTP1 = 'HTTP/1.0' +HTTP11 = 'HTTP/1.1' + def http_date(timeval=None): return type('')(formatdate(timeval=timeval, usegmt=True)) @@ -85,7 +89,8 @@ class MultiDict(dict): # {{{ __str__ = __unicode__ = __repr__ def pretty(self, leading_whitespace=''): - return leading_whitespace + ('\n' + leading_whitespace).join('%s: %s' % (k, v) for k, v in self.items()) + return leading_whitespace + ('\n' + leading_whitespace).join( + '%s: %s' % (k, (repr(v) if isinstance(v, bytes) else v)) for k, v in sorted(self.items(), key=itemgetter(0))) # }}} def error_codes(*errnames): @@ -112,29 +117,15 @@ socket_errors_socket_closed = error_codes( # errors indicating a disconnected c socket_errors_nonblocking = error_codes( 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK') -class Corked(object): +def start_cork(sock): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0) + if hasattr(socket, 'TCP_CORK'): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 1) - ' Context manager to turn on TCP corking. Ensures maximum throughput for large logical packets. ' - - def __init__(self, sock): - self.sock = sock - - def __enter__(self): - nodelay = self.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) - if nodelay == 1: - self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0) - self.set_nodelay = True - else: - self.set_nodelay = False - if hasattr(socket, 'TCP_CORK'): - self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 1) - - def __exit__(self, *args): - if self.set_nodelay: - self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - if hasattr(socket, 'TCP_CORK'): - self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 0) - self.sock.send(b'') # Ensure that uncorking occurs +def stop_cork(sock): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + if hasattr(socket, 'TCP_CORK'): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 0) def create_sock_pair(port=0): '''Create socket pair. Works also on windows by using an ephemeral TCP port.'''