Finish implementation of async http server

This commit is contained in:
Kovid Goyal 2015-05-25 18:34:51 +05:30
parent 2068e52b82
commit d075eff758
12 changed files with 1314 additions and 2064 deletions

View File

@ -1,337 +0,0 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import ssl, socket, select, os
from io import BytesIO
from calibre import as_unicode
from calibre.srv.opts import Options
from calibre.srv.utils import (
socket_errors_socket_closed, socket_errors_nonblocking, HandleInterrupt)
from calibre.utils.socket_inheritance import set_socket_inherit
from calibre.utils.logging import ThreadSafeLog
from calibre.utils.monotonic import monotonic
READ, WRITE, RDWR = 'READ', 'WRITE', 'RDWR'
class Connection(object):
def __init__(self, socket, opts, ssl_context):
self.opts = opts
self.ssl_context = ssl_context
self.wait_for = READ
if self.ssl_context is not None:
self.ready = False
self.socket = self.ssl_context.wrap_socket(socket, server_side=True, do_handshake_on_connect=False)
self.set_state(RDWR, self.do_ssl_handshake)
else:
self.ready = True
self.socket = socket
self.connection_ready()
self.last_activity = monotonic()
def set_state(self, wait_for, func):
self.wait_for = wait_for
self.handle_event = func
def do_ssl_handshake(self, event):
try:
self._sslobj.do_handshake()
except ssl.SSLWantReadError:
self.set_state(READ, self.do_ssl_handshake)
except ssl.SSLWantWriteError:
self.set_state(WRITE, self.do_ssl_handshake)
self.ready = True
self.connection_ready()
def send(self, data):
try:
ret = self.socket.send(data)
self.last_activity = monotonic()
return ret
except socket.error as e:
if e.errno in socket_errors_nonblocking:
return 0
elif e.errno in socket_errors_socket_closed:
self.ready = False
return 0
raise
def recv(self, buffer_size):
try:
data = self.socket.recv(buffer_size)
self.last_activity = monotonic()
if not data:
# a closed connection is indicated by signaling
# a read condition, and having recv() return 0.
self.ready = False
return b''
return data
except socket.error as e:
if e.errno in socket_errors_socket_closed:
self.ready = False
return b''
def close(self):
self.ready = False
try:
self.socket.shutdown(socket.SHUT_WR)
self.socket.close()
except socket.error:
pass
def connection_ready(self):
raise NotImplementedError()
class ServerLoop(object):
def __init__(
self,
handler,
bind_address=('localhost', 8080),
opts=None,
# A calibre logging object. If None, a default log that logs to
# stdout is used
log=None
):
self.ready = False
self.handler = handler
self.opts = opts or Options()
self.log = log or ThreadSafeLog(level=ThreadSafeLog.DEBUG)
ba = tuple(bind_address)
if not ba[0]:
# AI_PASSIVE does not work with host of '' or None
ba = ('0.0.0.0', ba[1])
self.bind_address = ba
self.bound_address = None
self.ssl_context = None
if self.opts.ssl_certfile is not None and self.opts.ssl_keyfile is not None:
self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
self.ssl_context.load_cert_chain(certfile=self.opts.ssl_certfile, keyfile=self.opts.ssl_keyfile)
self.pre_activated_socket = None
if self.opts.allow_socket_preallocation:
from calibre.srv.pre_activated import pre_activated_socket
self.pre_activated_socket = pre_activated_socket()
if self.pre_activated_socket is not None:
set_socket_inherit(self.pre_activated_socket, False)
self.bind_address = self.pre_activated_socket.getsockname()
def __str__(self):
return "%s(%r)" % (self.__class__.__name__, self.bind_address)
__repr__ = __str__
def serve_forever(self):
""" Listen for incoming connections. """
if self.pre_activated_socket is None:
# AF_INET or AF_INET6 socket
# Get the correct address family for our host (allows IPv6
# addresses)
host, port = self.bind_address
try:
info = socket.getaddrinfo(
host, port, socket.AF_UNSPEC,
socket.SOCK_STREAM, 0, socket.AI_PASSIVE)
except socket.gaierror:
if ':' in host:
info = [(socket.AF_INET6, socket.SOCK_STREAM,
0, "", self.bind_address + (0, 0))]
else:
info = [(socket.AF_INET, socket.SOCK_STREAM,
0, "", self.bind_address)]
self.socket = None
msg = "No socket could be created"
for res in info:
af, socktype, proto, canonname, sa = res
try:
self.bind(af, socktype, proto)
except socket.error, serr:
msg = "%s -- (%s: %s)" % (msg, sa, serr)
if self.socket:
self.socket.close()
self.socket = None
continue
break
if not self.socket:
raise socket.error(msg)
else:
self.socket = self.pre_activated_socket
self.pre_activated_socket = None
self.setup_socket()
self.ready = True
self.connection_map = {}
self.socket.listen(min(socket.SOMAXCONN, 128))
self.bound_address = ba = self.socket.getsockname()
if isinstance(ba, tuple):
ba = ':'.join(map(type(''), ba))
self.log('calibre server listening on', ba)
while True:
try:
self.tick()
except (KeyboardInterrupt, SystemExit):
self.shutdown()
break
except:
self.log.exception('Error in ServerLoop.tick')
def setup_socket(self):
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if self.opts.no_delay and not isinstance(self.bind_address, basestring):
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# If listening on the IPV6 any address ('::' = IN6ADDR_ANY),
# activate dual-stack.
if (hasattr(socket, 'AF_INET6') and self.socket.family == socket.AF_INET6 and
self.bind_address[0] in ('::', '::0', '::0.0.0.0')):
try:
self.socket.setsockopt(
socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
except (AttributeError, socket.error):
# Apparently, the socket option is not available in
# this machine's TCP stack
pass
self.socket.setblocking(0)
def bind(self, family, atype, proto=0):
'''Create (or recreate) the actual socket object.'''
self.socket = socket.socket(family, atype, proto)
set_socket_inherit(self.socket, False)
self.setup_socket()
self.socket.bind(self.bind_address)
def tick(self):
now = monotonic()
for s, conn in tuple(self.connection_map.iteritems()):
if now - conn.last_activity > self.opts.timeout:
self.log.debug('Closing connection because of extended inactivity')
self.close(s, conn)
read_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is READ or c.wait_for is RDWR]
write_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is WRITE or c.wait_for is RDWR]
readable, writable = select.select([self.socket] + read_needed, write_needed, [], self.opts.timeout)[:2]
if not self.ready:
return
for s, conn, event in self.get_actions(readable, writable):
try:
conn.handle_event(event)
if not conn.ready:
self.close(s, conn)
except Exception as e:
if conn.ready:
self.log.exception('Unhandled exception, terminating connection')
else:
self.log.error('Error in SSL handshake, terminating connection: %s' % as_unicode(e))
self.close(s, conn)
def wakeup(self):
# Touch our own socket to make select() return immediately.
sock = getattr(self, "socket", None)
if sock is not None:
try:
host, port = sock.getsockname()[:2]
except socket.error as e:
if e.errno not in socket_errors_socket_closed:
raise
else:
for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC,
socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
s = None
try:
s = socket.socket(af, socktype, proto)
s.settimeout(1.0)
s.connect((host, port))
s.close()
except socket.error:
if s is not None:
s.close()
return sock
def close(self, s, conn):
self.connection_map.pop(s, None)
conn.close()
def get_actions(self, readable, writable):
for s in readable:
if s is self.socket:
s, addr = self.accept()
if s is not None:
self.connection_map[s] = conn = self.handler(s, self.opts, self.ssl_context)
if self.ssl_context is not None:
yield s, conn, RDWR
else:
yield s, self.connection_map[s], READ
for s in writable:
yield s, self.connection_map[s], WRITE
def accept(self):
try:
return self.socket.accept()
except socket.error:
return None, None
def stop(self):
self.ready = False
self.wakeup()
def shutdown(self):
try:
if getattr(self, 'socket', None):
self.socket.close()
self.socket = None
except socket.error:
pass
for s, conn in tuple(self.connection_map.iteritems()):
self.close(s, conn)
class EchoLine(Connection): # {{{
bye_after_echo = False
def connection_ready(self):
self.rbuf = BytesIO()
self.set_state(READ, self.read_line)
def read_line(self, event):
data = self.recv(1)
if data:
self.rbuf.write(data)
if b'\n' == data:
if self.rbuf.tell() < 3:
# Empty line
self.rbuf = BytesIO(b'bye' + self.rbuf.getvalue())
self.bye_after_echo = True
self.set_state(WRITE, self.echo)
self.rbuf.seek(0)
def echo(self, event):
pos = self.rbuf.tell()
self.rbuf.seek(0, os.SEEK_END)
left = self.rbuf.tell() - pos
self.rbuf.seek(pos)
sent = self.send(self.rbuf.read(512))
if sent == left:
self.rbuf = BytesIO()
self.set_state(READ, self.read_line)
if self.bye_after_echo:
self.ready = False
else:
self.rbuf.seek(pos + sent)
# }}}
if __name__ == '__main__':
s = ServerLoop(EchoLine)
with HandleInterrupt(s.wakeup):
s.serve_forever()

View File

@ -7,26 +7,5 @@ __license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>' __copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
class MaxSizeExceeded(Exception):
def __init__(self, prefix, size, limit):
Exception.__init__(self, prefix + (' %d > maximum %d' % (size, limit)))
self.size = size
self.limit = limit
class HTTP404(Exception): class HTTP404(Exception):
pass pass
class IfNoneMatch(Exception):
def __init__(self, etag=None):
Exception.__init__(self, '')
self.etag = etag
class BadChunkedInput(ValueError):
pass
class RangeNotSatisfiable(ValueError):
def __init__(self, content_length):
ValueError.__init__(self)
self.content_length = content_length

View File

@ -1,602 +0,0 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import httplib, socket, re, os
from io import BytesIO
import repr as reprlib
from urllib import unquote
from functools import partial
from operator import itemgetter
from calibre import as_unicode
from calibre.constants import __version__
from calibre.srv.errors import (
MaxSizeExceeded, HTTP404, IfNoneMatch, BadChunkedInput, RangeNotSatisfiable)
from calibre.srv.respond import finalize_output, generate_static_output
from calibre.srv.utils import MultiDict, http_date, socket_errors_to_ignore
HTTP1 = 'HTTP/1.0'
HTTP11 = 'HTTP/1.1'
protocol_map = {(1, 0):HTTP1, (1, 1):HTTP11}
quoted_slash = re.compile(br'%2[fF]')
def parse_request_uri(uri): # {{{
"""Parse a Request-URI into (scheme, authority, path).
Note that Request-URI's must be one of::
Request-URI = "*" | absoluteURI | abs_path | authority
Therefore, a Request-URI which starts with a double forward-slash
cannot be a "net_path"::
net_path = "//" authority [ abs_path ]
Instead, it must be interpreted as an "abs_path" with an empty first
path segment::
abs_path = "/" path_segments
path_segments = segment *( "/" segment )
segment = *pchar *( ";" param )
param = *pchar
"""
if uri == b'*':
return None, None, uri
i = uri.find(b'://')
if i > 0 and b'?' not in uri[:i]:
# An absoluteURI.
# If there's a scheme (and it must be http or https), then:
# http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query
# ]]
scheme, remainder = uri[:i].lower(), uri[i + 3:]
authority, path = remainder.split(b'/', 1)
path = b'/' + path
return scheme, authority, path
if uri.startswith(b'/'):
# An abs_path.
return None, None, uri
else:
# An authority.
return None, uri, None
# }}}
comma_separated_headers = {
'Accept', 'Accept-Charset', 'Accept-Encoding',
'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control',
'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
'WWW-Authenticate'
}
decoded_headers = {
'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect',
} | comma_separated_headers
def read_headers(readline): # {{{
"""
Read headers from the given stream into the given header dict.
If hdict is None, a new header dict is created. Returns the populated
header dict.
Headers which are repeated are folded together using a comma if their
specification so dictates.
This function raises ValueError when the read bytes violate the HTTP spec.
You should probably return "400 Bad Request" if this happens.
"""
hdict = MultiDict()
def safe_decode(hname, value):
try:
return value.decode('ascii')
except UnicodeDecodeError:
if hname in decoded_headers:
raise
return value
current_key = current_value = None
def commit():
if current_key:
key = current_key.decode('ascii')
val = safe_decode(key, current_value)
if key in comma_separated_headers:
existing = hdict.pop(key)
if existing is not None:
val = existing + ', ' + val
hdict[key] = val
while True:
line = readline()
if not line:
# No more data--illegal end of headers
raise ValueError("Illegal end of headers.")
if line == b'\r\n':
# Normal end of headers
commit()
break
if not line.endswith(b'\r\n'):
raise ValueError("HTTP requires CRLF terminators")
if line[0] in b' \t':
# It's a continuation line.
if current_key is None or current_value is None:
raise ValueError('Orphaned continuation line')
current_value += b' ' + line.strip()
else:
commit()
current_key = current_value = None
k, v = line.split(b':', 1)
current_key = k.strip().title()
current_value = v.strip()
return hdict
# }}}
def http_communicate(handle_request, conn):
' Represents interaction with a http client over a single, persistent connection '
request_seen = False
def repr_for_pair(pair):
return pair.repr_for_log() if getattr(pair, 'started_request', False) else 'None'
def simple_response(pair, code, msg=''):
if pair and not pair.sent_headers:
try:
pair.simple_response(code, msg=msg)
except socket.error as e:
if e.errno not in socket_errors_to_ignore:
raise
try:
while True:
# (re)set pair to None so that if something goes wrong in
# the HTTPPair constructor, the error doesn't
# get written to the previous request.
pair = None
pair = HTTPPair(handle_request, conn)
# This order of operations should guarantee correct pipelining.
pair.parse_request()
if not pair.ready:
# Something went wrong in the parsing (and the server has
# probably already made a simple_response). Return and
# let the conn close.
return
request_seen = True
pair.respond()
if pair.close_connection:
return
except socket.timeout:
# Don't error if we're between requests; only error
# if 1) no request has been started at all, or 2) we're
# in the middle of a request. This allows persistent
# connections for HTTP/1.1
if (not request_seen) or (pair and pair.started_request):
# Don't bother writing the 408 if the response
# has already started being written.
simple_response(pair, httplib.REQUEST_TIMEOUT)
except socket.error:
# This socket is broken. Log the error and close connection
conn.server_loop.log.exception(
'Communication failed (socket error) while processing request:', repr_for_pair(pair))
except MaxSizeExceeded as e:
conn.server_loop.log.warn('Too large request body (%d > %d) for request:' % (e.size, e.limit), repr_for_pair(pair))
# Can happen if the request uses chunked transfer encoding
simple_response(pair, httplib.REQUEST_ENTITY_TOO_LARGE,
"The entity sent with the request exceeds the maximum "
"allowed bytes (%d)." % pair.max_request_body_size)
except BadChunkedInput as e:
conn.server_loop.log.warn('Bad chunked encoding (%s) for request:' % as_unicode(e.message), repr_for_pair(pair))
simple_response(pair, httplib.BAD_REQUEST,
'Invalid chunked encoding for request body: %s' % as_unicode(e.message))
except Exception:
conn.server_loop.log.exception('Error serving request:', pair.repr_for_log() if getattr(pair, 'started_request', False) else 'None')
simple_response(pair, httplib.INTERNAL_SERVER_ERROR)
class FixedSizeReader(object): # {{{
def __init__(self, socket_file, content_length):
self.socket_file, self.remaining = socket_file, content_length
def read(self, size=-1):
if size < 0:
size = self.remaining
size = min(self.remaining, size)
if size < 1:
return b''
data = self.socket_file.read(size)
self.remaining -= len(data)
return data
# }}}
class ChunkedReader(object): # {{{
def __init__(self, socket_file, maxsize):
self.socket_file, self.maxsize = socket_file, maxsize
self.rbuf = BytesIO()
self.bytes_read = 0
self.finished = False
def check_size(self):
if self.bytes_read > self.maxsize:
raise MaxSizeExceeded('Request entity too large', self.bytes_read, self.maxsize)
def read_chunk(self):
if self.finished:
return
line = self.socket_file.readline()
self.bytes_read += len(line)
self.check_size()
chunk_size = line.strip().split(b';', 1)[0]
try:
chunk_size = int(line, 16) + 2
except Exception:
raise BadChunkedInput('%s is not a valid chunk size' % reprlib.repr(chunk_size))
if chunk_size + self.bytes_read > self.maxsize:
raise MaxSizeExceeded('Request entity too large', self.bytes_read + chunk_size, self.maxsize)
try:
chunk = self.socket_file.read(chunk_size)
except socket.timeout:
raise BadChunkedInput('Timed out waiting for chunk of size %d to complete' % chunk_size)
if len(chunk) < chunk_size:
raise BadChunkedInput('Bad chunked encoding, chunk truncated: %d < %s' % (len(chunk), chunk_size))
if not chunk.endswith(b'\r\n'):
raise BadChunkedInput('Bad chunked encoding: %r != CRLF' % chunk[-2:])
self.rbuf.seek(0, os.SEEK_END)
self.bytes_read += chunk_size
if chunk_size == 2:
self.finished = True
else:
self.rbuf.write(chunk[:-2])
def read(self, size=-1):
if size < 0:
# Read all data
while not self.finished:
self.read_chunk()
self.rbuf.seek(0)
rv = self.rbuf.read()
if rv:
self.rbuf.truncate(0)
return rv
if size == 0:
return b''
while self.rbuf.tell() < size and not self.finished:
self.read_chunk()
data = self.rbuf.getvalue()
self.rbuf.truncate(0)
if size < len(data):
self.rbuf.write(data[size:])
return data[:size]
return data
# }}}
class HTTPPair(object):
''' Represents a HTTP request/response pair '''
def __init__(self, handle_request, conn):
self.conn = conn
self.server_loop = conn.server_loop
self.max_header_line_size = int(self.server_loop.opts.max_header_line_size * 1024)
self.max_request_body_size = int(self.server_loop.opts.max_request_body_size * 1024 * 1024)
self.scheme = 'http' if self.server_loop.ssl_context is None else 'https'
self.inheaders = MultiDict()
self.outheaders = MultiDict()
self.handle_request = handle_request
self.request_line = None
self.path = ()
self.qs = MultiDict()
self.method = None
"""When True, the request has been parsed and is ready to begin generating
the response. When False, signals the calling Connection that the response
should not be generated and the connection should close, immediately after
parsing the request."""
self.ready = False
"""Signals the calling Connection that the request should close. This does
not imply an error! The client and/or server may each request that the
connection be closed, after the response."""
self.close_connection = False
self.started_request = False
self.response_protocol = HTTP1
self.status_code = None
self.sent_headers = False
self.request_content_length = 0
self.chunked_read = False
def parse_request(self):
"""Parse the next HTTP request start-line and message-headers."""
try:
if not self.read_request_line():
return
except MaxSizeExceeded as e:
self.server_loop.log.warn('Too large request URI (%d > %d), dropping connection' % (e.size, e.limit))
self.simple_response(
httplib.REQUEST_URI_TOO_LONG,
"The Request-URI sent with the request exceeds the maximum allowed bytes.")
return
try:
if not self.read_request_headers():
return
except MaxSizeExceeded as e:
self.server_loop.log.warn('Too large header (%d > %d) for request, dropping connection' % (e.size, e.limit))
self.simple_response(
httplib.REQUEST_ENTITY_TOO_LARGE,
"The headers sent with the request exceed the maximum allowed bytes.")
return
self.ready = True
def read_request_line(self):
self.request_line = request_line = self.conn.socket_file.readline(maxsize=self.max_header_line_size)
# Set started_request to True so http_communicate() knows to send 408
# from here on out.
self.started_request = True
if not request_line:
return False
if request_line == b'\r\n':
# RFC 2616 sec 4.1: "...if the server is reading the protocol
# stream at the beginning of a message and receives a CRLF
# first, it should ignore the CRLF."
# But only ignore one leading line! else we enable a DoS.
request_line = self.conn.socket_file.readline(maxsize=self.max_header_line_size)
if not request_line:
return False
if not request_line.endswith(b'\r\n'):
self.simple_response(
httplib.BAD_REQUEST, "HTTP requires CRLF terminators")
return False
try:
method, uri, req_protocol = request_line.strip().split(b' ', 2)
rp = int(req_protocol[5]), int(req_protocol[7])
self.method = method.decode('ascii')
except (ValueError, IndexError):
self.simple_response(httplib.BAD_REQUEST, "Malformed Request-Line")
return False
try:
self.request_protocol = protocol_map[rp]
except KeyError:
self.simple_response(httplib.HTTP_VERSION_NOT_SUPPORTED)
return False
self.response_protocol = protocol_map[min((1, 1), rp)]
scheme, authority, path = parse_request_uri(uri)
if b'#' in path:
self.simple_response(httplib.BAD_REQUEST, "Illegal #fragment in Request-URI.")
return False
if scheme:
try:
self.scheme = scheme.decode('ascii')
except ValueError:
self.simple_response(httplib.BAD_REQUEST, 'Un-decodeable scheme')
return False
qs = b''
if b'?' in path:
path, qs = path.split(b'?', 1)
try:
self.qs = MultiDict.create_from_query_string(qs)
except Exception:
self.simple_response(httplib.BAD_REQUEST, "Malformed Request-Line",
'Unparseable query string')
return False
try:
path = '%2F'.join(unquote(x).decode('utf-8') for x in quoted_slash.split(path))
except ValueError as e:
self.simple_response(httplib.BAD_REQUEST, as_unicode(e))
return False
self.path = tuple(filter(None, (x.replace('%2F', '/') for x in path.split('/'))))
return True
def read_request_headers(self):
# then all the http headers
try:
self.inheaders = read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size))
self.request_content_length = int(self.inheaders.get('Content-Length', 0))
except ValueError as e:
self.simple_response(httplib.BAD_REQUEST, as_unicode(e))
return False
if self.request_content_length > self.max_request_body_size:
self.simple_response(
httplib.REQUEST_ENTITY_TOO_LARGE,
"The entity sent with the request exceeds the maximum "
"allowed bytes (%d)." % self.max_request_body_size)
return False
# Persistent connection support
if self.response_protocol is HTTP11:
# Both server and client are HTTP/1.1
if self.inheaders.get("Connection", "") == "close":
self.close_connection = True
else:
# Either the server or client (or both) are HTTP/1.0
if self.inheaders.get("Connection", "") != "Keep-Alive":
self.close_connection = True
# Transfer-Encoding support
te = ()
if self.response_protocol is HTTP11:
rte = self.inheaders.get("Transfer-Encoding")
if rte:
te = [x.strip().lower() for x in rte.split(",") if x.strip()]
self.chunked_read = False
if te:
for enc in te:
if enc == "chunked":
self.chunked_read = True
else:
# Note that, even if we see "chunked", we must reject
# if there is an extension we don't recognize.
self.simple_response(httplib.NOT_IMPLEMENTED, "Unknown transfer encoding: %r" % enc)
self.close_connection = True
return False
if self.inheaders.get("Expect", '').lower() == "100-continue":
# Don't use simple_response here, because it emits headers
# we don't want.
msg = HTTP11 + " 100 Continue\r\n\r\n"
self.flushed_write(msg.encode('ascii'))
return True
def simple_response(self, status_code, msg="", read_remaining_input=False):
abort = status_code in (httplib.REQUEST_ENTITY_TOO_LARGE, httplib.REQUEST_URI_TOO_LONG)
if abort:
self.close_connection = True
if self.response_protocol is HTTP1:
# HTTP/1.0 has no 413/414 codes
status_code = httplib.BAD_REQUEST
msg = msg.encode('utf-8')
buf = [
'%s %d %s' % (self.response_protocol, status_code, httplib.responses[status_code]),
"Content-Length: %s" % len(msg),
"Content-Type: text/plain; charset=UTF-8",
"Date: " + http_date(),
]
if abort and self.response_protocol is HTTP11:
buf.append("Connection: close")
buf.append('')
buf = [(x + '\r\n').encode('ascii') for x in buf]
if self.method != 'HEAD':
buf.append(msg)
if read_remaining_input:
self.input_reader.read()
self.flushed_write(b''.join(buf))
def send_not_modified(self, etag=None):
buf = [
'%s %d %s' % (self.response_protocol, httplib.NOT_MODIFIED, httplib.responses[httplib.NOT_MODIFIED]),
"Content-Length: 0",
"Date: " + http_date(),
]
if etag is not None:
buf.append('ETag: ' + etag)
self.send_buf(buf)
def send_buf(self, buf, include_cache_headers=True):
if include_cache_headers:
for header in ('Expires', 'Cache-Control', 'Vary'):
val = self.outheaders.get(header)
if val:
buf.append(header + ': ' + val)
buf.append('')
buf = [(x + '\r\n').encode('ascii') for x in buf]
self.flushed_write(b''.join(buf))
def send_range_not_satisfiable(self, content_length):
buf = [
'%s %d %s' % (self.response_protocol, httplib.REQUESTED_RANGE_NOT_SATISFIABLE, httplib.responses[httplib.REQUESTED_RANGE_NOT_SATISFIABLE]),
"Date: " + http_date(),
"Content-Range: bytes */%d" % content_length,
]
self.send_buf(buf)
def flushed_write(self, data):
self.conn.socket_file.write(data)
self.conn.socket_file.flush()
def repr_for_log(self):
ans = ['HTTPPair: %r' % self.request_line]
if self.path:
ans.append('Path: %r' % (self.path,))
if self.qs:
ans.append('Query: %r' % self.qs)
if self.inheaders:
ans.extend(('In Headers:', self.inheaders.pretty('\t')))
if self.outheaders:
ans.extend(('Out Headers:', self.outheaders.pretty('\t')))
return '\n'.join(ans)
def generate_static_output(self, name, generator):
return generate_static_output(self.server_loop.gso_cache, self.server_loop.gso_lock, name, generator)
def respond(self):
if self.chunked_read:
self.input_reader = ChunkedReader(self.conn.socket_file, self.max_request_body_size)
else:
self.input_reader = FixedSizeReader(self.conn.socket_file, self.request_content_length)
try:
output = self.handle_request(self)
except HTTP404 as e:
self.simple_response(httplib.NOT_FOUND, e.message, read_remaining_input=True)
return
# Read and discard any remaining body from the HTTP request
self.input_reader.read()
if self.status_code is None:
self.status_code = httplib.CREATED if self.method == 'POST' else httplib.OK
try:
self.status_code, output = finalize_output(
output, self.inheaders, self.outheaders, self.status_code,
self.response_protocol is HTTP1, self.method, self.server_loop.opts)
except IfNoneMatch as e:
if self.method in ('GET', 'HEAD'):
self.send_not_modified(e.etag)
else:
self.simple_response(httplib.PRECONDITION_FAILED)
return
except RangeNotSatisfiable as e:
self.send_range_not_satisfiable(e.content_length)
return
with self.conn.corked:
self.send_headers()
if self.method != 'HEAD':
output.commit(self.conn.socket_file)
self.conn.socket_file.flush()
def send_headers(self):
self.sent_headers = True
self.outheaders.set('Date', http_date(), replace_all=True)
self.outheaders.set('Server', 'calibre %s' % __version__, replace_all=True)
keep_alive = not self.close_connection and self.server_loop.opts.timeout > 0
if keep_alive:
self.outheaders.set('Keep-Alive', 'timeout=%d' % self.server_loop.opts.timeout)
if 'Connection' not in self.outheaders:
if self.response_protocol is HTTP11:
if self.close_connection:
self.outheaders.set('Connection', 'close')
else:
if not self.close_connection:
self.outheaders.set('Connection', 'Keep-Alive')
ct = self.outheaders.get('Content-Type', '')
if ct.startswith('text/') and 'charset=' not in ct:
self.outheaders.set('Content-Type', ct + '; charset=UTF-8')
buf = [HTTP11 + (' %d ' % self.status_code) + httplib.responses[self.status_code]]
for header, value in sorted(self.outheaders.iteritems(), key=itemgetter(0)):
buf.append('%s: %s' % (header, value))
buf.append('')
self.conn.socket_file.write(b''.join((x + '\r\n').encode('ascii') for x in buf))
def create_http_handler(handle_request):
return partial(http_communicate, handle_request)

View File

@ -0,0 +1,374 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import re, httplib, repr as reprlib
from io import BytesIO, DEFAULT_BUFFER_SIZE
from urllib import unquote
from calibre import as_unicode, force_unicode
from calibre.ptempfile import SpooledTemporaryFile
from calibre.srv.loop import Connection, READ, WRITE
from calibre.srv.utils import MultiDict, HTTP1, HTTP11
protocol_map = {(1, 0):HTTP1, (1, 1):HTTP11}
quoted_slash = re.compile(br'%2[fF]')
HTTP_METHODS = {'HEAD', 'GET', 'PUT', 'POST', 'TRACE', 'DELETE', 'OPTIONS'}
def parse_request_uri(uri): # {{{
"""Parse a Request-URI into (scheme, authority, path).
Note that Request-URI's must be one of::
Request-URI = "*" | absoluteURI | abs_path | authority
Therefore, a Request-URI which starts with a double forward-slash
cannot be a "net_path"::
net_path = "//" authority [ abs_path ]
Instead, it must be interpreted as an "abs_path" with an empty first
path segment::
abs_path = "/" path_segments
path_segments = segment *( "/" segment )
segment = *pchar *( ";" param )
param = *pchar
"""
if uri == b'*':
return None, None, uri
i = uri.find(b'://')
if i > 0 and b'?' not in uri[:i]:
# An absoluteURI.
# If there's a scheme (and it must be http or https), then:
# http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query
# ]]
scheme, remainder = uri[:i].lower(), uri[i + 3:]
authority, path = remainder.split(b'/', 1)
path = b'/' + path
return scheme, authority, path
if uri.startswith(b'/'):
# An abs_path.
return None, None, uri
else:
# An authority.
return None, uri, None
# }}}
# HTTP Header parsing {{{
comma_separated_headers = {
'Accept', 'Accept-Charset', 'Accept-Encoding',
'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control',
'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
'WWW-Authenticate'
}
decoded_headers = {
'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect',
} | comma_separated_headers
class HTTPHeaderParser(object):
'''
Parse HTTP headers. Use this class by repeatedly calling the created object
with a single line at a time and checking the finished attribute. Can raise ValueError
for malformed headers, in which case you should probably return BAD_REQUEST.
Headers which are repeated are folded together using a comma if their
specification so dictates.
'''
__slots__ = ('hdict', 'lines', 'finished')
def __init__(self):
self.hdict = MultiDict()
self.lines = []
self.finished = False
def push(self, *lines):
for line in lines:
self(line)
def __call__(self, line):
'Process a single line'
def safe_decode(hname, value):
try:
return value.decode('ascii')
except UnicodeDecodeError:
if hname in decoded_headers:
raise
return value
def commit():
if not self.lines:
return
line = b' '.join(self.lines)
del self.lines[:]
k, v = line.partition(b':')[::2]
key = k.strip().decode('ascii').title()
val = safe_decode(key, v.strip())
if not key or not val:
raise ValueError('Malformed header line: %s' % reprlib.repr(line))
if key in comma_separated_headers:
existing = self.hdict.pop(key)
if existing is not None:
val = existing + ', ' + val
self.hdict[key] = val
if line == b'\r\n':
# Normal end of headers
commit()
self.finished = True
return
if line[0] in b' \t':
# It's a continuation line.
if not self.lines:
raise ValueError('Orphaned continuation line')
self.lines.append(line.lstrip())
else:
commit()
self.lines.append(line)
def read_headers(readline):
p = HTTPHeaderParser()
while not p.finished:
p(readline())
return p.hdict
# }}}
class HTTPRequest(Connection):
request_handler = None
static_cache = None
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self.corked = False
self.max_header_line_size = int(1024 * self.opts.max_header_line_size)
self.max_request_body_size = int(1024 * 1024 * self.opts.max_request_body_size)
def read(self, buf, endpos):
size = endpos - buf.tell()
if size > 0:
data = self.recv(min(size, DEFAULT_BUFFER_SIZE))
if data:
buf.write(data)
return len(data) >= size
else:
return False
else:
return True
def readline(self, buf):
if buf.tell() >= self.max_header_line_size - 1:
self.simple_response(self.header_line_too_long_error_code)
return
data = self.recv(1)
if data:
buf.write(data)
if b'\n' == data:
line = buf.getvalue()
buf.seek(0), buf.truncate()
if line.endswith(b'\r\n'):
return line
else:
self.simple_response(httplib.BAD_REQUEST, 'HTTP requires CRLF line terminators')
def connection_ready(self):
'Become ready to read an HTTP request'
self.method = self.request_line = None
self.response_protocol = self.request_protocol = HTTP1
self.path = self.query = None
self.close_after_response = False
self.header_line_too_long_error_code = httplib.REQUEST_URI_TOO_LONG
self.response_started = False
self.set_state(READ, self.parse_request_line, BytesIO(), first=True)
def parse_request_line(self, buf, event, first=False): # {{{
line = self.readline(buf)
if line is None:
return
if line == b'\r\n':
# Ignore a single leading empty line, as per RFC 2616 sec 4.1
if first:
return self.set_state(READ, self.parse_request_line, BytesIO())
return self.simple_response(httplib.BAD_REQUEST, 'Multiple leading empty lines not allowed')
try:
method, uri, req_protocol = line.strip().split(b' ', 2)
rp = int(req_protocol[5]), int(req_protocol[7])
self.method = method.decode('ascii').upper()
except Exception:
return self.simple_response(httplib.BAD_REQUEST, "Malformed Request-Line")
if self.method not in HTTP_METHODS:
return self.simple_response(httplib.BAD_REQUEST, "Unknown HTTP method")
try:
self.request_protocol = protocol_map[rp]
except KeyError:
return self.simple_response(httplib.HTTP_VERSION_NOT_SUPPORTED)
self.response_protocol = protocol_map[min((1, 1), rp)]
scheme, authority, path = parse_request_uri(uri)
if b'#' in path:
return self.simple_response(httplib.BAD_REQUEST, "Illegal #fragment in Request-URI.")
if scheme:
try:
self.scheme = scheme.decode('ascii')
except ValueError:
return self.simple_response(httplib.BAD_REQUEST, 'Un-decodeable scheme')
qs = b''
if b'?' in path:
path, qs = path.split(b'?', 1)
try:
self.query = MultiDict.create_from_query_string(qs)
except Exception:
return self.simple_response(httplib.BAD_REQUEST, 'Unparseable query string')
try:
path = '%2F'.join(unquote(x).decode('utf-8') for x in quoted_slash.split(path))
except ValueError as e:
return self.simple_response(httplib.BAD_REQUEST, as_unicode(e))
self.path = tuple(filter(None, (x.replace('%2F', '/') for x in path.split('/'))))
self.header_line_too_long_error_code = httplib.REQUEST_ENTITY_TOO_LARGE
self.request_line = line.rstrip()
self.set_state(READ, self.parse_header_line, HTTPHeaderParser(), BytesIO())
# }}}
@property
def state_description(self):
return 'Request: %s' % force_unicode(self.request_line, 'utf-8')
def parse_header_line(self, parser, buf, event):
line = self.readline(buf)
if line is None:
return
try:
parser(line)
except ValueError:
self.simple_response(httplib.BAD_REQUEST, 'Failed to parse header line')
return
if parser.finished:
self.finalize_headers(parser.hdict)
def finalize_headers(self, inheaders):
request_content_length = int(inheaders.get('Content-Length', 0))
if request_content_length > self.max_request_body_size:
return self.simple_response(httplib.REQUEST_ENTITY_TOO_LARGE,
"The entity sent with the request exceeds the maximum "
"allowed bytes (%d)." % self.max_request_body_size)
# Persistent connection support
if self.response_protocol is HTTP11:
# Both server and client are HTTP/1.1
if inheaders.get("Connection", "") == "close":
self.close_after_response = True
else:
# Either the server or client (or both) are HTTP/1.0
if inheaders.get("Connection", "") != "Keep-Alive":
self.close_after_response = True
# Transfer-Encoding support
te = ()
if self.response_protocol is HTTP11:
rte = inheaders.get("Transfer-Encoding")
if rte:
te = [x.strip().lower() for x in rte.split(",") if x.strip()]
chunked_read = False
if te:
for enc in te:
if enc == "chunked":
chunked_read = True
else:
# Note that, even if we see "chunked", we must reject
# if there is an extension we don't recognize.
return self.simple_response(httplib.NOT_IMPLEMENTED, "Unknown transfer encoding: %r" % enc)
if inheaders.get("Expect", '').lower() == "100-continue":
buf = BytesIO((HTTP11 + " 100 Continue\r\n\r\n").encode('ascii'))
return self.set_state(WRITE, self.write_continue, buf, inheaders, request_content_length, chunked_read)
self.read_request_body(inheaders, request_content_length, chunked_read)
def write_continue(self, buf, inheaders, request_content_length, chunked_read, event):
if self.write(buf):
self.read_request_body(inheaders, request_content_length, chunked_read)
def read_request_body(self, inheaders, request_content_length, chunked_read):
buf = SpooledTemporaryFile(prefix='rq-body-', max_size=DEFAULT_BUFFER_SIZE, dir=self.tdir)
if chunked_read:
self.set_state(READ, self.read_chunk_length, inheaders, BytesIO(), buf, [0])
else:
if request_content_length > 0:
self.set_state(READ, self.sized_read, inheaders, buf, request_content_length)
else:
self.prepare_response(inheaders, BytesIO())
def sized_read(self, inheaders, buf, request_content_length, event):
if self.read(buf, request_content_length):
self.prepare_response(inheaders, buf)
def read_chunk_length(self, inheaders, line_buf, buf, bytes_read, event):
line = self.readline(line_buf)
if line is None:
return
bytes_read[0] += len(line)
try:
chunk_size = int(line.strip(), 16)
except Exception:
return self.simple_response(httplib.BAD_REQUEST, '%s is not a valid chunk size' % reprlib.repr(line.strip()))
if bytes_read[0] + chunk_size + 2 > self.max_request_body_size:
return self.simple_response(httplib.REQUEST_ENTITY_TOO_LARGE,
'Chunked request is larger than %d bytes' % self.max_request_body_size)
if chunk_size == 0:
self.set_state(READ, self.read_chunk_separator, inheaders, BytesIO(), buf, bytes_read, last=True)
else:
self.set_state(READ, self.read_chunk, inheaders, buf, chunk_size, buf.tell() + chunk_size, bytes_read)
def read_chunk(self, inheaders, buf, chunk_size, end, bytes_read, event):
if not self.read(buf, end):
return
bytes_read[0] += chunk_size
self.set_state(READ, self.read_chunk_separator, inheaders, BytesIO(), buf, bytes_read)
def read_chunk_separator(self, inheaders, line_buf, buf, bytes_read, event, last=False):
line = self.readline(line_buf)
if line is None:
return
if line != b'\r\n':
return self.simple_response(httplib.BAD_REQUEST, 'Chunk does not have trailing CRLF')
bytes_read[0] += len(line)
if bytes_read[0] > self.max_request_body_size:
return self.simple_response(httplib.REQUEST_ENTITY_TOO_LARGE,
'Chunked request is larger than %d bytes' % self.max_request_body_size)
if last:
self.prepare_response(inheaders, buf)
else:
self.set_state(READ, self.read_chunk_length, inheaders, BytesIO(), buf, bytes_read)
def handle_timeout(self):
if self.response_started:
return False
self.simple_response(httplib.REQUEST_TIMEOUT)
return True
def write(self, buf, end=None):
raise NotImplementedError()
def simple_response(self, status_code, msg='', close_after_response=True):
raise NotImplementedError()
def prepare_response(self, inheaders, request_body_file):
raise NotImplementedError()

View File

@ -0,0 +1,525 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import os, httplib, hashlib, uuid, zlib, time, struct, repr as reprlib
from select import PIPE_BUF
from collections import namedtuple
from io import BytesIO, DEFAULT_BUFFER_SIZE
from itertools import chain, repeat, izip_longest
from operator import itemgetter
from functools import wraps
from calibre import guess_type, force_unicode
from calibre.constants import __version__
from calibre.srv.loop import WRITE
from calibre.srv.errors import HTTP404
from calibre.srv.http_request import HTTPRequest, read_headers
from calibre.srv.sendfile import file_metadata, sendfile_to_socket_async, CannotSendfile, SendfileInterrupted
from calibre.srv.utils import MultiDict, start_cork, stop_cork, http_date, HTTP1, HTTP11, socket_errors_socket_closed
Range = namedtuple('Range', 'start stop size')
MULTIPART_SEPARATOR = uuid.uuid4().hex.decode('ascii')
def header_list_to_file(buf): # {{{
buf.append('')
return BytesIO(b''.join((x + '\r\n').encode('ascii') for x in buf))
# }}}
def parse_multipart_byterange(buf, content_type): # {{{
sep = (content_type.rsplit('=', 1)[-1]).encode('utf-8')
ans = []
def parse_part():
line = buf.readline()
if not line:
raise ValueError('Premature end of message')
if not line.startswith(b'--' + sep):
raise ValueError('Malformed start of multipart message: %s' % reprlib.repr(line))
if line.endswith(b'--'):
return None
headers = read_headers(buf.readline)
cr = headers.get('Content-Range')
if not cr:
raise ValueError('Missing Content-Range header in sub-part')
if not cr.startswith('bytes '):
raise ValueError('Malformed Content-Range header in sub-part, no prefix')
try:
start, stop = map(lambda x: int(x.strip()), cr.partition(' ')[-1].partition('/')[0].partition('-')[::2])
except Exception:
raise ValueError('Malformed Content-Range header in sub-part, failed to parse byte range')
content_length = stop - start + 1
ret = buf.read(content_length)
if len(ret) != content_length:
raise ValueError('Malformed sub-part, length of body not equal to length specified in Content-Range')
buf.readline()
return (start, ret)
while True:
data = parse_part()
if data is None:
break
ans.append(data)
return ans
# }}}
def parse_if_none_match(val): # {{{
return {x.strip() for x in val.split(',')}
# }}}
def acceptable_encoding(val, allowed=frozenset({'gzip'})): # {{{
def enc(x):
e, r = x.partition(';')[::2]
p, v = r.partition('=')[::2]
q = 1.0
if p == 'q' and v:
try:
q = float(v)
except Exception:
pass
return e.lower(), q
emap = dict(enc(x.strip()) for x in val.split(','))
acceptable = sorted(set(emap) & allowed, key=emap.__getitem__, reverse=True)
if acceptable:
return acceptable[0]
# }}}
def get_ranges(headervalue, content_length): # {{{
''' Return a list of ranges from the Range header. If this function returns
an empty list, it indicates no valid range was found. '''
if not headervalue:
return None
result = []
try:
bytesunit, byteranges = headervalue.split("=", 1)
except Exception:
return None
if bytesunit.strip() != 'bytes':
return None
for brange in byteranges.split(","):
start, stop = [x.strip() for x in brange.split("-", 1)]
if start:
if not stop:
stop = content_length - 1
try:
start, stop = int(start), int(stop)
except Exception:
continue
if start >= content_length:
continue
if stop < start:
continue
stop = min(stop, content_length - 1)
result.append(Range(start, stop, stop - start + 1))
elif stop:
# Negative subscript (last N bytes)
try:
stop = int(stop)
except Exception:
continue
if stop > content_length:
result.append(Range(0, content_length-1, content_length))
else:
result.append(Range(content_length - stop, content_length - 1, stop))
return result
# }}}
# gzip transfer encoding {{{
def gzip_prefix(mtime=None):
# See http://www.gzip.org/zlib/rfc-gzip.html
if mtime is None:
mtime = time.time()
return b''.join((
b'\x1f\x8b', # ID1 and ID2: gzip marker
b'\x08', # CM: compression method
b'\x00', # FLG: none set
# MTIME: 4 bytes
struct.pack(b"<L", int(mtime) & 0xFFFFFFFF),
b'\x02', # XFL: max compression, slowest algo
b'\xff', # OS: unknown
))
def compress_readable_output(src_file, compress_level=6):
crc = zlib.crc32(b"")
size = 0
zobj = zlib.compressobj(compress_level,
zlib.DEFLATED, -zlib.MAX_WBITS,
zlib.DEF_MEM_LEVEL, 0)
prefix_written = False
while True:
data = src_file.read(DEFAULT_BUFFER_SIZE)
if not data:
break
size += len(data)
crc = zlib.crc32(data, crc)
data = zobj.compress(data)
if not prefix_written:
prefix_written = True
data = gzip_prefix(time.time()) + data
yield data
yield zobj.flush() + struct.pack(b"<L", crc & 0xFFFFFFFF) + struct.pack(b"<L", size & 0xFFFFFFFF)
# }}}
def get_range_parts(ranges, content_type, content_length): # {{{
def part(r):
ans = ['--%s' % MULTIPART_SEPARATOR, 'Content-Range: bytes %d-%d/%d' % (r.start, r.stop, content_length)]
if content_type:
ans.append('Content-Type: %s' % content_type)
ans.append('')
return ('\r\n'.join(ans)).encode('ascii')
return list(map(part, ranges)) + [('--%s--' % MULTIPART_SEPARATOR).encode('ascii')]
# }}}
class RequestData(object): # {{{
def __init__(self, method, path, query, inheaders, request_body_file, outheaders, response_protocol, static_cache, opts):
self.method, self.path, self.query, self.inheaders, self.request_body_file, self.outheaders, self.response_protocol, self.static_cache = (
method, path, query, inheaders, request_body_file, outheaders, response_protocol, static_cache
)
self.opts = opts
self.status_code = httplib.CREATED if self.method == 'POST' else httplib.OK
def generate_static_output(self, name, generator):
ans = self.static_cache.get(name)
if ans is None:
ans = self.static_cache[name] = StaticOutput(generator())
return ans
def read(self, size=-1):
return self.request_body_file.read(size)
# }}}
class ReadableOutput(object):
def __init__(self, output, etag=None, content_length=None):
self.src_file = output
if content_length is None:
self.src_file.seek(0, os.SEEK_END)
self.content_length = self.src_file.tell()
else:
self.content_length = content_length
self.etag = etag
self.accept_ranges = True
self.use_sendfile = False
self.src_file.seek(0)
def filesystem_file_output(output, outheaders, stat_result):
etag = '"%s"' % hashlib.sha1(type('')(stat_result.st_mtime) + force_unicode(output.name or '')).hexdigest()
self = ReadableOutput(output, etag=etag, content_length=stat_result.st_size)
self.name = output.name
self.use_sendfile = True
return self
def dynamic_output(output, outheaders):
if isinstance(output, bytes):
data = output
else:
data = output.encode('utf-8')
ct = outheaders.get('Content-Type')
if not ct:
outheaders.set('Content-Type', 'text/plain; charset=UTF-8', replace_all=True)
ans = ReadableOutput(BytesIO(data))
ans.accept_ranges = False
return ans
class GeneratedOutput(object):
def __init__(self, output, etag=None):
self.output = output
self.content_length = None
self.etag = etag
self.accept_ranges = False
class StaticOutput(object):
def __init__(self, data):
if isinstance(data, type('')):
data = data.encode('utf-8')
self.data = data
self.etag = '"%s"' % hashlib.sha1(data).hexdigest()
self.content_length = len(data)
class HTTPConnection(HTTPRequest):
def write(self, buf, end=None):
pos = buf.tell()
if end is None:
buf.seek(0, os.SEEK_END)
end = buf.tell()
buf.seek(pos)
limit = end - pos
if limit == 0:
return True
if self.use_sendfile and not isinstance(buf, BytesIO):
try:
sent = sendfile_to_socket_async(buf, pos, limit, self.socket)
except CannotSendfile:
self.use_sendfile = False
return False
except SendfileInterrupted:
return False
except IOError as e:
if e.errno in socket_errors_socket_closed:
self.ready = self.use_sendfile = False
return False
raise
if sent == 0:
# Something bad happened, was the file modified on disk by
# another process?
self.use_sendfile = self.ready = False
raise IOError('sendfile() failed to write any bytes to the socket')
else:
sent = self.send(buf.read(min(limit, PIPE_BUF)))
buf.seek(pos + sent)
return buf.tell() == end
def simple_response(self, status_code, msg='', close_after_response=True):
if self.response_protocol is HTTP1 and status_code in (httplib.REQUEST_ENTITY_TOO_LARGE, httplib.REQUEST_URI_TOO_LONG):
# HTTP/1.0 has no 413/414 codes
status_code = httplib.BAD_REQUEST
self.close_after_response = close_after_response
msg = msg.encode('utf-8')
ct = 'http' if self.method == 'TRACE' else 'plain'
buf = [
'%s %d %s' % (self.response_protocol, status_code, httplib.responses[status_code]),
"Content-Length: %s" % len(msg),
"Content-Type: text/%s; charset=UTF-8" % ct,
"Date: " + http_date(),
]
if self.close_after_response and self.response_protocol is HTTP11:
buf.append("Connection: close")
buf.append('')
buf = [(x + '\r\n').encode('ascii') for x in buf]
if self.method != 'HEAD':
buf.append(msg)
self.response_ready(BytesIO(b''.join(buf)))
def prepare_response(self, inheaders, request_body_file):
if self.method == 'TRACE':
msg = force_unicode(self.request_line, 'utf-8') + '\n' + inheaders.pretty()
return self.simple_response(httplib.OK, msg, close_after_response=False)
request_body_file.seek(0)
outheaders = MultiDict()
data = RequestData(self.method, self.path, self.query, inheaders, request_body_file, outheaders, self.response_protocol, self.static_cache, self.opts)
try:
output = self.request_handler(data)
except HTTP404 as e:
return self.simple_response(httplib.NOT_FOUND, msg=e.message or '', close_after_response=False)
output = self.finalize_output(output, data, self.method is HTTP1)
if output is None:
return
outheaders.set('Date', http_date(), replace_all=True)
outheaders.set('Server', 'calibre %s' % __version__, replace_all=True)
keep_alive = not self.close_after_response and self.opts.timeout > 0
if keep_alive:
outheaders.set('Keep-Alive', 'timeout=%d' % self.opts.timeout)
if 'Connection' not in outheaders:
if self.response_protocol is HTTP11:
if self.close_after_response:
outheaders.set('Connection', 'close')
else:
if not self.close_after_response:
outheaders.set('Connection', 'Keep-Alive')
ct = outheaders.get('Content-Type', '')
if ct.startswith('text/') and 'charset=' not in ct:
outheaders.set('Content-Type', ct + '; charset=UTF-8')
buf = [HTTP11 + (' %d ' % data.status_code) + httplib.responses[data.status_code]]
for header, value in sorted(outheaders.iteritems(), key=itemgetter(0)):
buf.append('%s: %s' % (header, value))
buf.append('')
self.response_ready(BytesIO(b''.join((x + '\r\n').encode('ascii') for x in buf)), output=output)
def send_range_not_satisfiable(self, content_length):
buf = [
'%s %d %s' % (self.response_protocol, httplib.REQUESTED_RANGE_NOT_SATISFIABLE, httplib.responses[httplib.REQUESTED_RANGE_NOT_SATISFIABLE]),
"Date: " + http_date(),
"Content-Range: bytes */%d" % content_length,
]
self.response_ready(header_list_to_file(buf))
def send_not_modified(self, etag=None):
buf = [
'%s %d %s' % (self.response_protocol, httplib.NOT_MODIFIED, httplib.responses[httplib.NOT_MODIFIED]),
"Content-Length: 0",
"Date: " + http_date(),
]
if etag is not None:
buf.append('ETag: ' + etag)
self.response_ready(header_list_to_file(buf))
def response_ready(self, header_file, output=None):
self.response_started = True
start_cork(self.socket)
self.corked = True
self.use_sendfile = False
self.set_state(WRITE, self.write_response_headers, header_file, output)
def write_response_headers(self, buf, output, event):
if self.write(buf):
self.write_response_body(output)
def write_response_body(self, output):
if output is None or self.method == 'HEAD':
self.reset_state()
return
if isinstance(output, ReadableOutput):
self.use_sendfile = output.use_sendfile and self.opts.use_sendfile and sendfile_to_socket_async is not None
if output.ranges is not None:
if isinstance(output.ranges, Range):
r = output.ranges
output.src_file.seek(r.start)
self.set_state(WRITE, self.write_buf, output.src_file, end=r.stop + 1)
else:
self.set_state(WRITE, self.write_ranges, output.src_file, output.ranges, first=True)
else:
self.set_state(WRITE, self.write_buf, output.src_file)
elif isinstance(output, GeneratedOutput):
self.set_state(WRITE, self.write_iter, chain(output.output, repeat(None, 1)))
else:
raise TypeError('Unknown output type: %r' % output)
def write_buf(self, buf, event, end=None):
if self.write(buf, end=end):
self.reset_state()
def write_ranges(self, buf, ranges, event, first=False):
r, range_part = next(ranges)
if r is None:
# EOF range part
self.set_state(WRITE, self.write_buf, BytesIO(b'\r\n' + range_part))
else:
buf.seek(r.start)
self.set_state(WRITE, self.write_range_part, BytesIO((b'' if first else b'\r\n') + range_part + b'\r\n'), buf, r.stop + 1, ranges)
def write_range_part(self, part_buf, buf, end, ranges, event):
if self.write(part_buf):
self.set_state(WRITE, self.write_range, buf, end, ranges)
def write_range(self, buf, end, ranges, event):
if self.write(buf, end=end):
self.set_state(WRITE, self.write_ranges, buf, ranges)
def write_iter(self, output, event):
chunk = next(output)
if chunk is None:
self.set_state(WRITE, self.write_chunk, BytesIO(b'0\r\n\r\n'), output, last=True)
else:
if not isinstance(chunk, bytes):
chunk = chunk.encode('utf-8')
chunk = ('%X\r\n' % len(chunk)).encode('ascii') + chunk + b'\r\n'
self.set_state(WRITE, self.write_chunk, BytesIO(chunk), output)
def write_chunk(self, buf, output, event, last=False):
if self.write(buf):
if last:
self.reset_state()
else:
self.set_state(WRITE, self.write_iter, output)
def reset_state(self):
self.connection_ready()
self.ready = not self.close_after_response
stop_cork(self.socket)
self.corked = False
def report_unhandled_exception(self, e, formatted_traceback):
self.simple_response(httplib.INTERNAL_SERVER_ERROR)
def finalize_output(self, output, request, is_http1):
opts = self.opts
outheaders = request.outheaders
stat_result = file_metadata(output)
if stat_result is not None:
output = filesystem_file_output(output, outheaders, stat_result)
if 'Content-Type' not in outheaders:
mt = guess_type(output.name)[0]
if mt:
if mt in {'text/plain', 'text/html', 'application/javascript', 'text/css'}:
mt =+ '; charset=UTF-8'
outheaders['Content-Type'] = mt
elif isinstance(output, (bytes, type(''))):
output = dynamic_output(output, outheaders)
elif hasattr(output, 'read'):
output = ReadableOutput(output)
elif isinstance(output, StaticOutput):
output = ReadableOutput(BytesIO(output.data), etag=output.etag, content_length=output.content_length)
else:
output = GeneratedOutput(output)
ct = outheaders.get('Content-Type', '').partition(';')[0]
compressible = (not ct or ct.startswith('text/') or ct.startswith('image/svg') or
ct in {'application/json', 'application/javascript'})
compressible = (compressible and request.status_code == httplib.OK and
(opts.compress_min_size > -1 and output.content_length >= opts.compress_min_size) and
acceptable_encoding(request.inheaders.get('Accept-Encoding', '')) and not is_http1)
accept_ranges = (not compressible and output.accept_ranges is not None and request.status_code == httplib.OK and
not is_http1)
ranges = get_ranges(request.inheaders.get('Range'), output.content_length) if output.accept_ranges and self.method in ('GET', 'HEAD') else None
if_range = (request.inheaders.get('If-Range') or '').strip()
if if_range and if_range != output.etag:
ranges = None
if ranges is not None and not ranges:
return self.send_range_not_satisfiable(output.content_length)
for header in ('Accept-Ranges', 'Content-Encoding', 'Transfer-Encoding', 'ETag', 'Content-Length'):
outheaders.pop('header', all=True)
none_match = parse_if_none_match(request.inheaders.get('If-None-Match', ''))
matched = '*' in none_match or (output.etag and output.etag in none_match)
if matched:
if self.method in ('GET', 'HEAD'):
self.send_not_modified(output.etag)
else:
self.simple_response(httplib.PRECONDITION_FAILED)
return
output.ranges = None
if output.etag and self.method in ('GET', 'HEAD'):
outheaders.set('ETag', output.etag, replace_all=True)
if accept_ranges:
outheaders.set('Accept-Ranges', 'bytes', replace_all=True)
if compressible and not ranges:
outheaders.set('Content-Encoding', 'gzip', replace_all=True)
output = GeneratedOutput(compress_readable_output(output.src_file), etag=output.etag)
if output.content_length is not None and not compressible and not ranges:
outheaders.set('Content-Length', '%d' % output.content_length, replace_all=True)
if compressible or output.content_length is None:
outheaders.set('Transfer-Encoding', 'chunked', replace_all=True)
if ranges:
if len(ranges) == 1:
r = ranges[0]
outheaders.set('Content-Length', '%d' % r.size, replace_all=True)
outheaders.set('Content-Range', 'bytes %d-%d/%d' % (r.start, r.stop, output.content_length), replace_all=True)
output.ranges = r
else:
range_parts = get_range_parts(ranges, outheaders.get('Content-Type'), output.content_length)
size = sum(map(len, range_parts)) + sum(r.size + 4 for r in ranges)
outheaders.set('Content-Length', '%d' % size, replace_all=True)
outheaders.set('Content-Type', 'multipart/byteranges; boundary=' + MULTIPART_SEPARATOR, replace_all=True)
output.ranges = izip_longest(ranges, range_parts)
request.status_code = httplib.PARTIAL_CONTENT
return output
def create_http_handler(handler):
static_cache = {} # noqa
@wraps(handler)
def wrapper(*args, **kwargs):
ans = HTTPConnection(*args, **kwargs)
ans.request_handler = handler
ans.static_cache = {}
return ans
return wrapper

View File

@ -6,510 +6,132 @@ from __future__ import (unicode_literals, division, absolute_import,
__license__ = 'GPL v3' __license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>' __copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import socket, os, ssl, time, sys import ssl, socket, select, os, traceback
from operator import and_ from io import BytesIO
from Queue import Queue, Full from functools import partial
from threading import Thread, current_thread, Lock
from io import DEFAULT_BUFFER_SIZE, BytesIO
from calibre.srv.errors import MaxSizeExceeded from calibre import as_unicode
from calibre.ptempfile import TemporaryDirectory
from calibre.srv.opts import Options from calibre.srv.opts import Options
from calibre.srv.utils import socket_errors_to_ignore, socket_error_eintr, socket_errors_nonblocking, Corked, HandleInterrupt from calibre.srv.utils import (
socket_errors_socket_closed, socket_errors_nonblocking, HandleInterrupt, socket_errors_eintr)
from calibre.utils.socket_inheritance import set_socket_inherit from calibre.utils.socket_inheritance import set_socket_inherit
from calibre.utils.logging import ThreadSafeLog from calibre.utils.logging import ThreadSafeLog
from calibre.utils.monotonic import monotonic
class SocketFile(object): # {{{ READ, WRITE, RDWR = 'READ', 'WRITE', 'RDWR'
"""Faux file object attached to a socket object. Works with non-blocking
sockets, unlike the fileobject created by socket.makefile() """
name = "<socket>" class Connection(object):
__slots__ = ( def __init__(self, socket, opts, ssl_context, tdir):
"mode", "bufsize", "softspace", "_sock", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", "_wbuf_len", "_close", 'bytes_read', 'bytes_written', self.opts = opts
) self.tdir = tdir
self.ssl_context = ssl_context
def __init__(self, sock, bufsize=-1, close=False): self.wait_for = READ
self._sock = sock self.response_started = False
self.bytes_read = self.bytes_written = 0 if self.ssl_context is not None:
self.mode = 'r+b' self.ready = False
self.bufsize = DEFAULT_BUFFER_SIZE if bufsize < 0 else bufsize self.socket = self.ssl_context.wrap_socket(socket, server_side=True, do_handshake_on_connect=False)
self.softspace = False self.set_state(RDWR, self.do_ssl_handshake)
# _rbufsize is the suggested recv buffer size. It is *strictly*
# obeyed within readline() for recv calls. If it is larger than
# default_bufsize it will be used for recv calls within read().
if self.bufsize == 0:
self._rbufsize = 1
elif bufsize == 1:
self._rbufsize = DEFAULT_BUFFER_SIZE
else: else:
self._rbufsize = bufsize self.ready = True
self._wbufsize = bufsize self.socket = socket
# We use BytesIO for the read buffer to avoid holding a list self.connection_ready()
# of variously sized string objects which have been known to self.last_activity = monotonic()
# fragment the heap due to how they are malloc()ed and often
# realloc()ed down much smaller than their original allocation.
self._rbuf = BytesIO()
self._wbuf = [] # A list of strings
self._wbuf_len = 0
self._close = close
@property def set_state(self, wait_for, func, *args, **kwargs):
def closed(self): self.wait_for = wait_for
return self._sock is None if args or kwargs:
func = partial(func, *args, **kwargs)
self.handle_event = func
def do_ssl_handshake(self, event):
try:
self._sslobj.do_handshake()
except ssl.SSLWantReadError:
self.set_state(READ, self.do_ssl_handshake)
except ssl.SSLWantWriteError:
self.set_state(WRITE, self.do_ssl_handshake)
self.ready = True
self.connection_ready()
def send(self, data):
try:
ret = self.socket.send(data)
self.last_activity = monotonic()
return ret
except socket.error as e:
if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr:
return 0
elif e.errno in socket_errors_socket_closed:
self.ready = False
return 0
raise
def recv(self, buffer_size):
try:
data = self.socket.recv(buffer_size)
self.last_activity = monotonic()
if not data:
# a closed connection is indicated by signaling
# a read condition, and having recv() return 0.
self.ready = False
return b''
return data
except socket.error as e:
if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr:
return b''
if e.errno in socket_errors_socket_closed:
self.ready = False
return b''
raise
def close(self): def close(self):
try: self.ready = False
if self._sock is not None: self.handle_event = None # prevent reference cycles
try:
self.flush()
except socket.error:
pass
finally:
if self._close and self._sock is not None:
self._sock.close()
self._sock = None
def __del__(self):
try:
self.close()
except:
# close() may fail if __init__ didn't complete
pass
def fileno(self):
return self._sock.fileno()
def gettimeout(self):
return self._sock.gettimeout()
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
def flush(self):
if self._wbuf_len:
data = b''.join(self._wbuf)
self._wbuf = []
self._wbuf_len = 0
data_size = len(data)
view = memoryview(data)
write_offset = 0
buffer_size = max(self._rbufsize, DEFAULT_BUFFER_SIZE)
try:
while write_offset < data_size:
try:
bytes_sent = self._sock.send(view[write_offset:write_offset+buffer_size])
write_offset += bytes_sent
self.bytes_written += bytes_sent
except socket.error as e:
if e.args[0] not in socket_errors_nonblocking:
raise
finally:
if write_offset < data_size:
remainder = data[write_offset:]
self._wbuf.append(remainder)
self._wbuf_len = len(remainder)
del view, data # explicit free
def write(self, data):
if not isinstance(data, bytes):
raise TypeError('Cannot write data of type: %s to a socket' % type(data))
if not data:
return
self._wbuf.append(data)
self._wbuf_len += len(data)
if self._wbufsize == 0 or (self._wbufsize == 1 and b'\n' in data) or (self._wbufsize > 1 and self._wbuf_len >= self._wbufsize):
self.flush()
def writelines(self, lines):
for line in lines:
self.write(line)
def recv(self, size):
while True:
try:
data = self._sock.recv(size)
self.bytes_read += len(data)
return data
except socket.error, e:
if e.args[0] not in socket_errors_nonblocking and e.args[0] not in socket_error_eintr:
raise
def read(self, size=-1):
# Use max, disallow tiny reads in a loop as they are very inefficient.
# We never leave read() with any leftover data from a new recv() call
# in our internal buffer.
rbufsize = max(self._rbufsize, DEFAULT_BUFFER_SIZE)
buf = self._rbuf
buf.seek(0, os.SEEK_END)
if size < 0:
# Read until EOF
self._rbuf = BytesIO() # reset _rbuf. we consume it via buf.
while True:
data = self.recv(rbufsize)
if not data:
break
buf.write(data)
return buf.getvalue()
else:
# Read until size bytes or EOF seen, whichever comes first
buf_len = buf.tell()
if buf_len >= size:
# Already have size bytes in our buffer? Extract and return.
buf.seek(0)
rv = buf.read(size)
self._rbuf = BytesIO()
self._rbuf.write(buf.read())
return rv
self._rbuf = BytesIO() # reset _rbuf. we consume it via buf.
while True:
left = size - buf_len
# recv() will malloc the amount of memory given as its
# parameter even though it often returns much less data
# than that. The returned data string is short lived
# as we copy it into a StringIO and free it. This avoids
# fragmentation issues on many platforms.
data = self.recv(left)
if not data:
break
n = len(data)
if n == size and not buf_len:
# Shortcut. Avoid buffer data copies when:
# - We have no data in our buffer.
# AND
# - Our call to recv returned exactly the
# number of bytes we were asked to read.
return data
if n == left:
buf.write(data)
del data # explicit free
break
buf.write(data) # noqa
buf_len += n
del data # noqa explicit free
return buf.getvalue()
def readline(self, size=-1, maxsize=sys.maxsize):
buf = self._rbuf
buf.seek(0, os.SEEK_END)
if buf.tell() > 0:
# check if we already have it in our buffer
buf.seek(0)
bline = buf.readline(size)
self._rbuf = BytesIO()
if bline.endswith(b'\n') or len(bline) == size:
self._rbuf.write(buf.read())
if len(bline) > maxsize:
raise MaxSizeExceeded('Line length', len(bline), maxsize)
return bline
else:
self._rbuf.write(bline)
self._rbuf.write(buf.read())
del bline
if size < 0:
# Read until \n or EOF, whichever comes first
if self._rbufsize <= 1:
# Speed up unbuffered case
buf.seek(0)
buffers = [buf.read()]
self._rbuf = BytesIO() # reset _rbuf. we consume it via buf.
data = None
recv = self.recv
sz = len(buffers[0])
while data != b'\n':
data = recv(1)
if not data:
break
sz += 1
if sz > maxsize:
raise MaxSizeExceeded('Line length', sz, maxsize)
buffers.append(data)
return b''.join(buffers)
buf.seek(0, os.SEEK_END)
self._rbuf = BytesIO() # reset _rbuf. we consume it via buf.
while True:
data = self.recv(self._rbufsize)
if not data:
break
nl = data.find(b'\n')
if nl >= 0:
nl += 1
buf.write(data[:nl])
self._rbuf.write(data[nl:])
del data
break
buf.write(data) # noqa
if buf.tell() > maxsize:
raise MaxSizeExceeded('Line length', buf.tell(), maxsize)
return buf.getvalue()
else:
# Read until size bytes or \n or EOF seen, whichever comes first
buf.seek(0, os.SEEK_END)
buf_len = buf.tell()
if buf_len >= size:
buf.seek(0)
rv = buf.read(size)
self._rbuf = BytesIO()
self._rbuf.write(buf.read())
if len(rv) > maxsize:
raise MaxSizeExceeded('Line length', len(rv), maxsize)
return rv
self._rbuf = BytesIO() # reset _rbuf. we consume it via buf.
while True:
data = self.recv(self._rbufsize)
if not data:
break
left = size - buf_len
# did we just receive a newline?
nl = data.find(b'\n', 0, left)
if nl >= 0:
nl += 1
# save the excess data to _rbuf
self._rbuf.write(data[nl:])
if buf_len:
buf.write(data[:nl])
break
else:
# Shortcut. Avoid data copy through buf when returning
# a substring of our first recv() and buf has no
# existing data.
if nl > maxsize:
raise MaxSizeExceeded('Line length', nl, maxsize)
return data[:nl]
n = len(data)
if n == size and not buf_len:
# Shortcut. Avoid data copy through buf when
# returning exactly all of our first recv().
if n > maxsize:
raise MaxSizeExceeded('Line length', n, maxsize)
return data
if n >= left:
buf.write(data[:left])
self._rbuf.write(data[left:])
break
buf.write(data)
buf_len += n
if buf.tell() > maxsize:
raise MaxSizeExceeded('Line length', buf.tell(), maxsize)
return buf.getvalue()
def readlines(self, sizehint=0, maxsize=sys.maxsize):
total = 0
ans = []
while True:
line = self.readline(maxsize=maxsize)
if not line:
break
ans.append(line)
total += len(line)
if sizehint and total >= sizehint:
break
return ans
def __iter__(self):
line = True
while line:
line = self.readline()
if line:
yield line
# }}}
class Connection(object): # {{{
' A thin wrapper around an active socket '
remote_addr = None
remote_port = None
def __init__(self, server_loop, socket):
self.server_loop = server_loop
self.socket = socket
self.corked = Corked(socket)
self.socket_file = SocketFile(socket)
self.closed = False
def close(self):
"""Close the socket underlying this connection."""
if self.closed:
return
self.socket_file.close()
try: try:
self.socket.shutdown(socket.SHUT_WR) self.socket.shutdown(socket.SHUT_WR)
self.socket.close() self.socket.close()
except socket.error: except socket.error:
pass pass
self.closed = True
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
# }}}
class WorkerThread(Thread): # {{{
daemon = True
def __init__(self, server_loop):
self.serving = False
self.server_loop = server_loop
self.conn = None
self.forcible_shutdown = False
Thread.__init__(self, name='ServerWorker')
def run(self):
try:
while True:
self.serving = False
self.conn = conn = self.server_loop.requests.get()
if conn is None:
return # Clean exit
with conn, self:
self.server_loop.req_resp_handler(conn)
except (KeyboardInterrupt, SystemExit):
self.server_loop.stop()
except socket.error:
if not self.forcible_shutdown:
raise
def __enter__(self):
self.serving = True
return self
def __exit__(self, *args):
self.serving = False
# }}}
class ThreadPool(object): # {{{
def __init__(self, server_loop, min_threads=10, max_threads=-1, accepted_queue_size=-1, accepted_queue_timeout=10):
self.server_loop = server_loop
self.min_threads = max(1, min_threads)
self.max_threads = max_threads
self._threads = []
self._queue = Queue(maxsize=accepted_queue_size)
self._queue_put_timeout = accepted_queue_timeout
self.get = self._queue.get
def start(self):
"""Start the pool of threads."""
self._threads = [self._spawn_worker() for i in xrange(self.min_threads)]
@property @property
def idle(self): def state_description(self):
return sum(int(not w.serving) for w in self._threads) return ''
@property def report_unhandled_exception(self, e, formatted_traceback):
def busy(self): pass
return sum(int(w.serving) for w in self._threads)
def put(self, obj): def connection_ready(self):
self._queue.put(obj, block=True, timeout=self._queue_put_timeout) raise NotImplementedError()
def grow(self, amount): def handle_timeout(self):
"""Spawn new worker threads (not above self.max_threads).""" return False
budget = max(self.max_threads - len(self._threads), 0) if self.max_threads > 0 else sys.maxsize
n_new = min(amount, budget)
self._threads.extend([self._spawn_worker() for i in xrange(n_new)])
def _spawn_worker(self):
worker = WorkerThread(self.server_loop)
worker.start()
return worker
@staticmethod
def _all(func, items):
results = [func(item) for item in items]
return reduce(and_, results, True)
def shrink(self, amount):
"""Kill off worker threads (not below self.min_threads)."""
# Grow/shrink the pool if necessary.
# Remove any dead threads from our list
orig = len(self._threads)
self._threads = [t for t in self._threads if t.is_alive()]
amount -= orig - len(self._threads)
# calculate the number of threads above the minimum
n_extra = max(len(self._threads) - self.min_threads, 0)
# don't remove more than amount
n_to_remove = min(amount, n_extra)
# put shutdown requests on the queue equal to the number of threads
# to remove. As each request is processed by a worker, that worker
# will terminate and be culled from the list.
for n in xrange(n_to_remove):
self._queue.put(None)
def stop(self, timeout=5):
# Must shut down threads here so the code that calls
# this method can know when all threads are stopped.
for worker in self._threads:
self._queue.put(None)
# Don't join the current thread (this should never happen, since
# ServerLoop calls stop() in its own thread, but better to be safe).
current = current_thread()
if timeout and timeout >= 0:
endtime = time.time() + timeout
while self._threads:
worker = self._threads.pop()
if worker is not current and worker.is_alive():
try:
if timeout is None or timeout < 0:
worker.join()
else:
remaining_time = endtime - time.time()
if remaining_time > 0:
worker.join(remaining_time)
if worker.is_alive():
# We exhausted the timeout.
# Forcibly shut down the socket.
worker.forcible_shutdown = True
c = worker.conn
if c and not c.socket_file.closed:
c.socket.shutdown(socket.SHUT_RDWR)
c.socket.close()
worker.join()
except KeyboardInterrupt:
pass # Ignore repeated Ctrl-C.
@property
def qsize(self):
return self._queue.qsize()
# }}}
class ServerLoop(object): class ServerLoop(object):
def __init__( def __init__(
self, self,
req_resp_handler, handler,
bind_address=('localhost', 8080), bind_address=('localhost', 8080),
opts=None, opts=None,
# A calibre logging object. If None a default log that logs to # A calibre logging object. If None, a default log that logs to
# stdout is used # stdout is used
log=None log=None
): ):
self.ready = False
self.handler = handler
self.opts = opts or Options() self.opts = opts or Options()
self.req_resp_handler = req_resp_handler
self.log = log or ThreadSafeLog(level=ThreadSafeLog.DEBUG) self.log = log or ThreadSafeLog(level=ThreadSafeLog.DEBUG)
self.gso_cache, self.gso_lock = {}, Lock()
ba = bind_address ba = tuple(bind_address)
if not isinstance(ba, basestring): if not ba[0]:
ba = tuple(ba) # AI_PASSIVE does not work with host of '' or None
if not ba[0]: ba = ('0.0.0.0', ba[1])
# AI_PASSIVE does not work with host of '' or None
ba = ('0.0.0.0', ba[1])
self.bind_address = ba self.bind_address = ba
self.bound_address = None self.bound_address = None
self.connection_map = {}
self.ssl_context = None self.ssl_context = None
if self.opts.ssl_certfile is not None and self.opts.ssl_keyfile is not None: if self.opts.ssl_certfile is not None and self.opts.ssl_keyfile is not None:
self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
@ -523,51 +145,33 @@ class ServerLoop(object):
set_socket_inherit(self.pre_activated_socket, False) set_socket_inherit(self.pre_activated_socket, False)
self.bind_address = self.pre_activated_socket.getsockname() self.bind_address = self.pre_activated_socket.getsockname()
self.ready = False
self.requests = ThreadPool(self, min_threads=self.opts.min_threads, max_threads=self.opts.max_threads)
def __str__(self): def __str__(self):
return "%s(%r)" % (self.__class__.__name__, self.bind_address) return "%s(%r)" % (self.__class__.__name__, self.bind_address)
__repr__ = __str__ __repr__ = __str__
@property
def num_active_connections(self):
return len(self.connection_map)
def serve_forever(self): def serve_forever(self):
""" Listen for incoming connections. """ """ Listen for incoming connections. """
if self.pre_activated_socket is None: if self.pre_activated_socket is None:
# Select the appropriate socket # AF_INET or AF_INET6 socket
if isinstance(self.bind_address, basestring): # Get the correct address family for our host (allows IPv6
# AF_UNIX socket # addresses)
host, port = self.bind_address
# So we can reuse the socket... try:
try: info = socket.getaddrinfo(
os.unlink(self.bind_address) host, port, socket.AF_UNSPEC,
except EnvironmentError: socket.SOCK_STREAM, 0, socket.AI_PASSIVE)
pass except socket.gaierror:
if ':' in host:
# So everyone can access the socket... info = [(socket.AF_INET6, socket.SOCK_STREAM,
try: 0, "", self.bind_address + (0, 0))]
os.chmod(self.bind_address, 0777) else:
except EnvironmentError: info = [(socket.AF_INET, socket.SOCK_STREAM,
pass 0, "", self.bind_address)]
info = [
(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_address)]
else:
# AF_INET or AF_INET6 socket
# Get the correct address family for our host (allows IPv6
# addresses)
host, port = self.bind_address
try:
info = socket.getaddrinfo(
host, port, socket.AF_UNSPEC,
socket.SOCK_STREAM, 0, socket.AI_PASSIVE)
except socket.gaierror:
if ':' in host:
info = [(socket.AF_INET6, socket.SOCK_STREAM,
0, "", self.bind_address + (0, 0))]
else:
info = [(socket.AF_INET, socket.SOCK_STREAM,
0, "", self.bind_address)]
self.socket = None self.socket = None
msg = "No socket could be created" msg = "No socket could be created"
@ -589,28 +193,28 @@ class ServerLoop(object):
self.pre_activated_socket = None self.pre_activated_socket = None
self.setup_socket() self.setup_socket()
self.socket.listen(5) self.connection_map = {}
self.socket.listen(min(socket.SOMAXCONN, 128))
self.bound_address = ba = self.socket.getsockname() self.bound_address = ba = self.socket.getsockname()
if isinstance(ba, tuple): if isinstance(ba, tuple):
ba = ':'.join(map(type(''), ba)) ba = ':'.join(map(type(''), ba))
self.log('calibre server listening on', ba) with TemporaryDirectory(prefix='srv-') as tdir:
self.tdir = tdir
self.ready = True
self.log('calibre server listening on', ba)
# Create worker threads while True:
self.requests.start() try:
self.ready = True self.tick()
except (KeyboardInterrupt, SystemExit):
while self.ready: self.shutdown()
try: break
self.tick() except:
except (KeyboardInterrupt, SystemExit): self.log.exception('Error in ServerLoop.tick')
raise
except:
self.log.exception('Error in ServerLoop.tick')
def setup_socket(self): def setup_socket(self):
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if self.opts.no_delay and not isinstance(self.bind_address, basestring): self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# If listening on the IPV6 any address ('::' = IN6ADDR_ANY), # If listening on the IPV6 any address ('::' = IN6ADDR_ANY),
# activate dual-stack. # activate dual-stack.
@ -623,149 +227,159 @@ class ServerLoop(object):
# Apparently, the socket option is not available in # Apparently, the socket option is not available in
# this machine's TCP stack # this machine's TCP stack
pass pass
self.socket.setblocking(0)
def bind(self, family, atype, proto=0): def bind(self, family, atype, proto=0):
"""Create (or recreate) the actual socket object.""" '''Create (or recreate) the actual socket object.'''
self.socket = socket.socket(family, atype, proto) self.socket = socket.socket(family, atype, proto)
set_socket_inherit(self.socket, False) set_socket_inherit(self.socket, False)
self.setup_socket() self.setup_socket()
self.socket.bind(self.bind_address) self.socket.bind(self.bind_address)
def tick(self): def tick(self):
"""Accept a new connection and put it on the Queue.""" now = monotonic()
for s, conn in tuple(self.connection_map.iteritems()):
if now - conn.last_activity > self.opts.timeout:
if not conn.handle_timeout():
self.log.debug('Closing connection because of extended inactivity')
self.close(s, conn)
read_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is READ or c.wait_for is RDWR]
write_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is WRITE or c.wait_for is RDWR]
try: try:
s, addr = self.socket.accept() readable, writable, _ = select.select([self.socket] + read_needed, write_needed, [], self.opts.timeout)
if not self.ready: except select.error as e:
if e.errno in socket_errors_eintr:
return return
for s, conn in tuple(self.connection_map.iteritems()):
set_socket_inherit(s, False)
if hasattr(s, 'settimeout'):
s.settimeout(self.opts.timeout)
if self.ssl_context is not None:
try: try:
s = self.ssl_context.wrap_socket(s, server_side=True) select.select([s], [], [], 0)
except ssl.SSLEOFError: except select.error:
return # Ignore, client closed connection self.close(s, conn) # Bad socket, discard
except ssl.SSLError as e:
if e.args[1].endswith('http request'):
msg = (b"The client sent a plain HTTP request, but "
b"this server only speaks HTTPS on this port.")
response = [
b"HTTP/1.1 400 Bad Request\r\n",
str("Content-Length: %s\r\n" % len(msg)),
b"Content-Type: text/plain\r\n\r\n",
msg
]
with SocketFile(s._sock) as f:
f.write(response)
return
elif e.args[1].endswith('unknown protocol'):
return # Drop connection
raise
if hasattr(s, 'settimeout'):
s.settimeout(self.opts.timeout)
conn = Connection(self, s)
if not isinstance(self.bind_address, basestring):
# optional values
# Until we do DNS lookups, omit REMOTE_HOST
if addr is None: # sometimes this can happen
# figure out if AF_INET or AF_INET6.
if len(s.getsockname()) == 2:
# AF_INET
addr = ('0.0.0.0', 0)
else:
# AF_INET6
addr = ('::', 0)
conn.remote_addr = addr[0]
conn.remote_port = addr[1]
try:
self.requests.put(conn)
except Full:
self.log.warn('Server overloaded, dropping connection')
conn.close()
return
except socket.timeout:
# The only reason for the timeout in start() is so we can
# notice keyboard interrupts on Win32, which don't interrupt
# accept() by default
return
except socket.error as e:
if e.args[0] in socket_error_eintr | socket_errors_nonblocking | socket_errors_to_ignore:
return
raise
def stop(self):
""" Gracefully shutdown the server loop. """
if not self.ready: if not self.ready:
return return
# We run the stop code in its own thread so that it is not interrupted
# by KeyboardInterrupt
self.ready = False
t = Thread(target=self._stop)
t.start()
try:
t.join()
except KeyboardInterrupt:
pass
def _stop(self): ignore = set()
self.log('Shutting down server gracefully, waiting for connections to close...') for s, conn, event in self.get_actions(readable, writable):
self.requests.stop(self.opts.shutdown_timeout) if s in ignore:
sock = self.tick_once() continue
if hasattr(sock, "close"): try:
sock.close() conn.handle_event(event)
self.socket = None if not conn.ready:
self.close(s, conn)
except Exception as e:
ignore.add(s)
if conn.ready:
self.log.exception('Unhandled exception in state: %s' % conn.state_description)
if conn.response_started:
self.close(s, conn)
else:
try:
conn.report_unhandled_exception(e, traceback.format_exc())
except Exception:
self.close(s, conn)
else:
self.log.error('Error in SSL handshake, terminating connection: %s' % as_unicode(e))
self.close(s, conn)
def tick_once(self): def wakeup(self):
# Touch our own socket to make accept() return immediately. # Touch our own socket to make select() return immediately.
sock = getattr(self, "socket", None) sock = getattr(self, "socket", None)
if sock is not None: if sock is not None:
if not isinstance(self.bind_address, basestring): try:
try: host, port = sock.getsockname()[:2]
host, port = sock.getsockname()[:2] except socket.error as e:
except socket.error as e: if e.errno not in socket_errors_socket_closed:
if e.args[0] not in socket_errors_to_ignore: raise
raise else:
else: for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC,
# Ensure tick() returns by opening a transient connection socket.SOCK_STREAM):
# to our own listening socket af, socktype, proto, canonname, sa = res
for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, s = None
socket.SOCK_STREAM): try:
af, socktype, proto, canonname, sa = res s = socket.socket(af, socktype, proto)
s = None s.settimeout(1.0)
try: s.connect((host, port))
s = socket.socket(af, socktype, proto) s.close()
s.settimeout(1.0) except socket.error:
s.connect((host, port)) if s is not None:
s.close() s.close()
except socket.error:
if s is not None:
s.close()
return sock return sock
def echo_handler(conn): def close(self, s, conn):
keep_going = True self.connection_map.pop(s, None)
while keep_going: conn.close()
def get_actions(self, readable, writable):
for s in readable:
if s is self.socket:
s, addr = self.accept()
if s is not None:
self.connection_map[s] = conn = self.handler(s, self.opts, self.ssl_context, self.tdir)
if self.ssl_context is not None:
yield s, conn, RDWR
else:
yield s, self.connection_map[s], READ
for s in writable:
yield s, self.connection_map[s], WRITE
def accept(self):
try: try:
line = conn.socket_file.readline() return self.socket.accept()
except socket.timeout: except socket.error:
continue return None, None
conn.server_loop.log('Received:', repr(line))
if not line.rstrip(): def stop(self):
keep_going = False self.ready = False
line = b'bye\r\n' self.wakeup()
conn.socket_file.write(line)
conn.socket_file.flush() def shutdown(self):
try:
if getattr(self, 'socket', None):
self.socket.close()
self.socket = None
except socket.error:
pass
for s, conn in tuple(self.connection_map.iteritems()):
self.close(s, conn)
class EchoLine(Connection): # {{{
bye_after_echo = False
def connection_ready(self):
self.rbuf = BytesIO()
self.set_state(READ, self.read_line)
def read_line(self, event):
data = self.recv(1)
if data:
self.rbuf.write(data)
if b'\n' == data:
if self.rbuf.tell() < 3:
# Empty line
self.rbuf = BytesIO(b'bye' + self.rbuf.getvalue())
self.bye_after_echo = True
self.set_state(WRITE, self.echo)
self.rbuf.seek(0)
def echo(self, event):
pos = self.rbuf.tell()
self.rbuf.seek(0, os.SEEK_END)
left = self.rbuf.tell() - pos
self.rbuf.seek(pos)
sent = self.send(self.rbuf.read(512))
if sent == left:
self.rbuf = BytesIO()
self.set_state(READ, self.read_line)
if self.bye_after_echo:
self.ready = False
else:
self.rbuf.seek(pos + sent)
# }}}
if __name__ == '__main__': if __name__ == '__main__':
s = ServerLoop(echo_handler) s = ServerLoop(EchoLine)
with HandleInterrupt(s.tick_once): with HandleInterrupt(s.wakeup):
try: s.serve_forever()
s.serve_forever()
except KeyboardInterrupt:
pass
s.stop()

View File

@ -36,14 +36,6 @@ raw_options = (
'shutdown_timeout', 5.0, 'shutdown_timeout', 5.0,
None, None,
'Minimum number of connection handling threads',
'min_threads', 10,
None,
'Maximum number of simultaneous connections (beyond this number of connections will be dropped)',
'max_threads', 500,
None,
'Allow socket pre-allocation, for example, with systemd socket activation', 'Allow socket pre-allocation, for example, with systemd socket activation',
'allow_socket_preallocation', True, 'allow_socket_preallocation', True,
None, None,
@ -52,7 +44,7 @@ raw_options = (
'max_header_line_size', 8.0, 'max_header_line_size', 8.0,
None, None,
'Max. size of a HTTP request (in MB)', 'Max. allowed size for files uploaded to the server (in MB)',
'max_request_body_size', 500.0, 'max_request_body_size', 500.0,
None, None,
@ -60,12 +52,6 @@ raw_options = (
'compress_min_size', 1024, 'compress_min_size', 1024,
None, None,
'Decrease latency by using the TCP_NODELAY feature',
'no_delay', True,
'no_delay turns on TCP_NODELAY which decreases latency at the cost of'
' worse overall performance when sending multiple small packets. It'
' prevents the TCP stack from aggregating multiple small TCP packets.',
'Use zero copy file transfers for increased performance', 'Use zero copy file transfers for increased performance',
'use_sendfile', True, 'use_sendfile', True,
'This will use zero-copy in-kernel transfers when sending files over the network,' 'This will use zero-copy in-kernel transfers when sending files over the network,'

View File

@ -1,356 +0,0 @@
#!/usr/bin/env python2
# vim:fileencoding=utf-8
from __future__ import (unicode_literals, division, absolute_import,
print_function)
__license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import os, hashlib, httplib, zlib, struct, time, uuid
from io import DEFAULT_BUFFER_SIZE, BytesIO
from collections import namedtuple
from functools import partial
from future_builtins import map
from itertools import izip_longest
from calibre import force_unicode, guess_type
from calibre.srv.errors import IfNoneMatch, RangeNotSatisfiable
from calibre.srv.sendfile import file_metadata, copy_range, sendfile_to_socket
Range = namedtuple('Range', 'start stop size')
MULTIPART_SEPARATOR = uuid.uuid4().hex.decode('ascii')
def get_ranges(headervalue, content_length):
''' Return a list of ranges from the Range header. If this function returns
an empty list, it indicates no valid range was found. '''
if not headervalue:
return None
result = []
try:
bytesunit, byteranges = headervalue.split("=", 1)
except Exception:
return None
if bytesunit.strip() != 'bytes':
return None
for brange in byteranges.split(","):
start, stop = [x.strip() for x in brange.split("-", 1)]
if start:
if not stop:
stop = content_length - 1
try:
start, stop = int(start), int(stop)
except Exception:
continue
if start >= content_length:
continue
if stop < start:
continue
stop = min(stop, content_length - 1)
result.append(Range(start, stop, stop - start + 1))
elif stop:
# Negative subscript (last N bytes)
try:
stop = int(stop)
except Exception:
continue
if stop > content_length:
result.append(Range(0, content_length-1, content_length))
else:
result.append(Range(content_length - stop, content_length - 1, stop))
return result
def acceptable_encoding(val, allowed=frozenset({'gzip'})):
def enc(x):
e, r = x.partition(';')[::2]
p, v = r.partition('=')[::2]
q = 1.0
if p == 'q' and v:
try:
q = float(v)
except Exception:
pass
return e.lower(), q
emap = dict(enc(x.strip()) for x in val.split(','))
acceptable = sorted(set(emap) & allowed, key=emap.__getitem__, reverse=True)
if acceptable:
return acceptable[0]
def gzip_prefix(mtime):
# See http://www.gzip.org/zlib/rfc-gzip.html
return b''.join((
b'\x1f\x8b', # ID1 and ID2: gzip marker
b'\x08', # CM: compression method
b'\x00', # FLG: none set
# MTIME: 4 bytes
struct.pack(b"<L", int(mtime) & 0xFFFFFFFF),
b'\x02', # XFL: max compression, slowest algo
b'\xff', # OS: unknown
))
def write_chunked_data(dest, data):
dest.write(('%X\r\n' % len(data)).encode('ascii'))
dest.write(data)
dest.write(b'\r\n')
def write_compressed_file_obj(input_file, dest, compress_level=6):
crc = zlib.crc32(b"")
size = 0
zobj = zlib.compressobj(compress_level,
zlib.DEFLATED, -zlib.MAX_WBITS,
zlib.DEF_MEM_LEVEL, 0)
prefix_written = False
while True:
data = input_file.read(DEFAULT_BUFFER_SIZE)
if not data:
break
size += len(data)
crc = zlib.crc32(data, crc)
data = zobj.compress(data)
if not prefix_written:
prefix_written = True
data = gzip_prefix(time.time()) + data
write_chunked_data(dest, data)
data = zobj.flush() + struct.pack(b"<L", crc & 0xFFFFFFFF) + struct.pack(b"<L", size & 0xFFFFFFFF)
write_chunked_data(dest, data)
write_chunked_data(dest, b'')
def get_range_parts(ranges, content_type, content_length):
def part(r):
ans = ['--%s' % MULTIPART_SEPARATOR, 'Content-Range: bytes %d-%d/%d' % (r.start, r.stop, content_length)]
if content_type:
ans.append('Content-Type: %s' % content_type)
ans.append('')
return ('\r\n'.join(ans)).encode('ascii')
return list(map(part, ranges)) + [('--%s--' % MULTIPART_SEPARATOR).encode('ascii')]
def parse_multipart_byterange(buf, content_type):
from calibre.srv.http import read_headers
sep = (content_type.rsplit('=', 1)[-1]).encode('utf-8')
ans = []
def parse_part():
line = buf.readline()
if not line:
raise ValueError('Premature end of message')
if not line.startswith(b'--' + sep):
raise ValueError('Malformed start of multipart message')
if line.endswith(b'--'):
return None
headers = read_headers(buf.readline)
cr = headers.get('Content-Range')
if not cr:
raise ValueError('Missing Content-Range header in sub-part')
if not cr.startswith('bytes '):
raise ValueError('Malformed Content-Range header in sub-part, no prefix')
try:
start, stop = map(lambda x: int(x.strip()), cr.partition(' ')[-1].partition('/')[0].partition('-')[::2])
except Exception:
raise ValueError('Malformed Content-Range header in sub-part, failed to parse byte range')
content_length = stop - start + 1
ret = buf.read(content_length)
if len(ret) != content_length:
raise ValueError('Malformed sub-part, length of body not equal to length specified in Content-Range')
buf.readline()
return (start, ret)
while True:
data = parse_part()
if data is None:
break
ans.append(data)
return ans
class ReadableOutput(object):
def __init__(self, output, outheaders):
self.src_file = output
self.src_file.seek(0, os.SEEK_END)
self.content_length = self.src_file.tell()
self.etag = None
self.accept_ranges = True
self.use_sendfile = False
def write(self, dest):
if self.use_sendfile:
dest.flush() # Ensure everything in the SocketFile buffer is sent before calling sendfile()
sent = sendfile_to_socket(self.src_file, 0, self.content_length, dest)
else:
sent = copy_range(self.src_file, 0, self.content_length, dest)
if sent != self.content_length:
raise IOError(
'Failed to send complete file (%r) (%s != %s bytes), perhaps the file was modified during send?' % (
getattr(self.src_file, 'name', '<file>'), sent, self.content_length))
self.src_file = None
def write_compressed(self, dest):
self.src_file.seek(0)
write_compressed_file_obj(self.src_file, dest)
self.src_file = None
def write_ranges(self, ranges, dest):
if isinstance(ranges, Range):
r = ranges
self.copy_range(r.start, r.size, dest)
else:
for r, header in ranges:
dest.write(header)
if r is not None:
dest.write(b'\r\n')
self.copy_range(r.start, r.size, dest)
dest.write(b'\r\n')
self.src_file = None
def copy_range(self, start, size, dest):
if self.use_sendfile:
dest.flush() # Ensure everything in the SocketFile buffer is sent before calling sendfile()
sent = sendfile_to_socket(self.src_file, start, size, dest)
else:
sent = copy_range(self.src_file, start, size, dest)
if sent != size:
raise IOError('Failed to send byte range from file (%r) (%s != %s bytes), perhaps the file was modified during send?' % (
getattr(self.src_file, 'name', '<file>'), sent, size))
class FileSystemOutputFile(ReadableOutput):
def __init__(self, output, outheaders, stat_result, use_sendfile):
self.src_file = output
self.name = output.name
self.content_length = stat_result.st_size
self.etag = '"%s"' % hashlib.sha1(type('')(stat_result.st_mtime) + force_unicode(output.name or '')).hexdigest()
self.accept_ranges = True
self.use_sendfile = use_sendfile and sendfile_to_socket is not None
class DynamicOutput(object):
def __init__(self, output, outheaders):
if isinstance(output, bytes):
self.data = output
else:
self.data = output.encode('utf-8')
ct = outheaders.get('Content-Type')
if not ct:
outheaders.set('Content-Type', 'text/plain; charset=UTF-8', replace_all=True)
self.content_length = len(self.data)
self.etag = None
self.accept_ranges = False
def write(self, dest):
dest.write(self.data)
self.data = None
def write_compressed(self, dest):
write_compressed_file_obj(BytesIO(self.data), dest)
class GeneratedOutput(object):
def __init__(self, output, outheaders):
self.output = output
self.content_length = self.etag = None
self.accept_ranges = False
def write(self, dest):
for line in self.output:
if line:
write_chunked_data(dest, line)
class StaticGeneratedOutput(object):
def __init__(self, data):
if isinstance(data, type('')):
data = data.encode('utf-8')
self.data = data
self.etag = '"%s"' % hashlib.sha1(data).hexdigest()
self.content_length = len(data)
self.accept_ranges = False
def write(self, dest):
dest.write(self.data)
def write_compressed(self, dest):
write_compressed_file_obj(BytesIO(self.data), dest)
def generate_static_output(cache, gso_lock, name, generator):
with gso_lock:
ans = cache.get(name)
if ans is None:
ans = cache[name] = StaticGeneratedOutput(generator())
return ans
def parse_if_none_match(val):
return {x.strip() for x in val.split(',')}
def finalize_output(output, inheaders, outheaders, status_code, is_http1, method, opts):
ct = outheaders.get('Content-Type', '')
compressible = not ct or ct.startswith('text/') or ct.startswith('image/svg') or ct.startswith('application/json')
stat_result = file_metadata(output)
if stat_result is not None:
output = FileSystemOutputFile(output, outheaders, stat_result, opts.use_sendfile)
if 'Content-Type' not in outheaders:
mt = guess_type(output.name)[0]
if mt:
if mt in ('text/plain', 'text/html'):
mt =+ '; charset=UTF-8'
outheaders['Content-Type'] = mt
elif isinstance(output, (bytes, type(''))):
output = DynamicOutput(output, outheaders)
elif hasattr(output, 'read'):
output = ReadableOutput(output, outheaders)
elif isinstance(output, StaticGeneratedOutput):
pass
else:
output = GeneratedOutput(output, outheaders)
compressible = (status_code == httplib.OK and compressible and
(opts.compress_min_size > -1 and output.content_length >= opts.compress_min_size) and
acceptable_encoding(inheaders.get('Accept-Encoding', '')) and not is_http1)
accept_ranges = (not compressible and output.accept_ranges is not None and status_code == httplib.OK and
not is_http1)
ranges = get_ranges(inheaders.get('Range'), output.content_length) if output.accept_ranges and method in ('GET', 'HEAD') else None
if_range = (inheaders.get('If-Range') or '').strip()
if if_range and if_range != output.etag:
ranges = None
if ranges is not None and not ranges:
raise RangeNotSatisfiable(output.content_length)
for header in ('Accept-Ranges', 'Content-Encoding', 'Transfer-Encoding', 'ETag', 'Content-Length'):
outheaders.pop('header', all=True)
none_match = parse_if_none_match(inheaders.get('If-None-Match', ''))
matched = '*' in none_match or (output.etag and output.etag in none_match)
if matched:
raise IfNoneMatch(output.etag)
if output.etag and method in ('GET', 'HEAD'):
outheaders.set('ETag', output.etag, replace_all=True)
if accept_ranges:
outheaders.set('Accept-Ranges', 'bytes', replace_all=True)
elif compressible:
outheaders.set('Content-Encoding', 'gzip', replace_all=True)
if output.content_length is not None and not compressible and not ranges:
outheaders.set('Content-Length', '%d' % output.content_length, replace_all=True)
if compressible or output.content_length is None:
outheaders.set('Transfer-Encoding', 'chunked', replace_all=True)
if ranges:
if len(ranges) == 1:
r = ranges[0]
outheaders.set('Content-Length', '%d' % r.size, replace_all=True)
outheaders.set('Content-Range', 'bytes %d-%d/%d' % (r.start, r.stop, output.content_length), replace_all=True)
output.commit = partial(output.write_ranges, r)
else:
range_parts = get_range_parts(ranges, outheaders.get('Content-Type'), output.content_length)
size = sum(map(len, range_parts)) + sum(r.size + 4 for r in ranges)
outheaders.set('Content-Length', '%d' % size, replace_all=True)
outheaders.set('Content-Type', 'multipart/byteranges; boundary=' + MULTIPART_SEPARATOR, replace_all=True)
output.commit = partial(output.write_ranges, izip_longest(ranges, range_parts))
status_code = httplib.PARTIAL_CONTENT
else:
output.commit = output.write_compressed if compressible else output.write
return status_code, output

View File

@ -10,7 +10,7 @@ import os, ctypes, errno, socket
from io import DEFAULT_BUFFER_SIZE from io import DEFAULT_BUFFER_SIZE
from select import select from select import select
from calibre.constants import iswindows, isosx from calibre.constants import islinux, isosx
from calibre.srv.utils import eintr_retry_call from calibre.srv.utils import eintr_retry_call
def file_metadata(fileobj): def file_metadata(fileobj):
@ -33,10 +33,15 @@ def copy_range(src_file, start, size, dest):
del data del data
return total_sent return total_sent
class CannotSendfile(Exception):
pass
if iswindows: class SendfileInterrupted(Exception):
sendfile_to_socket = None pass
elif isosx:
sendfile_to_socket = sendfile_to_socket_async = None
if isosx:
libc = ctypes.CDLL(None, use_errno=True) libc = ctypes.CDLL(None, use_errno=True)
sendfile = ctypes.CFUNCTYPE( sendfile = ctypes.CFUNCTYPE(
ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int64, ctypes.POINTER(ctypes.c_int64), ctypes.c_void_p, ctypes.c_int, use_errno=True)( ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int64, ctypes.POINTER(ctypes.c_int64), ctypes.c_void_p, ctypes.c_int, use_errno=True)(
@ -68,7 +73,19 @@ elif isosx:
offset += num_bytes.value offset += num_bytes.value
return total_sent return total_sent
else: def sendfile_to_socket_async(fileobj, offset, size, socket_file):
num_bytes = ctypes.c_int64(size)
ret = sendfile(fileobj.fileno(), socket_file.fileno(), offset, ctypes.byref(num_bytes), None, 0)
if ret != 0:
err = ctypes.get_errno()
if err in (errno.EBADF, errno.ENOTSUP, errno.ENOTSOCK, errno.EOPNOTSUPP):
raise CannotSendfile()
if err in (errno.EINTR, errno.EAGAIN):
raise SendfileInterrupted()
raise IOError((err, os.strerror(err)))
return num_bytes.value
elif islinux:
libc = ctypes.CDLL(None, use_errno=True) libc = ctypes.CDLL(None, use_errno=True)
sendfile = ctypes.CFUNCTYPE( sendfile = ctypes.CFUNCTYPE(
ctypes.c_ssize_t, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int64), ctypes.c_size_t, use_errno=True)(('sendfile64', libc)) ctypes.c_ssize_t, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int64), ctypes.c_size_t, use_errno=True)(('sendfile64', libc))
@ -97,3 +114,15 @@ else:
size -= sent size -= sent
total_sent += sent total_sent += sent
return total_sent return total_sent
def sendfile_to_socket_async(fileobj, offset, size, socket_file):
off = ctypes.c_int64(offset)
sent = sendfile(socket_file.fileno(), fileobj.fileno(), ctypes.byref(off), size)
if sent < 0:
err = ctypes.get_errno()
if err in (errno.ENOSYS, errno.EINVAL):
raise CannotSendfile()
if err in (errno.EINTR, errno.EAGAIN):
raise SendfileInterrupted()
raise IOError((err, os.strerror(err)))
return sent

View File

@ -36,7 +36,7 @@ class TestServer(Thread):
Thread.__init__(self, name='ServerMain') Thread.__init__(self, name='ServerMain')
from calibre.srv.opts import Options from calibre.srv.opts import Options
from calibre.srv.loop import ServerLoop from calibre.srv.loop import ServerLoop
from calibre.srv.http import create_http_handler from calibre.srv.http_response import create_http_handler
kwargs['shutdown_timeout'] = kwargs.get('shutdown_timeout', 0.1) kwargs['shutdown_timeout'] = kwargs.get('shutdown_timeout', 0.1)
self.loop = ServerLoop( self.loop = ServerLoop(
create_http_handler(handler), create_http_handler(handler),
@ -68,5 +68,5 @@ class TestServer(Thread):
return httplib.HTTPConnection(self.address[0], self.address[1], strict=True, timeout=timeout) return httplib.HTTPConnection(self.address[0], self.address[1], strict=True, timeout=timeout)
def change_handler(self, handler): def change_handler(self, handler):
from calibre.srv.http import create_http_handler from calibre.srv.http_response import create_http_handler
self.loop.req_resp_handler = create_http_handler(handler) self.loop.handler = create_http_handler(handler)

View File

@ -6,57 +6,56 @@ from __future__ import (unicode_literals, division, absolute_import,
__license__ = 'GPL v3' __license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>' __copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import textwrap, httplib, hashlib, zlib, string import httplib, hashlib, zlib, string
from io import BytesIO from io import BytesIO
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from calibre import guess_type from calibre import guess_type
from calibre.srv.tests.base import BaseTest, TestServer from calibre.srv.tests.base import BaseTest, TestServer
def headers(raw):
return BytesIO(textwrap.dedent(raw).encode('utf-8'))
class TestHTTP(BaseTest): class TestHTTP(BaseTest):
def test_header_parsing(self): # {{{ def test_header_parsing(self): # {{{
'Test parsing of HTTP headers' 'Test parsing of HTTP headers'
from calibre.srv.http import read_headers from calibre.srv.http_request import HTTPHeaderParser
def test(name, raw, **kwargs): def test(name, *lines, **kwargs):
hdict = read_headers(headers(raw).readline) p = HTTPHeaderParser()
self.assertSetEqual(set(hdict.items()), {(k.replace('_', '-').title(), v) for k, v in kwargs.iteritems()}, name + ' failed') p.push(*lines)
self.assertTrue(p.finished)
self.assertSetEqual(set(p.hdict.items()), {(k.replace('_', '-').title(), v) for k, v in kwargs.iteritems()}, name + ' failed')
test('Continuation line parsing', test('Continuation line parsing',
'''\ 'a: one',
a: one\r 'b: two',
b: two\r ' 2',
2\r '\t3',
\t3\r 'c:three',
c:three\r '\r\n', a='one', b='two 2 3', c='three')
\r\n''', a='one', b='two 2 3', c='three')
test('Non-ascii headers parsing', test('Non-ascii headers parsing',
'''\ b'a:mūs\r', '\r\n', a='mūs'.encode('utf-8'))
a:mūs\r
\r\n''', a='mūs'.encode('utf-8'))
test('Comma-separated parsing', test('Comma-separated parsing',
'''\ 'Accept-Encoding: one',
Accept-Encoding: one\r 'accept-Encoding: two',
Accept-Encoding: two\r '\r\n', accept_encoding='one, two')
\r\n''', accept_encoding='one, two')
def parse(line):
HTTPHeaderParser()(line)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
read_headers(headers('Connection:mūs\r\n').readline) parse('Connection:mūs\r\n')
read_headers(headers('Connection\r\n').readline) parse('Connection\r\n')
read_headers(headers('Connection:a\r\n').readline) parse('Connection:a\r\n')
read_headers(headers('Connection:a\n').readline) parse('Connection:a\n')
read_headers(headers(' Connection:a\n').readline) parse(' Connection:a\n')
parse(':a\n')
# }}} # }}}
def test_accept_encoding(self): # {{{ def test_accept_encoding(self): # {{{
'Test parsing of Accept-Encoding' 'Test parsing of Accept-Encoding'
from calibre.srv.respond import acceptable_encoding from calibre.srv.http_response import acceptable_encoding
def test(name, val, ans, allowed={'gzip'}): def test(name, val, ans, allowed={'gzip'}):
self.ae(acceptable_encoding(val, allowed), ans, name + ' failed') self.ae(acceptable_encoding(val, allowed), ans, name + ' failed')
test('Empty field', '', None) test('Empty field', '', None)
@ -68,7 +67,7 @@ class TestHTTP(BaseTest):
def test_range_parsing(self): # {{{ def test_range_parsing(self): # {{{
'Test parsing of Range header' 'Test parsing of Range header'
from calibre.srv.respond import get_ranges from calibre.srv.http_response import get_ranges
def test(val, *args): def test(val, *args):
pval = get_ranges(val, 100) pval = get_ranges(val, 100)
if len(args) == 1 and args[0] is None: if len(args) == 1 and args[0] is None:
@ -91,11 +90,38 @@ class TestHTTP(BaseTest):
'Test basic HTTP protocol conformance' 'Test basic HTTP protocol conformance'
from calibre.srv.errors import HTTP404 from calibre.srv.errors import HTTP404
body = 'Requested resource not found' body = 'Requested resource not found'
def handler(conn): def handler(data):
raise HTTP404(body) raise HTTP404(body)
def raw_send(conn, raw):
conn.send(raw)
conn._HTTPConnection__state = httplib._CS_REQ_SENT
return conn.getresponse()
with TestServer(handler, timeout=0.1, max_header_line_size=100./1024, max_request_body_size=100./(1024*1024)) as server: with TestServer(handler, timeout=0.1, max_header_line_size=100./1024, max_request_body_size=100./(1024*1024)) as server:
# Test 404
conn = server.connect() conn = server.connect()
r = raw_send(conn, b'hello\n')
self.ae(r.status, httplib.BAD_REQUEST)
self.ae(r.read(), b'HTTP requires CRLF line terminators')
r = raw_send(conn, b'\r\nGET /index.html HTTP/1.1\r\n\r\n')
self.ae(r.status, httplib.NOT_FOUND), self.ae(r.read(), b'Requested resource not found')
r = raw_send(conn, b'\r\n\r\nGET /index.html HTTP/1.1\r\n\r\n')
self.ae(r.status, httplib.BAD_REQUEST)
self.ae(r.read(), b'Multiple leading empty lines not allowed')
r = raw_send(conn, b'hello world\r\n')
self.ae(r.status, httplib.BAD_REQUEST)
self.ae(r.read(), b'Malformed Request-Line')
r = raw_send(conn, b'x' * 200)
self.ae(r.status, httplib.BAD_REQUEST)
self.ae(r.read(), b'')
r = raw_send(conn, b'XXX /index.html HTTP/1.1\r\n\r\n')
self.ae(r.status, httplib.BAD_REQUEST), self.ae(r.read(), b'Unknown HTTP method')
# Test 404
conn.request('HEAD', '/moose') conn.request('HEAD', '/moose')
r = conn.getresponse() r = conn.getresponse()
self.ae(r.status, httplib.NOT_FOUND) self.ae(r.status, httplib.NOT_FOUND)
@ -104,32 +130,48 @@ class TestHTTP(BaseTest):
self.ae(r.getheader('Content-Type'), 'text/plain; charset=UTF-8') self.ae(r.getheader('Content-Type'), 'text/plain; charset=UTF-8')
self.ae(len(r.getheaders()), 3) self.ae(len(r.getheaders()), 3)
self.ae(r.read(), '') self.ae(r.read(), '')
conn.request('GET', '/moose') conn.request('GET', '/choose')
r = conn.getresponse() r = conn.getresponse()
self.ae(r.status, httplib.NOT_FOUND) self.ae(r.status, httplib.NOT_FOUND)
self.ae(r.read(), 'Requested resource not found') self.ae(r.read(), b'Requested resource not found')
server.change_handler(lambda conn:conn.path[0] + conn.input_reader.read().decode('ascii')) # Test 500
orig = server.loop.log.filter_level
server.loop.log.filter_level = server.loop.log.ERROR + 10
server.change_handler(lambda data:1/0)
conn = server.connect()
conn.request('GET', '/test/')
r = conn.getresponse()
self.ae(r.status, httplib.INTERNAL_SERVER_ERROR)
server.loop.log.filter_level = orig
server.change_handler(lambda data:data.path[0] + data.read().decode('ascii'))
conn = server.connect() conn = server.connect()
# Test simple GET # Test simple GET
conn.request('GET', '/test/') conn.request('GET', '/test/')
r = conn.getresponse() r = conn.getresponse()
self.ae(r.status, httplib.OK) self.ae(r.status, httplib.OK)
self.ae(r.read(), 'test') self.ae(r.read(), b'test')
# Test TRACE
lines = ['TRACE /xxx HTTP/1.1', 'Test: value', 'Xyz: abc, def', '', '']
r = raw_send(conn, ('\r\n'.join(lines)).encode('ascii'))
self.ae(r.status, httplib.OK)
self.ae(r.read().decode('utf-8'), '\n'.join(lines[:-2]))
# Test POST with simple body # Test POST with simple body
conn.request('POST', '/test', 'body') conn.request('POST', '/test', 'body')
r = conn.getresponse() r = conn.getresponse()
self.ae(r.status, httplib.CREATED) self.ae(r.status, httplib.CREATED)
self.ae(r.read(), 'testbody') self.ae(r.read(), b'testbody')
# Test POST with chunked transfer encoding # Test POST with chunked transfer encoding
conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'}) conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'})
conn.send(b'4\r\nbody\r\n0\r\n\r\n') conn.send(b'4\r\nbody\r\na\r\n1234567890\r\n0\r\n\r\n')
r = conn.getresponse() r = conn.getresponse()
self.ae(r.status, httplib.CREATED) self.ae(r.status, httplib.CREATED)
self.ae(r.read(), 'testbody') self.ae(r.read(), b'testbody1234567890')
# Test various incorrect input # Test various incorrect input
orig_level, server.log.filter_level = server.log.filter_level, server.log.ERROR orig_level, server.log.filter_level = server.log.filter_level, server.log.ERROR
@ -150,19 +192,26 @@ class TestHTTP(BaseTest):
self.ae(r.status, httplib.BAD_REQUEST) self.ae(r.status, httplib.BAD_REQUEST)
self.assertIn(b'not a valid chunk size', r.read()) self.assertIn(b'not a valid chunk size', r.read())
conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'})
conn.send(b'4\r\nbody\r\n200\r\n\r\n')
r = conn.getresponse()
self.ae(r.status, httplib.REQUEST_ENTITY_TOO_LARGE)
conn.request('POST', '/test', body='a'*200)
r = conn.getresponse()
self.ae(r.status, httplib.REQUEST_ENTITY_TOO_LARGE)
conn = server.connect() conn = server.connect()
conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'}) conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'})
conn.send(b'3\r\nbody\r\n0\r\n\r\n') conn.send(b'3\r\nbody\r\n0\r\n\r\n')
r = conn.getresponse() r = conn.getresponse()
self.ae(r.status, httplib.BAD_REQUEST) self.ae(r.status, httplib.BAD_REQUEST), self.ae(r.read(), b'Chunk does not have trailing CRLF')
self.assertIn(b'!= CRLF', r.read())
conn = server.connect(timeout=1) conn = server.connect(timeout=1)
conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'}) conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'})
conn.send(b'30\r\nbody\r\n0\r\n\r\n') conn.send(b'30\r\nbody\r\n0\r\n\r\n')
r = conn.getresponse() r = conn.getresponse()
self.ae(r.status, httplib.BAD_REQUEST) self.ae(r.status, httplib.REQUEST_TIMEOUT)
self.assertIn(b'Timed out waiting for chunk', r.read()) self.assertIn(b'', r.read())
server.log.filter_level = orig_level server.log.filter_level = orig_level
conn = server.connect() conn = server.connect()
@ -180,18 +229,17 @@ class TestHTTP(BaseTest):
# Test closing # Test closing
conn.request('GET', '/close', headers={'Connection':'close'}) conn.request('GET', '/close', headers={'Connection':'close'})
self.ae(server.loop.requests.busy, 1)
r = conn.getresponse() r = conn.getresponse()
self.ae(server.loop.num_active_connections, 1)
self.ae(r.status, 200), self.ae(r.read(), 'close') self.ae(r.status, 200), self.ae(r.read(), 'close')
self.ae(server.loop.requests.busy, 0) server.loop.wakeup()
self.ae(server.loop.num_active_connections, 0)
self.assertIsNone(conn.sock) self.assertIsNone(conn.sock)
self.ae(server.loop.requests.idle, 10)
# }}} # }}}
def test_http_response(self): # {{{ def test_http_response(self): # {{{
'Test HTTP protocol responses' 'Test HTTP protocol responses'
from calibre.srv.respond import parse_multipart_byterange from calibre.srv.http_response import parse_multipart_byterange
def handler(conn): def handler(conn):
return conn.generate_static_output('test', lambda : ''.join(conn.path)) return conn.generate_static_output('test', lambda : ''.join(conn.path))
with TestServer(handler, timeout=0.1, compress_min_size=0) as server, \ with TestServer(handler, timeout=0.1, compress_min_size=0) as server, \
@ -216,9 +264,10 @@ class TestHTTP(BaseTest):
r = conn.getresponse() r = conn.getresponse()
self.ae(r.status, httplib.OK), self.ae(zlib.decompress(r.read(), 16+zlib.MAX_WBITS), b'an_etagged_path') self.ae(r.status, httplib.OK), self.ae(zlib.decompress(r.read(), 16+zlib.MAX_WBITS), b'an_etagged_path')
for i in '12': # Test getting a filesystem file
# Test getting a filesystem file for use_sendfile in (True, False):
server.change_handler(lambda conn: f) server.change_handler(lambda conn: f)
server.loop.opts.use_sendfile = use_sendfile
conn = server.connect() conn = server.connect()
conn.request('GET', '/test') conn.request('GET', '/test')
r = conn.getresponse() r = conn.getresponse()
@ -229,27 +278,27 @@ class TestHTTP(BaseTest):
self.ae(int(r.getheader('Content-Length')), len(fdata)) self.ae(int(r.getheader('Content-Length')), len(fdata))
self.ae(r.status, httplib.OK), self.ae(r.read(), fdata) self.ae(r.status, httplib.OK), self.ae(r.read(), fdata)
conn.request('GET', '/test', headers={'Range':'bytes=0-25'}) conn.request('GET', '/test', headers={'Range':'bytes=2-25'})
r = conn.getresponse() r = conn.getresponse()
self.ae(type('')(r.getheader('Accept-Ranges')), 'bytes') self.ae(type('')(r.getheader('Accept-Ranges')), 'bytes')
self.ae(type('')(r.getheader('Content-Range')), 'bytes 0-25/%d' % len(fdata)) self.ae(type('')(r.getheader('Content-Range')), 'bytes 2-25/%d' % len(fdata))
self.ae(int(r.getheader('Content-Length')), 26) self.ae(int(r.getheader('Content-Length')), 24)
self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[0:26]) self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[2:26])
conn.request('GET', '/test', headers={'Range':'bytes=100000-'}) conn.request('GET', '/test', headers={'Range':'bytes=100000-'})
r = conn.getresponse() r = conn.getresponse()
self.ae(type('')(r.getheader('Content-Range')), 'bytes */%d' % len(fdata)) self.ae(type('')(r.getheader('Content-Range')), 'bytes */%d' % len(fdata))
self.ae(r.status, httplib.REQUESTED_RANGE_NOT_SATISFIABLE) self.ae(r.status, httplib.REQUESTED_RANGE_NOT_SATISFIABLE)
conn.request('GET', '/test', headers={'Range':'bytes=0-1000000'})
r = conn.getresponse()
self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata)
conn.request('GET', '/test', headers={'Range':'bytes=25-50', 'If-Range':etag}) conn.request('GET', '/test', headers={'Range':'bytes=25-50', 'If-Range':etag})
r = conn.getresponse() r = conn.getresponse()
self.ae(int(r.getheader('Content-Length')), 26) self.ae(int(r.getheader('Content-Length')), 26)
self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[25:51]) self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[25:51])
conn.request('GET', '/test', headers={'Range':'bytes=0-1000000'})
r = conn.getresponse()
self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata)
conn.request('GET', '/test', headers={'Range':'bytes=25-50', 'If-Range':'"nomatch"'}) conn.request('GET', '/test', headers={'Range':'bytes=25-50', 'If-Range':'"nomatch"'})
r = conn.getresponse() r = conn.getresponse()
self.assertFalse(r.getheader('Content-Range')) self.assertFalse(r.getheader('Content-Range'))
@ -274,6 +323,4 @@ class TestHTTP(BaseTest):
self.ae(data, r.read()) self.ae(data, r.read())
# Now try it without sendfile # Now try it without sendfile
server.loop.opts.use_sendfile ^= True
conn = server.connect()
# }}} # }}}

View File

@ -11,9 +11,13 @@ from contextlib import closing
from urlparse import parse_qs from urlparse import parse_qs
import repr as reprlib import repr as reprlib
from email.utils import formatdate from email.utils import formatdate
from operator import itemgetter
from calibre.constants import iswindows from calibre.constants import iswindows
HTTP1 = 'HTTP/1.0'
HTTP11 = 'HTTP/1.1'
def http_date(timeval=None): def http_date(timeval=None):
return type('')(formatdate(timeval=timeval, usegmt=True)) return type('')(formatdate(timeval=timeval, usegmt=True))
@ -85,7 +89,8 @@ class MultiDict(dict): # {{{
__str__ = __unicode__ = __repr__ __str__ = __unicode__ = __repr__
def pretty(self, leading_whitespace=''): def pretty(self, leading_whitespace=''):
return leading_whitespace + ('\n' + leading_whitespace).join('%s: %s' % (k, v) for k, v in self.items()) return leading_whitespace + ('\n' + leading_whitespace).join(
'%s: %s' % (k, (repr(v) if isinstance(v, bytes) else v)) for k, v in sorted(self.items(), key=itemgetter(0)))
# }}} # }}}
def error_codes(*errnames): def error_codes(*errnames):
@ -112,29 +117,15 @@ socket_errors_socket_closed = error_codes( # errors indicating a disconnected c
socket_errors_nonblocking = error_codes( socket_errors_nonblocking = error_codes(
'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK') 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK')
class Corked(object): def start_cork(sock):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0)
if hasattr(socket, 'TCP_CORK'):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 1)
' Context manager to turn on TCP corking. Ensures maximum throughput for large logical packets. ' def stop_cork(sock):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
def __init__(self, sock): if hasattr(socket, 'TCP_CORK'):
self.sock = sock sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 0)
def __enter__(self):
nodelay = self.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
if nodelay == 1:
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0)
self.set_nodelay = True
else:
self.set_nodelay = False
if hasattr(socket, 'TCP_CORK'):
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 1)
def __exit__(self, *args):
if self.set_nodelay:
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if hasattr(socket, 'TCP_CORK'):
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 0)
self.sock.send(b'') # Ensure that uncorking occurs
def create_sock_pair(port=0): def create_sock_pair(port=0):
'''Create socket pair. Works also on windows by using an ephemeral TCP port.''' '''Create socket pair. Works also on windows by using an ephemeral TCP port.'''