mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
Finish implementation of async http server
This commit is contained in:
parent
2068e52b82
commit
d075eff758
@ -1,337 +0,0 @@
|
||||
#!/usr/bin/env python2
|
||||
# vim:fileencoding=utf-8
|
||||
from __future__ import (unicode_literals, division, absolute_import,
|
||||
print_function)
|
||||
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
import ssl, socket, select, os
|
||||
from io import BytesIO
|
||||
|
||||
from calibre import as_unicode
|
||||
from calibre.srv.opts import Options
|
||||
from calibre.srv.utils import (
|
||||
socket_errors_socket_closed, socket_errors_nonblocking, HandleInterrupt)
|
||||
from calibre.utils.socket_inheritance import set_socket_inherit
|
||||
from calibre.utils.logging import ThreadSafeLog
|
||||
from calibre.utils.monotonic import monotonic
|
||||
|
||||
READ, WRITE, RDWR = 'READ', 'WRITE', 'RDWR'
|
||||
|
||||
class Connection(object):
|
||||
|
||||
def __init__(self, socket, opts, ssl_context):
|
||||
self.opts = opts
|
||||
self.ssl_context = ssl_context
|
||||
self.wait_for = READ
|
||||
if self.ssl_context is not None:
|
||||
self.ready = False
|
||||
self.socket = self.ssl_context.wrap_socket(socket, server_side=True, do_handshake_on_connect=False)
|
||||
self.set_state(RDWR, self.do_ssl_handshake)
|
||||
else:
|
||||
self.ready = True
|
||||
self.socket = socket
|
||||
self.connection_ready()
|
||||
self.last_activity = monotonic()
|
||||
|
||||
def set_state(self, wait_for, func):
|
||||
self.wait_for = wait_for
|
||||
self.handle_event = func
|
||||
|
||||
def do_ssl_handshake(self, event):
|
||||
try:
|
||||
self._sslobj.do_handshake()
|
||||
except ssl.SSLWantReadError:
|
||||
self.set_state(READ, self.do_ssl_handshake)
|
||||
except ssl.SSLWantWriteError:
|
||||
self.set_state(WRITE, self.do_ssl_handshake)
|
||||
self.ready = True
|
||||
self.connection_ready()
|
||||
|
||||
def send(self, data):
|
||||
try:
|
||||
ret = self.socket.send(data)
|
||||
self.last_activity = monotonic()
|
||||
return ret
|
||||
except socket.error as e:
|
||||
if e.errno in socket_errors_nonblocking:
|
||||
return 0
|
||||
elif e.errno in socket_errors_socket_closed:
|
||||
self.ready = False
|
||||
return 0
|
||||
raise
|
||||
|
||||
def recv(self, buffer_size):
|
||||
try:
|
||||
data = self.socket.recv(buffer_size)
|
||||
self.last_activity = monotonic()
|
||||
if not data:
|
||||
# a closed connection is indicated by signaling
|
||||
# a read condition, and having recv() return 0.
|
||||
self.ready = False
|
||||
return b''
|
||||
return data
|
||||
except socket.error as e:
|
||||
if e.errno in socket_errors_socket_closed:
|
||||
self.ready = False
|
||||
return b''
|
||||
|
||||
def close(self):
|
||||
self.ready = False
|
||||
try:
|
||||
self.socket.shutdown(socket.SHUT_WR)
|
||||
self.socket.close()
|
||||
except socket.error:
|
||||
pass
|
||||
|
||||
def connection_ready(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
class ServerLoop(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler,
|
||||
bind_address=('localhost', 8080),
|
||||
opts=None,
|
||||
# A calibre logging object. If None, a default log that logs to
|
||||
# stdout is used
|
||||
log=None
|
||||
):
|
||||
self.ready = False
|
||||
self.handler = handler
|
||||
self.opts = opts or Options()
|
||||
self.log = log or ThreadSafeLog(level=ThreadSafeLog.DEBUG)
|
||||
|
||||
ba = tuple(bind_address)
|
||||
if not ba[0]:
|
||||
# AI_PASSIVE does not work with host of '' or None
|
||||
ba = ('0.0.0.0', ba[1])
|
||||
self.bind_address = ba
|
||||
self.bound_address = None
|
||||
|
||||
self.ssl_context = None
|
||||
if self.opts.ssl_certfile is not None and self.opts.ssl_keyfile is not None:
|
||||
self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
self.ssl_context.load_cert_chain(certfile=self.opts.ssl_certfile, keyfile=self.opts.ssl_keyfile)
|
||||
|
||||
self.pre_activated_socket = None
|
||||
if self.opts.allow_socket_preallocation:
|
||||
from calibre.srv.pre_activated import pre_activated_socket
|
||||
self.pre_activated_socket = pre_activated_socket()
|
||||
if self.pre_activated_socket is not None:
|
||||
set_socket_inherit(self.pre_activated_socket, False)
|
||||
self.bind_address = self.pre_activated_socket.getsockname()
|
||||
|
||||
def __str__(self):
|
||||
return "%s(%r)" % (self.__class__.__name__, self.bind_address)
|
||||
__repr__ = __str__
|
||||
|
||||
def serve_forever(self):
|
||||
""" Listen for incoming connections. """
|
||||
|
||||
if self.pre_activated_socket is None:
|
||||
# AF_INET or AF_INET6 socket
|
||||
# Get the correct address family for our host (allows IPv6
|
||||
# addresses)
|
||||
host, port = self.bind_address
|
||||
try:
|
||||
info = socket.getaddrinfo(
|
||||
host, port, socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM, 0, socket.AI_PASSIVE)
|
||||
except socket.gaierror:
|
||||
if ':' in host:
|
||||
info = [(socket.AF_INET6, socket.SOCK_STREAM,
|
||||
0, "", self.bind_address + (0, 0))]
|
||||
else:
|
||||
info = [(socket.AF_INET, socket.SOCK_STREAM,
|
||||
0, "", self.bind_address)]
|
||||
|
||||
self.socket = None
|
||||
msg = "No socket could be created"
|
||||
for res in info:
|
||||
af, socktype, proto, canonname, sa = res
|
||||
try:
|
||||
self.bind(af, socktype, proto)
|
||||
except socket.error, serr:
|
||||
msg = "%s -- (%s: %s)" % (msg, sa, serr)
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
continue
|
||||
break
|
||||
if not self.socket:
|
||||
raise socket.error(msg)
|
||||
else:
|
||||
self.socket = self.pre_activated_socket
|
||||
self.pre_activated_socket = None
|
||||
self.setup_socket()
|
||||
|
||||
self.ready = True
|
||||
self.connection_map = {}
|
||||
self.socket.listen(min(socket.SOMAXCONN, 128))
|
||||
self.bound_address = ba = self.socket.getsockname()
|
||||
if isinstance(ba, tuple):
|
||||
ba = ':'.join(map(type(''), ba))
|
||||
self.log('calibre server listening on', ba)
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.tick()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
self.shutdown()
|
||||
break
|
||||
except:
|
||||
self.log.exception('Error in ServerLoop.tick')
|
||||
|
||||
def setup_socket(self):
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
if self.opts.no_delay and not isinstance(self.bind_address, basestring):
|
||||
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
|
||||
# If listening on the IPV6 any address ('::' = IN6ADDR_ANY),
|
||||
# activate dual-stack.
|
||||
if (hasattr(socket, 'AF_INET6') and self.socket.family == socket.AF_INET6 and
|
||||
self.bind_address[0] in ('::', '::0', '::0.0.0.0')):
|
||||
try:
|
||||
self.socket.setsockopt(
|
||||
socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
|
||||
except (AttributeError, socket.error):
|
||||
# Apparently, the socket option is not available in
|
||||
# this machine's TCP stack
|
||||
pass
|
||||
self.socket.setblocking(0)
|
||||
|
||||
def bind(self, family, atype, proto=0):
|
||||
'''Create (or recreate) the actual socket object.'''
|
||||
self.socket = socket.socket(family, atype, proto)
|
||||
set_socket_inherit(self.socket, False)
|
||||
self.setup_socket()
|
||||
self.socket.bind(self.bind_address)
|
||||
|
||||
def tick(self):
|
||||
now = monotonic()
|
||||
for s, conn in tuple(self.connection_map.iteritems()):
|
||||
if now - conn.last_activity > self.opts.timeout:
|
||||
self.log.debug('Closing connection because of extended inactivity')
|
||||
self.close(s, conn)
|
||||
|
||||
read_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is READ or c.wait_for is RDWR]
|
||||
write_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is WRITE or c.wait_for is RDWR]
|
||||
readable, writable = select.select([self.socket] + read_needed, write_needed, [], self.opts.timeout)[:2]
|
||||
if not self.ready:
|
||||
return
|
||||
|
||||
for s, conn, event in self.get_actions(readable, writable):
|
||||
try:
|
||||
conn.handle_event(event)
|
||||
if not conn.ready:
|
||||
self.close(s, conn)
|
||||
except Exception as e:
|
||||
if conn.ready:
|
||||
self.log.exception('Unhandled exception, terminating connection')
|
||||
else:
|
||||
self.log.error('Error in SSL handshake, terminating connection: %s' % as_unicode(e))
|
||||
self.close(s, conn)
|
||||
|
||||
def wakeup(self):
|
||||
# Touch our own socket to make select() return immediately.
|
||||
sock = getattr(self, "socket", None)
|
||||
if sock is not None:
|
||||
try:
|
||||
host, port = sock.getsockname()[:2]
|
||||
except socket.error as e:
|
||||
if e.errno not in socket_errors_socket_closed:
|
||||
raise
|
||||
else:
|
||||
for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM):
|
||||
af, socktype, proto, canonname, sa = res
|
||||
s = None
|
||||
try:
|
||||
s = socket.socket(af, socktype, proto)
|
||||
s.settimeout(1.0)
|
||||
s.connect((host, port))
|
||||
s.close()
|
||||
except socket.error:
|
||||
if s is not None:
|
||||
s.close()
|
||||
return sock
|
||||
|
||||
def close(self, s, conn):
|
||||
self.connection_map.pop(s, None)
|
||||
conn.close()
|
||||
|
||||
def get_actions(self, readable, writable):
|
||||
for s in readable:
|
||||
if s is self.socket:
|
||||
s, addr = self.accept()
|
||||
if s is not None:
|
||||
self.connection_map[s] = conn = self.handler(s, self.opts, self.ssl_context)
|
||||
if self.ssl_context is not None:
|
||||
yield s, conn, RDWR
|
||||
else:
|
||||
yield s, self.connection_map[s], READ
|
||||
for s in writable:
|
||||
yield s, self.connection_map[s], WRITE
|
||||
|
||||
def accept(self):
|
||||
try:
|
||||
return self.socket.accept()
|
||||
except socket.error:
|
||||
return None, None
|
||||
|
||||
def stop(self):
|
||||
self.ready = False
|
||||
self.wakeup()
|
||||
|
||||
def shutdown(self):
|
||||
try:
|
||||
if getattr(self, 'socket', None):
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
except socket.error:
|
||||
pass
|
||||
for s, conn in tuple(self.connection_map.iteritems()):
|
||||
self.close(s, conn)
|
||||
|
||||
class EchoLine(Connection): # {{{
|
||||
|
||||
bye_after_echo = False
|
||||
|
||||
def connection_ready(self):
|
||||
self.rbuf = BytesIO()
|
||||
self.set_state(READ, self.read_line)
|
||||
|
||||
def read_line(self, event):
|
||||
data = self.recv(1)
|
||||
if data:
|
||||
self.rbuf.write(data)
|
||||
if b'\n' == data:
|
||||
if self.rbuf.tell() < 3:
|
||||
# Empty line
|
||||
self.rbuf = BytesIO(b'bye' + self.rbuf.getvalue())
|
||||
self.bye_after_echo = True
|
||||
self.set_state(WRITE, self.echo)
|
||||
self.rbuf.seek(0)
|
||||
|
||||
def echo(self, event):
|
||||
pos = self.rbuf.tell()
|
||||
self.rbuf.seek(0, os.SEEK_END)
|
||||
left = self.rbuf.tell() - pos
|
||||
self.rbuf.seek(pos)
|
||||
sent = self.send(self.rbuf.read(512))
|
||||
if sent == left:
|
||||
self.rbuf = BytesIO()
|
||||
self.set_state(READ, self.read_line)
|
||||
if self.bye_after_echo:
|
||||
self.ready = False
|
||||
else:
|
||||
self.rbuf.seek(pos + sent)
|
||||
# }}}
|
||||
|
||||
if __name__ == '__main__':
|
||||
s = ServerLoop(EchoLine)
|
||||
with HandleInterrupt(s.wakeup):
|
||||
s.serve_forever()
|
@ -7,26 +7,5 @@ __license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
|
||||
class MaxSizeExceeded(Exception):
|
||||
|
||||
def __init__(self, prefix, size, limit):
|
||||
Exception.__init__(self, prefix + (' %d > maximum %d' % (size, limit)))
|
||||
self.size = size
|
||||
self.limit = limit
|
||||
|
||||
class HTTP404(Exception):
|
||||
pass
|
||||
|
||||
class IfNoneMatch(Exception):
|
||||
def __init__(self, etag=None):
|
||||
Exception.__init__(self, '')
|
||||
self.etag = etag
|
||||
|
||||
class BadChunkedInput(ValueError):
|
||||
pass
|
||||
|
||||
class RangeNotSatisfiable(ValueError):
|
||||
|
||||
def __init__(self, content_length):
|
||||
ValueError.__init__(self)
|
||||
self.content_length = content_length
|
||||
|
@ -1,602 +0,0 @@
|
||||
#!/usr/bin/env python2
|
||||
# vim:fileencoding=utf-8
|
||||
from __future__ import (unicode_literals, division, absolute_import,
|
||||
print_function)
|
||||
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
import httplib, socket, re, os
|
||||
from io import BytesIO
|
||||
import repr as reprlib
|
||||
from urllib import unquote
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
|
||||
from calibre import as_unicode
|
||||
from calibre.constants import __version__
|
||||
from calibre.srv.errors import (
|
||||
MaxSizeExceeded, HTTP404, IfNoneMatch, BadChunkedInput, RangeNotSatisfiable)
|
||||
from calibre.srv.respond import finalize_output, generate_static_output
|
||||
from calibre.srv.utils import MultiDict, http_date, socket_errors_to_ignore
|
||||
|
||||
HTTP1 = 'HTTP/1.0'
|
||||
HTTP11 = 'HTTP/1.1'
|
||||
protocol_map = {(1, 0):HTTP1, (1, 1):HTTP11}
|
||||
quoted_slash = re.compile(br'%2[fF]')
|
||||
|
||||
def parse_request_uri(uri): # {{{
|
||||
"""Parse a Request-URI into (scheme, authority, path).
|
||||
|
||||
Note that Request-URI's must be one of::
|
||||
|
||||
Request-URI = "*" | absoluteURI | abs_path | authority
|
||||
|
||||
Therefore, a Request-URI which starts with a double forward-slash
|
||||
cannot be a "net_path"::
|
||||
|
||||
net_path = "//" authority [ abs_path ]
|
||||
|
||||
Instead, it must be interpreted as an "abs_path" with an empty first
|
||||
path segment::
|
||||
|
||||
abs_path = "/" path_segments
|
||||
path_segments = segment *( "/" segment )
|
||||
segment = *pchar *( ";" param )
|
||||
param = *pchar
|
||||
"""
|
||||
if uri == b'*':
|
||||
return None, None, uri
|
||||
|
||||
i = uri.find(b'://')
|
||||
if i > 0 and b'?' not in uri[:i]:
|
||||
# An absoluteURI.
|
||||
# If there's a scheme (and it must be http or https), then:
|
||||
# http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query
|
||||
# ]]
|
||||
scheme, remainder = uri[:i].lower(), uri[i + 3:]
|
||||
authority, path = remainder.split(b'/', 1)
|
||||
path = b'/' + path
|
||||
return scheme, authority, path
|
||||
|
||||
if uri.startswith(b'/'):
|
||||
# An abs_path.
|
||||
return None, None, uri
|
||||
else:
|
||||
# An authority.
|
||||
return None, uri, None
|
||||
# }}}
|
||||
|
||||
comma_separated_headers = {
|
||||
'Accept', 'Accept-Charset', 'Accept-Encoding',
|
||||
'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control',
|
||||
'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
|
||||
'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
|
||||
'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
|
||||
'WWW-Authenticate'
|
||||
}
|
||||
|
||||
decoded_headers = {
|
||||
'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect',
|
||||
} | comma_separated_headers
|
||||
|
||||
def read_headers(readline): # {{{
|
||||
"""
|
||||
Read headers from the given stream into the given header dict.
|
||||
|
||||
If hdict is None, a new header dict is created. Returns the populated
|
||||
header dict.
|
||||
|
||||
Headers which are repeated are folded together using a comma if their
|
||||
specification so dictates.
|
||||
|
||||
This function raises ValueError when the read bytes violate the HTTP spec.
|
||||
You should probably return "400 Bad Request" if this happens.
|
||||
"""
|
||||
hdict = MultiDict()
|
||||
|
||||
def safe_decode(hname, value):
|
||||
try:
|
||||
return value.decode('ascii')
|
||||
except UnicodeDecodeError:
|
||||
if hname in decoded_headers:
|
||||
raise
|
||||
return value
|
||||
|
||||
current_key = current_value = None
|
||||
|
||||
def commit():
|
||||
if current_key:
|
||||
key = current_key.decode('ascii')
|
||||
val = safe_decode(key, current_value)
|
||||
if key in comma_separated_headers:
|
||||
existing = hdict.pop(key)
|
||||
if existing is not None:
|
||||
val = existing + ', ' + val
|
||||
hdict[key] = val
|
||||
|
||||
while True:
|
||||
line = readline()
|
||||
if not line:
|
||||
# No more data--illegal end of headers
|
||||
raise ValueError("Illegal end of headers.")
|
||||
|
||||
if line == b'\r\n':
|
||||
# Normal end of headers
|
||||
commit()
|
||||
break
|
||||
if not line.endswith(b'\r\n'):
|
||||
raise ValueError("HTTP requires CRLF terminators")
|
||||
|
||||
if line[0] in b' \t':
|
||||
# It's a continuation line.
|
||||
if current_key is None or current_value is None:
|
||||
raise ValueError('Orphaned continuation line')
|
||||
current_value += b' ' + line.strip()
|
||||
else:
|
||||
commit()
|
||||
current_key = current_value = None
|
||||
k, v = line.split(b':', 1)
|
||||
current_key = k.strip().title()
|
||||
current_value = v.strip()
|
||||
|
||||
return hdict
|
||||
# }}}
|
||||
|
||||
def http_communicate(handle_request, conn):
|
||||
' Represents interaction with a http client over a single, persistent connection '
|
||||
request_seen = False
|
||||
def repr_for_pair(pair):
|
||||
return pair.repr_for_log() if getattr(pair, 'started_request', False) else 'None'
|
||||
|
||||
def simple_response(pair, code, msg=''):
|
||||
if pair and not pair.sent_headers:
|
||||
try:
|
||||
pair.simple_response(code, msg=msg)
|
||||
except socket.error as e:
|
||||
if e.errno not in socket_errors_to_ignore:
|
||||
raise
|
||||
|
||||
try:
|
||||
while True:
|
||||
# (re)set pair to None so that if something goes wrong in
|
||||
# the HTTPPair constructor, the error doesn't
|
||||
# get written to the previous request.
|
||||
pair = None
|
||||
pair = HTTPPair(handle_request, conn)
|
||||
|
||||
# This order of operations should guarantee correct pipelining.
|
||||
pair.parse_request()
|
||||
if not pair.ready:
|
||||
# Something went wrong in the parsing (and the server has
|
||||
# probably already made a simple_response). Return and
|
||||
# let the conn close.
|
||||
return
|
||||
|
||||
request_seen = True
|
||||
pair.respond()
|
||||
if pair.close_connection:
|
||||
return
|
||||
except socket.timeout:
|
||||
# Don't error if we're between requests; only error
|
||||
# if 1) no request has been started at all, or 2) we're
|
||||
# in the middle of a request. This allows persistent
|
||||
# connections for HTTP/1.1
|
||||
if (not request_seen) or (pair and pair.started_request):
|
||||
# Don't bother writing the 408 if the response
|
||||
# has already started being written.
|
||||
simple_response(pair, httplib.REQUEST_TIMEOUT)
|
||||
except socket.error:
|
||||
# This socket is broken. Log the error and close connection
|
||||
conn.server_loop.log.exception(
|
||||
'Communication failed (socket error) while processing request:', repr_for_pair(pair))
|
||||
except MaxSizeExceeded as e:
|
||||
conn.server_loop.log.warn('Too large request body (%d > %d) for request:' % (e.size, e.limit), repr_for_pair(pair))
|
||||
# Can happen if the request uses chunked transfer encoding
|
||||
simple_response(pair, httplib.REQUEST_ENTITY_TOO_LARGE,
|
||||
"The entity sent with the request exceeds the maximum "
|
||||
"allowed bytes (%d)." % pair.max_request_body_size)
|
||||
except BadChunkedInput as e:
|
||||
conn.server_loop.log.warn('Bad chunked encoding (%s) for request:' % as_unicode(e.message), repr_for_pair(pair))
|
||||
simple_response(pair, httplib.BAD_REQUEST,
|
||||
'Invalid chunked encoding for request body: %s' % as_unicode(e.message))
|
||||
except Exception:
|
||||
conn.server_loop.log.exception('Error serving request:', pair.repr_for_log() if getattr(pair, 'started_request', False) else 'None')
|
||||
simple_response(pair, httplib.INTERNAL_SERVER_ERROR)
|
||||
|
||||
class FixedSizeReader(object): # {{{
|
||||
|
||||
def __init__(self, socket_file, content_length):
|
||||
self.socket_file, self.remaining = socket_file, content_length
|
||||
|
||||
def read(self, size=-1):
|
||||
if size < 0:
|
||||
size = self.remaining
|
||||
size = min(self.remaining, size)
|
||||
if size < 1:
|
||||
return b''
|
||||
data = self.socket_file.read(size)
|
||||
self.remaining -= len(data)
|
||||
return data
|
||||
# }}}
|
||||
|
||||
class ChunkedReader(object): # {{{
|
||||
|
||||
def __init__(self, socket_file, maxsize):
|
||||
self.socket_file, self.maxsize = socket_file, maxsize
|
||||
self.rbuf = BytesIO()
|
||||
self.bytes_read = 0
|
||||
self.finished = False
|
||||
|
||||
def check_size(self):
|
||||
if self.bytes_read > self.maxsize:
|
||||
raise MaxSizeExceeded('Request entity too large', self.bytes_read, self.maxsize)
|
||||
|
||||
def read_chunk(self):
|
||||
if self.finished:
|
||||
return
|
||||
line = self.socket_file.readline()
|
||||
self.bytes_read += len(line)
|
||||
self.check_size()
|
||||
chunk_size = line.strip().split(b';', 1)[0]
|
||||
try:
|
||||
chunk_size = int(line, 16) + 2
|
||||
except Exception:
|
||||
raise BadChunkedInput('%s is not a valid chunk size' % reprlib.repr(chunk_size))
|
||||
if chunk_size + self.bytes_read > self.maxsize:
|
||||
raise MaxSizeExceeded('Request entity too large', self.bytes_read + chunk_size, self.maxsize)
|
||||
try:
|
||||
chunk = self.socket_file.read(chunk_size)
|
||||
except socket.timeout:
|
||||
raise BadChunkedInput('Timed out waiting for chunk of size %d to complete' % chunk_size)
|
||||
if len(chunk) < chunk_size:
|
||||
raise BadChunkedInput('Bad chunked encoding, chunk truncated: %d < %s' % (len(chunk), chunk_size))
|
||||
if not chunk.endswith(b'\r\n'):
|
||||
raise BadChunkedInput('Bad chunked encoding: %r != CRLF' % chunk[-2:])
|
||||
self.rbuf.seek(0, os.SEEK_END)
|
||||
self.bytes_read += chunk_size
|
||||
if chunk_size == 2:
|
||||
self.finished = True
|
||||
else:
|
||||
self.rbuf.write(chunk[:-2])
|
||||
|
||||
def read(self, size=-1):
|
||||
if size < 0:
|
||||
# Read all data
|
||||
while not self.finished:
|
||||
self.read_chunk()
|
||||
self.rbuf.seek(0)
|
||||
rv = self.rbuf.read()
|
||||
if rv:
|
||||
self.rbuf.truncate(0)
|
||||
return rv
|
||||
if size == 0:
|
||||
return b''
|
||||
while self.rbuf.tell() < size and not self.finished:
|
||||
self.read_chunk()
|
||||
data = self.rbuf.getvalue()
|
||||
self.rbuf.truncate(0)
|
||||
if size < len(data):
|
||||
self.rbuf.write(data[size:])
|
||||
return data[:size]
|
||||
return data
|
||||
# }}}
|
||||
|
||||
class HTTPPair(object):
|
||||
|
||||
''' Represents a HTTP request/response pair '''
|
||||
|
||||
def __init__(self, handle_request, conn):
|
||||
self.conn = conn
|
||||
self.server_loop = conn.server_loop
|
||||
self.max_header_line_size = int(self.server_loop.opts.max_header_line_size * 1024)
|
||||
self.max_request_body_size = int(self.server_loop.opts.max_request_body_size * 1024 * 1024)
|
||||
self.scheme = 'http' if self.server_loop.ssl_context is None else 'https'
|
||||
self.inheaders = MultiDict()
|
||||
self.outheaders = MultiDict()
|
||||
self.handle_request = handle_request
|
||||
self.request_line = None
|
||||
self.path = ()
|
||||
self.qs = MultiDict()
|
||||
self.method = None
|
||||
|
||||
"""When True, the request has been parsed and is ready to begin generating
|
||||
the response. When False, signals the calling Connection that the response
|
||||
should not be generated and the connection should close, immediately after
|
||||
parsing the request."""
|
||||
self.ready = False
|
||||
|
||||
"""Signals the calling Connection that the request should close. This does
|
||||
not imply an error! The client and/or server may each request that the
|
||||
connection be closed, after the response."""
|
||||
self.close_connection = False
|
||||
|
||||
self.started_request = False
|
||||
self.response_protocol = HTTP1
|
||||
|
||||
self.status_code = None
|
||||
self.sent_headers = False
|
||||
|
||||
self.request_content_length = 0
|
||||
self.chunked_read = False
|
||||
|
||||
def parse_request(self):
|
||||
"""Parse the next HTTP request start-line and message-headers."""
|
||||
try:
|
||||
if not self.read_request_line():
|
||||
return
|
||||
except MaxSizeExceeded as e:
|
||||
self.server_loop.log.warn('Too large request URI (%d > %d), dropping connection' % (e.size, e.limit))
|
||||
self.simple_response(
|
||||
httplib.REQUEST_URI_TOO_LONG,
|
||||
"The Request-URI sent with the request exceeds the maximum allowed bytes.")
|
||||
return
|
||||
|
||||
try:
|
||||
if not self.read_request_headers():
|
||||
return
|
||||
except MaxSizeExceeded as e:
|
||||
self.server_loop.log.warn('Too large header (%d > %d) for request, dropping connection' % (e.size, e.limit))
|
||||
self.simple_response(
|
||||
httplib.REQUEST_ENTITY_TOO_LARGE,
|
||||
"The headers sent with the request exceed the maximum allowed bytes.")
|
||||
return
|
||||
|
||||
self.ready = True
|
||||
|
||||
def read_request_line(self):
|
||||
self.request_line = request_line = self.conn.socket_file.readline(maxsize=self.max_header_line_size)
|
||||
|
||||
# Set started_request to True so http_communicate() knows to send 408
|
||||
# from here on out.
|
||||
self.started_request = True
|
||||
if not request_line:
|
||||
return False
|
||||
|
||||
if request_line == b'\r\n':
|
||||
# RFC 2616 sec 4.1: "...if the server is reading the protocol
|
||||
# stream at the beginning of a message and receives a CRLF
|
||||
# first, it should ignore the CRLF."
|
||||
# But only ignore one leading line! else we enable a DoS.
|
||||
request_line = self.conn.socket_file.readline(maxsize=self.max_header_line_size)
|
||||
if not request_line:
|
||||
return False
|
||||
|
||||
if not request_line.endswith(b'\r\n'):
|
||||
self.simple_response(
|
||||
httplib.BAD_REQUEST, "HTTP requires CRLF terminators")
|
||||
return False
|
||||
|
||||
try:
|
||||
method, uri, req_protocol = request_line.strip().split(b' ', 2)
|
||||
rp = int(req_protocol[5]), int(req_protocol[7])
|
||||
self.method = method.decode('ascii')
|
||||
except (ValueError, IndexError):
|
||||
self.simple_response(httplib.BAD_REQUEST, "Malformed Request-Line")
|
||||
return False
|
||||
|
||||
try:
|
||||
self.request_protocol = protocol_map[rp]
|
||||
except KeyError:
|
||||
self.simple_response(httplib.HTTP_VERSION_NOT_SUPPORTED)
|
||||
return False
|
||||
self.response_protocol = protocol_map[min((1, 1), rp)]
|
||||
|
||||
scheme, authority, path = parse_request_uri(uri)
|
||||
if b'#' in path:
|
||||
self.simple_response(httplib.BAD_REQUEST, "Illegal #fragment in Request-URI.")
|
||||
return False
|
||||
|
||||
if scheme:
|
||||
try:
|
||||
self.scheme = scheme.decode('ascii')
|
||||
except ValueError:
|
||||
self.simple_response(httplib.BAD_REQUEST, 'Un-decodeable scheme')
|
||||
return False
|
||||
|
||||
qs = b''
|
||||
if b'?' in path:
|
||||
path, qs = path.split(b'?', 1)
|
||||
try:
|
||||
self.qs = MultiDict.create_from_query_string(qs)
|
||||
except Exception:
|
||||
self.simple_response(httplib.BAD_REQUEST, "Malformed Request-Line",
|
||||
'Unparseable query string')
|
||||
return False
|
||||
|
||||
try:
|
||||
path = '%2F'.join(unquote(x).decode('utf-8') for x in quoted_slash.split(path))
|
||||
except ValueError as e:
|
||||
self.simple_response(httplib.BAD_REQUEST, as_unicode(e))
|
||||
return False
|
||||
self.path = tuple(filter(None, (x.replace('%2F', '/') for x in path.split('/'))))
|
||||
|
||||
return True
|
||||
|
||||
def read_request_headers(self):
|
||||
# then all the http headers
|
||||
try:
|
||||
self.inheaders = read_headers(partial(self.conn.socket_file.readline, maxsize=self.max_header_line_size))
|
||||
self.request_content_length = int(self.inheaders.get('Content-Length', 0))
|
||||
except ValueError as e:
|
||||
self.simple_response(httplib.BAD_REQUEST, as_unicode(e))
|
||||
return False
|
||||
|
||||
if self.request_content_length > self.max_request_body_size:
|
||||
self.simple_response(
|
||||
httplib.REQUEST_ENTITY_TOO_LARGE,
|
||||
"The entity sent with the request exceeds the maximum "
|
||||
"allowed bytes (%d)." % self.max_request_body_size)
|
||||
return False
|
||||
|
||||
# Persistent connection support
|
||||
if self.response_protocol is HTTP11:
|
||||
# Both server and client are HTTP/1.1
|
||||
if self.inheaders.get("Connection", "") == "close":
|
||||
self.close_connection = True
|
||||
else:
|
||||
# Either the server or client (or both) are HTTP/1.0
|
||||
if self.inheaders.get("Connection", "") != "Keep-Alive":
|
||||
self.close_connection = True
|
||||
|
||||
# Transfer-Encoding support
|
||||
te = ()
|
||||
if self.response_protocol is HTTP11:
|
||||
rte = self.inheaders.get("Transfer-Encoding")
|
||||
if rte:
|
||||
te = [x.strip().lower() for x in rte.split(",") if x.strip()]
|
||||
self.chunked_read = False
|
||||
if te:
|
||||
for enc in te:
|
||||
if enc == "chunked":
|
||||
self.chunked_read = True
|
||||
else:
|
||||
# Note that, even if we see "chunked", we must reject
|
||||
# if there is an extension we don't recognize.
|
||||
self.simple_response(httplib.NOT_IMPLEMENTED, "Unknown transfer encoding: %r" % enc)
|
||||
self.close_connection = True
|
||||
return False
|
||||
|
||||
if self.inheaders.get("Expect", '').lower() == "100-continue":
|
||||
# Don't use simple_response here, because it emits headers
|
||||
# we don't want.
|
||||
msg = HTTP11 + " 100 Continue\r\n\r\n"
|
||||
self.flushed_write(msg.encode('ascii'))
|
||||
return True
|
||||
|
||||
def simple_response(self, status_code, msg="", read_remaining_input=False):
|
||||
abort = status_code in (httplib.REQUEST_ENTITY_TOO_LARGE, httplib.REQUEST_URI_TOO_LONG)
|
||||
if abort:
|
||||
self.close_connection = True
|
||||
if self.response_protocol is HTTP1:
|
||||
# HTTP/1.0 has no 413/414 codes
|
||||
status_code = httplib.BAD_REQUEST
|
||||
|
||||
msg = msg.encode('utf-8')
|
||||
buf = [
|
||||
'%s %d %s' % (self.response_protocol, status_code, httplib.responses[status_code]),
|
||||
"Content-Length: %s" % len(msg),
|
||||
"Content-Type: text/plain; charset=UTF-8",
|
||||
"Date: " + http_date(),
|
||||
]
|
||||
if abort and self.response_protocol is HTTP11:
|
||||
buf.append("Connection: close")
|
||||
buf.append('')
|
||||
buf = [(x + '\r\n').encode('ascii') for x in buf]
|
||||
if self.method != 'HEAD':
|
||||
buf.append(msg)
|
||||
if read_remaining_input:
|
||||
self.input_reader.read()
|
||||
self.flushed_write(b''.join(buf))
|
||||
|
||||
def send_not_modified(self, etag=None):
|
||||
buf = [
|
||||
'%s %d %s' % (self.response_protocol, httplib.NOT_MODIFIED, httplib.responses[httplib.NOT_MODIFIED]),
|
||||
"Content-Length: 0",
|
||||
"Date: " + http_date(),
|
||||
]
|
||||
if etag is not None:
|
||||
buf.append('ETag: ' + etag)
|
||||
self.send_buf(buf)
|
||||
|
||||
def send_buf(self, buf, include_cache_headers=True):
|
||||
if include_cache_headers:
|
||||
for header in ('Expires', 'Cache-Control', 'Vary'):
|
||||
val = self.outheaders.get(header)
|
||||
if val:
|
||||
buf.append(header + ': ' + val)
|
||||
buf.append('')
|
||||
buf = [(x + '\r\n').encode('ascii') for x in buf]
|
||||
self.flushed_write(b''.join(buf))
|
||||
|
||||
def send_range_not_satisfiable(self, content_length):
|
||||
buf = [
|
||||
'%s %d %s' % (self.response_protocol, httplib.REQUESTED_RANGE_NOT_SATISFIABLE, httplib.responses[httplib.REQUESTED_RANGE_NOT_SATISFIABLE]),
|
||||
"Date: " + http_date(),
|
||||
"Content-Range: bytes */%d" % content_length,
|
||||
]
|
||||
self.send_buf(buf)
|
||||
|
||||
def flushed_write(self, data):
|
||||
self.conn.socket_file.write(data)
|
||||
self.conn.socket_file.flush()
|
||||
|
||||
def repr_for_log(self):
|
||||
ans = ['HTTPPair: %r' % self.request_line]
|
||||
if self.path:
|
||||
ans.append('Path: %r' % (self.path,))
|
||||
if self.qs:
|
||||
ans.append('Query: %r' % self.qs)
|
||||
if self.inheaders:
|
||||
ans.extend(('In Headers:', self.inheaders.pretty('\t')))
|
||||
if self.outheaders:
|
||||
ans.extend(('Out Headers:', self.outheaders.pretty('\t')))
|
||||
return '\n'.join(ans)
|
||||
|
||||
def generate_static_output(self, name, generator):
|
||||
return generate_static_output(self.server_loop.gso_cache, self.server_loop.gso_lock, name, generator)
|
||||
|
||||
def respond(self):
|
||||
if self.chunked_read:
|
||||
self.input_reader = ChunkedReader(self.conn.socket_file, self.max_request_body_size)
|
||||
else:
|
||||
self.input_reader = FixedSizeReader(self.conn.socket_file, self.request_content_length)
|
||||
|
||||
try:
|
||||
output = self.handle_request(self)
|
||||
except HTTP404 as e:
|
||||
self.simple_response(httplib.NOT_FOUND, e.message, read_remaining_input=True)
|
||||
return
|
||||
# Read and discard any remaining body from the HTTP request
|
||||
self.input_reader.read()
|
||||
if self.status_code is None:
|
||||
self.status_code = httplib.CREATED if self.method == 'POST' else httplib.OK
|
||||
|
||||
try:
|
||||
self.status_code, output = finalize_output(
|
||||
output, self.inheaders, self.outheaders, self.status_code,
|
||||
self.response_protocol is HTTP1, self.method, self.server_loop.opts)
|
||||
except IfNoneMatch as e:
|
||||
if self.method in ('GET', 'HEAD'):
|
||||
self.send_not_modified(e.etag)
|
||||
else:
|
||||
self.simple_response(httplib.PRECONDITION_FAILED)
|
||||
return
|
||||
except RangeNotSatisfiable as e:
|
||||
self.send_range_not_satisfiable(e.content_length)
|
||||
return
|
||||
|
||||
with self.conn.corked:
|
||||
self.send_headers()
|
||||
if self.method != 'HEAD':
|
||||
output.commit(self.conn.socket_file)
|
||||
self.conn.socket_file.flush()
|
||||
|
||||
def send_headers(self):
|
||||
self.sent_headers = True
|
||||
self.outheaders.set('Date', http_date(), replace_all=True)
|
||||
self.outheaders.set('Server', 'calibre %s' % __version__, replace_all=True)
|
||||
keep_alive = not self.close_connection and self.server_loop.opts.timeout > 0
|
||||
if keep_alive:
|
||||
self.outheaders.set('Keep-Alive', 'timeout=%d' % self.server_loop.opts.timeout)
|
||||
if 'Connection' not in self.outheaders:
|
||||
if self.response_protocol is HTTP11:
|
||||
if self.close_connection:
|
||||
self.outheaders.set('Connection', 'close')
|
||||
else:
|
||||
if not self.close_connection:
|
||||
self.outheaders.set('Connection', 'Keep-Alive')
|
||||
|
||||
ct = self.outheaders.get('Content-Type', '')
|
||||
if ct.startswith('text/') and 'charset=' not in ct:
|
||||
self.outheaders.set('Content-Type', ct + '; charset=UTF-8')
|
||||
|
||||
buf = [HTTP11 + (' %d ' % self.status_code) + httplib.responses[self.status_code]]
|
||||
for header, value in sorted(self.outheaders.iteritems(), key=itemgetter(0)):
|
||||
buf.append('%s: %s' % (header, value))
|
||||
buf.append('')
|
||||
self.conn.socket_file.write(b''.join((x + '\r\n').encode('ascii') for x in buf))
|
||||
|
||||
|
||||
def create_http_handler(handle_request):
|
||||
return partial(http_communicate, handle_request)
|
374
src/calibre/srv/http_request.py
Normal file
374
src/calibre/srv/http_request.py
Normal file
@ -0,0 +1,374 @@
|
||||
#!/usr/bin/env python2
|
||||
# vim:fileencoding=utf-8
|
||||
from __future__ import (unicode_literals, division, absolute_import,
|
||||
print_function)
|
||||
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
import re, httplib, repr as reprlib
|
||||
from io import BytesIO, DEFAULT_BUFFER_SIZE
|
||||
from urllib import unquote
|
||||
|
||||
from calibre import as_unicode, force_unicode
|
||||
from calibre.ptempfile import SpooledTemporaryFile
|
||||
from calibre.srv.loop import Connection, READ, WRITE
|
||||
from calibre.srv.utils import MultiDict, HTTP1, HTTP11
|
||||
|
||||
protocol_map = {(1, 0):HTTP1, (1, 1):HTTP11}
|
||||
quoted_slash = re.compile(br'%2[fF]')
|
||||
HTTP_METHODS = {'HEAD', 'GET', 'PUT', 'POST', 'TRACE', 'DELETE', 'OPTIONS'}
|
||||
|
||||
def parse_request_uri(uri): # {{{
|
||||
"""Parse a Request-URI into (scheme, authority, path).
|
||||
|
||||
Note that Request-URI's must be one of::
|
||||
|
||||
Request-URI = "*" | absoluteURI | abs_path | authority
|
||||
|
||||
Therefore, a Request-URI which starts with a double forward-slash
|
||||
cannot be a "net_path"::
|
||||
|
||||
net_path = "//" authority [ abs_path ]
|
||||
|
||||
Instead, it must be interpreted as an "abs_path" with an empty first
|
||||
path segment::
|
||||
|
||||
abs_path = "/" path_segments
|
||||
path_segments = segment *( "/" segment )
|
||||
segment = *pchar *( ";" param )
|
||||
param = *pchar
|
||||
"""
|
||||
if uri == b'*':
|
||||
return None, None, uri
|
||||
|
||||
i = uri.find(b'://')
|
||||
if i > 0 and b'?' not in uri[:i]:
|
||||
# An absoluteURI.
|
||||
# If there's a scheme (and it must be http or https), then:
|
||||
# http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query
|
||||
# ]]
|
||||
scheme, remainder = uri[:i].lower(), uri[i + 3:]
|
||||
authority, path = remainder.split(b'/', 1)
|
||||
path = b'/' + path
|
||||
return scheme, authority, path
|
||||
|
||||
if uri.startswith(b'/'):
|
||||
# An abs_path.
|
||||
return None, None, uri
|
||||
else:
|
||||
# An authority.
|
||||
return None, uri, None
|
||||
# }}}
|
||||
|
||||
# HTTP Header parsing {{{
|
||||
|
||||
comma_separated_headers = {
|
||||
'Accept', 'Accept-Charset', 'Accept-Encoding',
|
||||
'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control',
|
||||
'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
|
||||
'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
|
||||
'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
|
||||
'WWW-Authenticate'
|
||||
}
|
||||
|
||||
decoded_headers = {
|
||||
'Transfer-Encoding', 'Connection', 'Keep-Alive', 'Expect',
|
||||
} | comma_separated_headers
|
||||
|
||||
class HTTPHeaderParser(object):
|
||||
|
||||
'''
|
||||
Parse HTTP headers. Use this class by repeatedly calling the created object
|
||||
with a single line at a time and checking the finished attribute. Can raise ValueError
|
||||
for malformed headers, in which case you should probably return BAD_REQUEST.
|
||||
|
||||
Headers which are repeated are folded together using a comma if their
|
||||
specification so dictates.
|
||||
'''
|
||||
__slots__ = ('hdict', 'lines', 'finished')
|
||||
|
||||
def __init__(self):
|
||||
self.hdict = MultiDict()
|
||||
self.lines = []
|
||||
self.finished = False
|
||||
|
||||
def push(self, *lines):
|
||||
for line in lines:
|
||||
self(line)
|
||||
|
||||
def __call__(self, line):
|
||||
'Process a single line'
|
||||
|
||||
def safe_decode(hname, value):
|
||||
try:
|
||||
return value.decode('ascii')
|
||||
except UnicodeDecodeError:
|
||||
if hname in decoded_headers:
|
||||
raise
|
||||
return value
|
||||
|
||||
def commit():
|
||||
if not self.lines:
|
||||
return
|
||||
line = b' '.join(self.lines)
|
||||
del self.lines[:]
|
||||
|
||||
k, v = line.partition(b':')[::2]
|
||||
key = k.strip().decode('ascii').title()
|
||||
val = safe_decode(key, v.strip())
|
||||
if not key or not val:
|
||||
raise ValueError('Malformed header line: %s' % reprlib.repr(line))
|
||||
if key in comma_separated_headers:
|
||||
existing = self.hdict.pop(key)
|
||||
if existing is not None:
|
||||
val = existing + ', ' + val
|
||||
self.hdict[key] = val
|
||||
|
||||
if line == b'\r\n':
|
||||
# Normal end of headers
|
||||
commit()
|
||||
self.finished = True
|
||||
return
|
||||
|
||||
if line[0] in b' \t':
|
||||
# It's a continuation line.
|
||||
if not self.lines:
|
||||
raise ValueError('Orphaned continuation line')
|
||||
self.lines.append(line.lstrip())
|
||||
else:
|
||||
commit()
|
||||
self.lines.append(line)
|
||||
|
||||
def read_headers(readline):
|
||||
p = HTTPHeaderParser()
|
||||
while not p.finished:
|
||||
p(readline())
|
||||
return p.hdict
|
||||
# }}}
|
||||
|
||||
class HTTPRequest(Connection):
|
||||
|
||||
request_handler = None
|
||||
static_cache = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
Connection.__init__(self, *args, **kwargs)
|
||||
self.corked = False
|
||||
self.max_header_line_size = int(1024 * self.opts.max_header_line_size)
|
||||
self.max_request_body_size = int(1024 * 1024 * self.opts.max_request_body_size)
|
||||
|
||||
def read(self, buf, endpos):
|
||||
size = endpos - buf.tell()
|
||||
if size > 0:
|
||||
data = self.recv(min(size, DEFAULT_BUFFER_SIZE))
|
||||
if data:
|
||||
buf.write(data)
|
||||
return len(data) >= size
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def readline(self, buf):
|
||||
if buf.tell() >= self.max_header_line_size - 1:
|
||||
self.simple_response(self.header_line_too_long_error_code)
|
||||
return
|
||||
data = self.recv(1)
|
||||
if data:
|
||||
buf.write(data)
|
||||
if b'\n' == data:
|
||||
line = buf.getvalue()
|
||||
buf.seek(0), buf.truncate()
|
||||
if line.endswith(b'\r\n'):
|
||||
return line
|
||||
else:
|
||||
self.simple_response(httplib.BAD_REQUEST, 'HTTP requires CRLF line terminators')
|
||||
|
||||
def connection_ready(self):
|
||||
'Become ready to read an HTTP request'
|
||||
self.method = self.request_line = None
|
||||
self.response_protocol = self.request_protocol = HTTP1
|
||||
self.path = self.query = None
|
||||
self.close_after_response = False
|
||||
self.header_line_too_long_error_code = httplib.REQUEST_URI_TOO_LONG
|
||||
self.response_started = False
|
||||
self.set_state(READ, self.parse_request_line, BytesIO(), first=True)
|
||||
|
||||
def parse_request_line(self, buf, event, first=False): # {{{
|
||||
line = self.readline(buf)
|
||||
if line is None:
|
||||
return
|
||||
if line == b'\r\n':
|
||||
# Ignore a single leading empty line, as per RFC 2616 sec 4.1
|
||||
if first:
|
||||
return self.set_state(READ, self.parse_request_line, BytesIO())
|
||||
return self.simple_response(httplib.BAD_REQUEST, 'Multiple leading empty lines not allowed')
|
||||
|
||||
try:
|
||||
method, uri, req_protocol = line.strip().split(b' ', 2)
|
||||
rp = int(req_protocol[5]), int(req_protocol[7])
|
||||
self.method = method.decode('ascii').upper()
|
||||
except Exception:
|
||||
return self.simple_response(httplib.BAD_REQUEST, "Malformed Request-Line")
|
||||
|
||||
if self.method not in HTTP_METHODS:
|
||||
return self.simple_response(httplib.BAD_REQUEST, "Unknown HTTP method")
|
||||
|
||||
try:
|
||||
self.request_protocol = protocol_map[rp]
|
||||
except KeyError:
|
||||
return self.simple_response(httplib.HTTP_VERSION_NOT_SUPPORTED)
|
||||
self.response_protocol = protocol_map[min((1, 1), rp)]
|
||||
scheme, authority, path = parse_request_uri(uri)
|
||||
if b'#' in path:
|
||||
return self.simple_response(httplib.BAD_REQUEST, "Illegal #fragment in Request-URI.")
|
||||
|
||||
if scheme:
|
||||
try:
|
||||
self.scheme = scheme.decode('ascii')
|
||||
except ValueError:
|
||||
return self.simple_response(httplib.BAD_REQUEST, 'Un-decodeable scheme')
|
||||
|
||||
qs = b''
|
||||
if b'?' in path:
|
||||
path, qs = path.split(b'?', 1)
|
||||
try:
|
||||
self.query = MultiDict.create_from_query_string(qs)
|
||||
except Exception:
|
||||
return self.simple_response(httplib.BAD_REQUEST, 'Unparseable query string')
|
||||
|
||||
try:
|
||||
path = '%2F'.join(unquote(x).decode('utf-8') for x in quoted_slash.split(path))
|
||||
except ValueError as e:
|
||||
return self.simple_response(httplib.BAD_REQUEST, as_unicode(e))
|
||||
self.path = tuple(filter(None, (x.replace('%2F', '/') for x in path.split('/'))))
|
||||
self.header_line_too_long_error_code = httplib.REQUEST_ENTITY_TOO_LARGE
|
||||
self.request_line = line.rstrip()
|
||||
self.set_state(READ, self.parse_header_line, HTTPHeaderParser(), BytesIO())
|
||||
# }}}
|
||||
|
||||
@property
|
||||
def state_description(self):
|
||||
return 'Request: %s' % force_unicode(self.request_line, 'utf-8')
|
||||
|
||||
def parse_header_line(self, parser, buf, event):
|
||||
line = self.readline(buf)
|
||||
if line is None:
|
||||
return
|
||||
try:
|
||||
parser(line)
|
||||
except ValueError:
|
||||
self.simple_response(httplib.BAD_REQUEST, 'Failed to parse header line')
|
||||
return
|
||||
if parser.finished:
|
||||
self.finalize_headers(parser.hdict)
|
||||
|
||||
def finalize_headers(self, inheaders):
|
||||
request_content_length = int(inheaders.get('Content-Length', 0))
|
||||
if request_content_length > self.max_request_body_size:
|
||||
return self.simple_response(httplib.REQUEST_ENTITY_TOO_LARGE,
|
||||
"The entity sent with the request exceeds the maximum "
|
||||
"allowed bytes (%d)." % self.max_request_body_size)
|
||||
# Persistent connection support
|
||||
if self.response_protocol is HTTP11:
|
||||
# Both server and client are HTTP/1.1
|
||||
if inheaders.get("Connection", "") == "close":
|
||||
self.close_after_response = True
|
||||
else:
|
||||
# Either the server or client (or both) are HTTP/1.0
|
||||
if inheaders.get("Connection", "") != "Keep-Alive":
|
||||
self.close_after_response = True
|
||||
|
||||
# Transfer-Encoding support
|
||||
te = ()
|
||||
if self.response_protocol is HTTP11:
|
||||
rte = inheaders.get("Transfer-Encoding")
|
||||
if rte:
|
||||
te = [x.strip().lower() for x in rte.split(",") if x.strip()]
|
||||
chunked_read = False
|
||||
if te:
|
||||
for enc in te:
|
||||
if enc == "chunked":
|
||||
chunked_read = True
|
||||
else:
|
||||
# Note that, even if we see "chunked", we must reject
|
||||
# if there is an extension we don't recognize.
|
||||
return self.simple_response(httplib.NOT_IMPLEMENTED, "Unknown transfer encoding: %r" % enc)
|
||||
|
||||
if inheaders.get("Expect", '').lower() == "100-continue":
|
||||
buf = BytesIO((HTTP11 + " 100 Continue\r\n\r\n").encode('ascii'))
|
||||
return self.set_state(WRITE, self.write_continue, buf, inheaders, request_content_length, chunked_read)
|
||||
|
||||
self.read_request_body(inheaders, request_content_length, chunked_read)
|
||||
|
||||
def write_continue(self, buf, inheaders, request_content_length, chunked_read, event):
|
||||
if self.write(buf):
|
||||
self.read_request_body(inheaders, request_content_length, chunked_read)
|
||||
|
||||
def read_request_body(self, inheaders, request_content_length, chunked_read):
|
||||
buf = SpooledTemporaryFile(prefix='rq-body-', max_size=DEFAULT_BUFFER_SIZE, dir=self.tdir)
|
||||
if chunked_read:
|
||||
self.set_state(READ, self.read_chunk_length, inheaders, BytesIO(), buf, [0])
|
||||
else:
|
||||
if request_content_length > 0:
|
||||
self.set_state(READ, self.sized_read, inheaders, buf, request_content_length)
|
||||
else:
|
||||
self.prepare_response(inheaders, BytesIO())
|
||||
|
||||
def sized_read(self, inheaders, buf, request_content_length, event):
|
||||
if self.read(buf, request_content_length):
|
||||
self.prepare_response(inheaders, buf)
|
||||
|
||||
def read_chunk_length(self, inheaders, line_buf, buf, bytes_read, event):
|
||||
line = self.readline(line_buf)
|
||||
if line is None:
|
||||
return
|
||||
bytes_read[0] += len(line)
|
||||
try:
|
||||
chunk_size = int(line.strip(), 16)
|
||||
except Exception:
|
||||
return self.simple_response(httplib.BAD_REQUEST, '%s is not a valid chunk size' % reprlib.repr(line.strip()))
|
||||
if bytes_read[0] + chunk_size + 2 > self.max_request_body_size:
|
||||
return self.simple_response(httplib.REQUEST_ENTITY_TOO_LARGE,
|
||||
'Chunked request is larger than %d bytes' % self.max_request_body_size)
|
||||
if chunk_size == 0:
|
||||
self.set_state(READ, self.read_chunk_separator, inheaders, BytesIO(), buf, bytes_read, last=True)
|
||||
else:
|
||||
self.set_state(READ, self.read_chunk, inheaders, buf, chunk_size, buf.tell() + chunk_size, bytes_read)
|
||||
|
||||
def read_chunk(self, inheaders, buf, chunk_size, end, bytes_read, event):
|
||||
if not self.read(buf, end):
|
||||
return
|
||||
bytes_read[0] += chunk_size
|
||||
self.set_state(READ, self.read_chunk_separator, inheaders, BytesIO(), buf, bytes_read)
|
||||
|
||||
def read_chunk_separator(self, inheaders, line_buf, buf, bytes_read, event, last=False):
|
||||
line = self.readline(line_buf)
|
||||
if line is None:
|
||||
return
|
||||
if line != b'\r\n':
|
||||
return self.simple_response(httplib.BAD_REQUEST, 'Chunk does not have trailing CRLF')
|
||||
bytes_read[0] += len(line)
|
||||
if bytes_read[0] > self.max_request_body_size:
|
||||
return self.simple_response(httplib.REQUEST_ENTITY_TOO_LARGE,
|
||||
'Chunked request is larger than %d bytes' % self.max_request_body_size)
|
||||
if last:
|
||||
self.prepare_response(inheaders, buf)
|
||||
else:
|
||||
self.set_state(READ, self.read_chunk_length, inheaders, BytesIO(), buf, bytes_read)
|
||||
|
||||
def handle_timeout(self):
|
||||
if self.response_started:
|
||||
return False
|
||||
self.simple_response(httplib.REQUEST_TIMEOUT)
|
||||
return True
|
||||
|
||||
def write(self, buf, end=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def simple_response(self, status_code, msg='', close_after_response=True):
|
||||
raise NotImplementedError()
|
||||
|
||||
def prepare_response(self, inheaders, request_body_file):
|
||||
raise NotImplementedError()
|
525
src/calibre/srv/http_response.py
Normal file
525
src/calibre/srv/http_response.py
Normal file
@ -0,0 +1,525 @@
|
||||
#!/usr/bin/env python2
|
||||
# vim:fileencoding=utf-8
|
||||
from __future__ import (unicode_literals, division, absolute_import,
|
||||
print_function)
|
||||
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
import os, httplib, hashlib, uuid, zlib, time, struct, repr as reprlib
|
||||
from select import PIPE_BUF
|
||||
from collections import namedtuple
|
||||
from io import BytesIO, DEFAULT_BUFFER_SIZE
|
||||
from itertools import chain, repeat, izip_longest
|
||||
from operator import itemgetter
|
||||
from functools import wraps
|
||||
|
||||
from calibre import guess_type, force_unicode
|
||||
from calibre.constants import __version__
|
||||
from calibre.srv.loop import WRITE
|
||||
from calibre.srv.errors import HTTP404
|
||||
from calibre.srv.http_request import HTTPRequest, read_headers
|
||||
from calibre.srv.sendfile import file_metadata, sendfile_to_socket_async, CannotSendfile, SendfileInterrupted
|
||||
from calibre.srv.utils import MultiDict, start_cork, stop_cork, http_date, HTTP1, HTTP11, socket_errors_socket_closed
|
||||
|
||||
Range = namedtuple('Range', 'start stop size')
|
||||
MULTIPART_SEPARATOR = uuid.uuid4().hex.decode('ascii')
|
||||
|
||||
def header_list_to_file(buf): # {{{
|
||||
buf.append('')
|
||||
return BytesIO(b''.join((x + '\r\n').encode('ascii') for x in buf))
|
||||
# }}}
|
||||
|
||||
def parse_multipart_byterange(buf, content_type): # {{{
|
||||
sep = (content_type.rsplit('=', 1)[-1]).encode('utf-8')
|
||||
ans = []
|
||||
|
||||
def parse_part():
|
||||
line = buf.readline()
|
||||
if not line:
|
||||
raise ValueError('Premature end of message')
|
||||
if not line.startswith(b'--' + sep):
|
||||
raise ValueError('Malformed start of multipart message: %s' % reprlib.repr(line))
|
||||
if line.endswith(b'--'):
|
||||
return None
|
||||
headers = read_headers(buf.readline)
|
||||
cr = headers.get('Content-Range')
|
||||
if not cr:
|
||||
raise ValueError('Missing Content-Range header in sub-part')
|
||||
if not cr.startswith('bytes '):
|
||||
raise ValueError('Malformed Content-Range header in sub-part, no prefix')
|
||||
try:
|
||||
start, stop = map(lambda x: int(x.strip()), cr.partition(' ')[-1].partition('/')[0].partition('-')[::2])
|
||||
except Exception:
|
||||
raise ValueError('Malformed Content-Range header in sub-part, failed to parse byte range')
|
||||
content_length = stop - start + 1
|
||||
ret = buf.read(content_length)
|
||||
if len(ret) != content_length:
|
||||
raise ValueError('Malformed sub-part, length of body not equal to length specified in Content-Range')
|
||||
buf.readline()
|
||||
return (start, ret)
|
||||
while True:
|
||||
data = parse_part()
|
||||
if data is None:
|
||||
break
|
||||
ans.append(data)
|
||||
return ans
|
||||
# }}}
|
||||
|
||||
def parse_if_none_match(val): # {{{
|
||||
return {x.strip() for x in val.split(',')}
|
||||
# }}}
|
||||
|
||||
def acceptable_encoding(val, allowed=frozenset({'gzip'})): # {{{
|
||||
def enc(x):
|
||||
e, r = x.partition(';')[::2]
|
||||
p, v = r.partition('=')[::2]
|
||||
q = 1.0
|
||||
if p == 'q' and v:
|
||||
try:
|
||||
q = float(v)
|
||||
except Exception:
|
||||
pass
|
||||
return e.lower(), q
|
||||
|
||||
emap = dict(enc(x.strip()) for x in val.split(','))
|
||||
acceptable = sorted(set(emap) & allowed, key=emap.__getitem__, reverse=True)
|
||||
if acceptable:
|
||||
return acceptable[0]
|
||||
# }}}
|
||||
|
||||
def get_ranges(headervalue, content_length): # {{{
|
||||
''' Return a list of ranges from the Range header. If this function returns
|
||||
an empty list, it indicates no valid range was found. '''
|
||||
if not headervalue:
|
||||
return None
|
||||
|
||||
result = []
|
||||
try:
|
||||
bytesunit, byteranges = headervalue.split("=", 1)
|
||||
except Exception:
|
||||
return None
|
||||
if bytesunit.strip() != 'bytes':
|
||||
return None
|
||||
|
||||
for brange in byteranges.split(","):
|
||||
start, stop = [x.strip() for x in brange.split("-", 1)]
|
||||
if start:
|
||||
if not stop:
|
||||
stop = content_length - 1
|
||||
try:
|
||||
start, stop = int(start), int(stop)
|
||||
except Exception:
|
||||
continue
|
||||
if start >= content_length:
|
||||
continue
|
||||
if stop < start:
|
||||
continue
|
||||
stop = min(stop, content_length - 1)
|
||||
result.append(Range(start, stop, stop - start + 1))
|
||||
elif stop:
|
||||
# Negative subscript (last N bytes)
|
||||
try:
|
||||
stop = int(stop)
|
||||
except Exception:
|
||||
continue
|
||||
if stop > content_length:
|
||||
result.append(Range(0, content_length-1, content_length))
|
||||
else:
|
||||
result.append(Range(content_length - stop, content_length - 1, stop))
|
||||
|
||||
return result
|
||||
# }}}
|
||||
|
||||
# gzip transfer encoding {{{
|
||||
def gzip_prefix(mtime=None):
|
||||
# See http://www.gzip.org/zlib/rfc-gzip.html
|
||||
if mtime is None:
|
||||
mtime = time.time()
|
||||
return b''.join((
|
||||
b'\x1f\x8b', # ID1 and ID2: gzip marker
|
||||
b'\x08', # CM: compression method
|
||||
b'\x00', # FLG: none set
|
||||
# MTIME: 4 bytes
|
||||
struct.pack(b"<L", int(mtime) & 0xFFFFFFFF),
|
||||
b'\x02', # XFL: max compression, slowest algo
|
||||
b'\xff', # OS: unknown
|
||||
))
|
||||
|
||||
def compress_readable_output(src_file, compress_level=6):
|
||||
crc = zlib.crc32(b"")
|
||||
size = 0
|
||||
zobj = zlib.compressobj(compress_level,
|
||||
zlib.DEFLATED, -zlib.MAX_WBITS,
|
||||
zlib.DEF_MEM_LEVEL, 0)
|
||||
prefix_written = False
|
||||
while True:
|
||||
data = src_file.read(DEFAULT_BUFFER_SIZE)
|
||||
if not data:
|
||||
break
|
||||
size += len(data)
|
||||
crc = zlib.crc32(data, crc)
|
||||
data = zobj.compress(data)
|
||||
if not prefix_written:
|
||||
prefix_written = True
|
||||
data = gzip_prefix(time.time()) + data
|
||||
yield data
|
||||
yield zobj.flush() + struct.pack(b"<L", crc & 0xFFFFFFFF) + struct.pack(b"<L", size & 0xFFFFFFFF)
|
||||
# }}}
|
||||
|
||||
def get_range_parts(ranges, content_type, content_length): # {{{
|
||||
|
||||
def part(r):
|
||||
ans = ['--%s' % MULTIPART_SEPARATOR, 'Content-Range: bytes %d-%d/%d' % (r.start, r.stop, content_length)]
|
||||
if content_type:
|
||||
ans.append('Content-Type: %s' % content_type)
|
||||
ans.append('')
|
||||
return ('\r\n'.join(ans)).encode('ascii')
|
||||
return list(map(part, ranges)) + [('--%s--' % MULTIPART_SEPARATOR).encode('ascii')]
|
||||
# }}}
|
||||
|
||||
class RequestData(object): # {{{
|
||||
|
||||
def __init__(self, method, path, query, inheaders, request_body_file, outheaders, response_protocol, static_cache, opts):
|
||||
self.method, self.path, self.query, self.inheaders, self.request_body_file, self.outheaders, self.response_protocol, self.static_cache = (
|
||||
method, path, query, inheaders, request_body_file, outheaders, response_protocol, static_cache
|
||||
)
|
||||
self.opts = opts
|
||||
self.status_code = httplib.CREATED if self.method == 'POST' else httplib.OK
|
||||
|
||||
def generate_static_output(self, name, generator):
|
||||
ans = self.static_cache.get(name)
|
||||
if ans is None:
|
||||
ans = self.static_cache[name] = StaticOutput(generator())
|
||||
return ans
|
||||
|
||||
def read(self, size=-1):
|
||||
return self.request_body_file.read(size)
|
||||
# }}}
|
||||
|
||||
class ReadableOutput(object):
|
||||
|
||||
def __init__(self, output, etag=None, content_length=None):
|
||||
self.src_file = output
|
||||
if content_length is None:
|
||||
self.src_file.seek(0, os.SEEK_END)
|
||||
self.content_length = self.src_file.tell()
|
||||
else:
|
||||
self.content_length = content_length
|
||||
self.etag = etag
|
||||
self.accept_ranges = True
|
||||
self.use_sendfile = False
|
||||
self.src_file.seek(0)
|
||||
|
||||
def filesystem_file_output(output, outheaders, stat_result):
|
||||
etag = '"%s"' % hashlib.sha1(type('')(stat_result.st_mtime) + force_unicode(output.name or '')).hexdigest()
|
||||
self = ReadableOutput(output, etag=etag, content_length=stat_result.st_size)
|
||||
self.name = output.name
|
||||
self.use_sendfile = True
|
||||
return self
|
||||
|
||||
def dynamic_output(output, outheaders):
|
||||
if isinstance(output, bytes):
|
||||
data = output
|
||||
else:
|
||||
data = output.encode('utf-8')
|
||||
ct = outheaders.get('Content-Type')
|
||||
if not ct:
|
||||
outheaders.set('Content-Type', 'text/plain; charset=UTF-8', replace_all=True)
|
||||
ans = ReadableOutput(BytesIO(data))
|
||||
ans.accept_ranges = False
|
||||
return ans
|
||||
|
||||
class GeneratedOutput(object):
|
||||
|
||||
def __init__(self, output, etag=None):
|
||||
self.output = output
|
||||
self.content_length = None
|
||||
self.etag = etag
|
||||
self.accept_ranges = False
|
||||
|
||||
class StaticOutput(object):
|
||||
|
||||
def __init__(self, data):
|
||||
if isinstance(data, type('')):
|
||||
data = data.encode('utf-8')
|
||||
self.data = data
|
||||
self.etag = '"%s"' % hashlib.sha1(data).hexdigest()
|
||||
self.content_length = len(data)
|
||||
|
||||
class HTTPConnection(HTTPRequest):
|
||||
|
||||
def write(self, buf, end=None):
|
||||
pos = buf.tell()
|
||||
if end is None:
|
||||
buf.seek(0, os.SEEK_END)
|
||||
end = buf.tell()
|
||||
buf.seek(pos)
|
||||
limit = end - pos
|
||||
if limit == 0:
|
||||
return True
|
||||
if self.use_sendfile and not isinstance(buf, BytesIO):
|
||||
try:
|
||||
sent = sendfile_to_socket_async(buf, pos, limit, self.socket)
|
||||
except CannotSendfile:
|
||||
self.use_sendfile = False
|
||||
return False
|
||||
except SendfileInterrupted:
|
||||
return False
|
||||
except IOError as e:
|
||||
if e.errno in socket_errors_socket_closed:
|
||||
self.ready = self.use_sendfile = False
|
||||
return False
|
||||
raise
|
||||
if sent == 0:
|
||||
# Something bad happened, was the file modified on disk by
|
||||
# another process?
|
||||
self.use_sendfile = self.ready = False
|
||||
raise IOError('sendfile() failed to write any bytes to the socket')
|
||||
else:
|
||||
sent = self.send(buf.read(min(limit, PIPE_BUF)))
|
||||
buf.seek(pos + sent)
|
||||
return buf.tell() == end
|
||||
|
||||
def simple_response(self, status_code, msg='', close_after_response=True):
|
||||
if self.response_protocol is HTTP1 and status_code in (httplib.REQUEST_ENTITY_TOO_LARGE, httplib.REQUEST_URI_TOO_LONG):
|
||||
# HTTP/1.0 has no 413/414 codes
|
||||
status_code = httplib.BAD_REQUEST
|
||||
self.close_after_response = close_after_response
|
||||
msg = msg.encode('utf-8')
|
||||
ct = 'http' if self.method == 'TRACE' else 'plain'
|
||||
buf = [
|
||||
'%s %d %s' % (self.response_protocol, status_code, httplib.responses[status_code]),
|
||||
"Content-Length: %s" % len(msg),
|
||||
"Content-Type: text/%s; charset=UTF-8" % ct,
|
||||
"Date: " + http_date(),
|
||||
]
|
||||
if self.close_after_response and self.response_protocol is HTTP11:
|
||||
buf.append("Connection: close")
|
||||
buf.append('')
|
||||
buf = [(x + '\r\n').encode('ascii') for x in buf]
|
||||
if self.method != 'HEAD':
|
||||
buf.append(msg)
|
||||
self.response_ready(BytesIO(b''.join(buf)))
|
||||
|
||||
def prepare_response(self, inheaders, request_body_file):
|
||||
if self.method == 'TRACE':
|
||||
msg = force_unicode(self.request_line, 'utf-8') + '\n' + inheaders.pretty()
|
||||
return self.simple_response(httplib.OK, msg, close_after_response=False)
|
||||
request_body_file.seek(0)
|
||||
outheaders = MultiDict()
|
||||
data = RequestData(self.method, self.path, self.query, inheaders, request_body_file, outheaders, self.response_protocol, self.static_cache, self.opts)
|
||||
try:
|
||||
output = self.request_handler(data)
|
||||
except HTTP404 as e:
|
||||
return self.simple_response(httplib.NOT_FOUND, msg=e.message or '', close_after_response=False)
|
||||
|
||||
output = self.finalize_output(output, data, self.method is HTTP1)
|
||||
if output is None:
|
||||
return
|
||||
|
||||
outheaders.set('Date', http_date(), replace_all=True)
|
||||
outheaders.set('Server', 'calibre %s' % __version__, replace_all=True)
|
||||
keep_alive = not self.close_after_response and self.opts.timeout > 0
|
||||
if keep_alive:
|
||||
outheaders.set('Keep-Alive', 'timeout=%d' % self.opts.timeout)
|
||||
if 'Connection' not in outheaders:
|
||||
if self.response_protocol is HTTP11:
|
||||
if self.close_after_response:
|
||||
outheaders.set('Connection', 'close')
|
||||
else:
|
||||
if not self.close_after_response:
|
||||
outheaders.set('Connection', 'Keep-Alive')
|
||||
|
||||
ct = outheaders.get('Content-Type', '')
|
||||
if ct.startswith('text/') and 'charset=' not in ct:
|
||||
outheaders.set('Content-Type', ct + '; charset=UTF-8')
|
||||
|
||||
buf = [HTTP11 + (' %d ' % data.status_code) + httplib.responses[data.status_code]]
|
||||
for header, value in sorted(outheaders.iteritems(), key=itemgetter(0)):
|
||||
buf.append('%s: %s' % (header, value))
|
||||
buf.append('')
|
||||
self.response_ready(BytesIO(b''.join((x + '\r\n').encode('ascii') for x in buf)), output=output)
|
||||
|
||||
def send_range_not_satisfiable(self, content_length):
|
||||
buf = [
|
||||
'%s %d %s' % (self.response_protocol, httplib.REQUESTED_RANGE_NOT_SATISFIABLE, httplib.responses[httplib.REQUESTED_RANGE_NOT_SATISFIABLE]),
|
||||
"Date: " + http_date(),
|
||||
"Content-Range: bytes */%d" % content_length,
|
||||
]
|
||||
self.response_ready(header_list_to_file(buf))
|
||||
|
||||
def send_not_modified(self, etag=None):
|
||||
buf = [
|
||||
'%s %d %s' % (self.response_protocol, httplib.NOT_MODIFIED, httplib.responses[httplib.NOT_MODIFIED]),
|
||||
"Content-Length: 0",
|
||||
"Date: " + http_date(),
|
||||
]
|
||||
if etag is not None:
|
||||
buf.append('ETag: ' + etag)
|
||||
self.response_ready(header_list_to_file(buf))
|
||||
|
||||
def response_ready(self, header_file, output=None):
|
||||
self.response_started = True
|
||||
start_cork(self.socket)
|
||||
self.corked = True
|
||||
self.use_sendfile = False
|
||||
self.set_state(WRITE, self.write_response_headers, header_file, output)
|
||||
|
||||
def write_response_headers(self, buf, output, event):
|
||||
if self.write(buf):
|
||||
self.write_response_body(output)
|
||||
|
||||
def write_response_body(self, output):
|
||||
if output is None or self.method == 'HEAD':
|
||||
self.reset_state()
|
||||
return
|
||||
if isinstance(output, ReadableOutput):
|
||||
self.use_sendfile = output.use_sendfile and self.opts.use_sendfile and sendfile_to_socket_async is not None
|
||||
if output.ranges is not None:
|
||||
if isinstance(output.ranges, Range):
|
||||
r = output.ranges
|
||||
output.src_file.seek(r.start)
|
||||
self.set_state(WRITE, self.write_buf, output.src_file, end=r.stop + 1)
|
||||
else:
|
||||
self.set_state(WRITE, self.write_ranges, output.src_file, output.ranges, first=True)
|
||||
else:
|
||||
self.set_state(WRITE, self.write_buf, output.src_file)
|
||||
elif isinstance(output, GeneratedOutput):
|
||||
self.set_state(WRITE, self.write_iter, chain(output.output, repeat(None, 1)))
|
||||
else:
|
||||
raise TypeError('Unknown output type: %r' % output)
|
||||
|
||||
def write_buf(self, buf, event, end=None):
|
||||
if self.write(buf, end=end):
|
||||
self.reset_state()
|
||||
|
||||
def write_ranges(self, buf, ranges, event, first=False):
|
||||
r, range_part = next(ranges)
|
||||
if r is None:
|
||||
# EOF range part
|
||||
self.set_state(WRITE, self.write_buf, BytesIO(b'\r\n' + range_part))
|
||||
else:
|
||||
buf.seek(r.start)
|
||||
self.set_state(WRITE, self.write_range_part, BytesIO((b'' if first else b'\r\n') + range_part + b'\r\n'), buf, r.stop + 1, ranges)
|
||||
|
||||
def write_range_part(self, part_buf, buf, end, ranges, event):
|
||||
if self.write(part_buf):
|
||||
self.set_state(WRITE, self.write_range, buf, end, ranges)
|
||||
|
||||
def write_range(self, buf, end, ranges, event):
|
||||
if self.write(buf, end=end):
|
||||
self.set_state(WRITE, self.write_ranges, buf, ranges)
|
||||
|
||||
def write_iter(self, output, event):
|
||||
chunk = next(output)
|
||||
if chunk is None:
|
||||
self.set_state(WRITE, self.write_chunk, BytesIO(b'0\r\n\r\n'), output, last=True)
|
||||
else:
|
||||
if not isinstance(chunk, bytes):
|
||||
chunk = chunk.encode('utf-8')
|
||||
chunk = ('%X\r\n' % len(chunk)).encode('ascii') + chunk + b'\r\n'
|
||||
self.set_state(WRITE, self.write_chunk, BytesIO(chunk), output)
|
||||
|
||||
def write_chunk(self, buf, output, event, last=False):
|
||||
if self.write(buf):
|
||||
if last:
|
||||
self.reset_state()
|
||||
else:
|
||||
self.set_state(WRITE, self.write_iter, output)
|
||||
|
||||
def reset_state(self):
|
||||
self.connection_ready()
|
||||
self.ready = not self.close_after_response
|
||||
stop_cork(self.socket)
|
||||
self.corked = False
|
||||
|
||||
def report_unhandled_exception(self, e, formatted_traceback):
|
||||
self.simple_response(httplib.INTERNAL_SERVER_ERROR)
|
||||
|
||||
def finalize_output(self, output, request, is_http1):
|
||||
opts = self.opts
|
||||
outheaders = request.outheaders
|
||||
stat_result = file_metadata(output)
|
||||
if stat_result is not None:
|
||||
output = filesystem_file_output(output, outheaders, stat_result)
|
||||
if 'Content-Type' not in outheaders:
|
||||
mt = guess_type(output.name)[0]
|
||||
if mt:
|
||||
if mt in {'text/plain', 'text/html', 'application/javascript', 'text/css'}:
|
||||
mt =+ '; charset=UTF-8'
|
||||
outheaders['Content-Type'] = mt
|
||||
elif isinstance(output, (bytes, type(''))):
|
||||
output = dynamic_output(output, outheaders)
|
||||
elif hasattr(output, 'read'):
|
||||
output = ReadableOutput(output)
|
||||
elif isinstance(output, StaticOutput):
|
||||
output = ReadableOutput(BytesIO(output.data), etag=output.etag, content_length=output.content_length)
|
||||
else:
|
||||
output = GeneratedOutput(output)
|
||||
ct = outheaders.get('Content-Type', '').partition(';')[0]
|
||||
compressible = (not ct or ct.startswith('text/') or ct.startswith('image/svg') or
|
||||
ct in {'application/json', 'application/javascript'})
|
||||
compressible = (compressible and request.status_code == httplib.OK and
|
||||
(opts.compress_min_size > -1 and output.content_length >= opts.compress_min_size) and
|
||||
acceptable_encoding(request.inheaders.get('Accept-Encoding', '')) and not is_http1)
|
||||
accept_ranges = (not compressible and output.accept_ranges is not None and request.status_code == httplib.OK and
|
||||
not is_http1)
|
||||
ranges = get_ranges(request.inheaders.get('Range'), output.content_length) if output.accept_ranges and self.method in ('GET', 'HEAD') else None
|
||||
if_range = (request.inheaders.get('If-Range') or '').strip()
|
||||
if if_range and if_range != output.etag:
|
||||
ranges = None
|
||||
if ranges is not None and not ranges:
|
||||
return self.send_range_not_satisfiable(output.content_length)
|
||||
|
||||
for header in ('Accept-Ranges', 'Content-Encoding', 'Transfer-Encoding', 'ETag', 'Content-Length'):
|
||||
outheaders.pop('header', all=True)
|
||||
|
||||
none_match = parse_if_none_match(request.inheaders.get('If-None-Match', ''))
|
||||
matched = '*' in none_match or (output.etag and output.etag in none_match)
|
||||
if matched:
|
||||
if self.method in ('GET', 'HEAD'):
|
||||
self.send_not_modified(output.etag)
|
||||
else:
|
||||
self.simple_response(httplib.PRECONDITION_FAILED)
|
||||
return
|
||||
|
||||
output.ranges = None
|
||||
|
||||
if output.etag and self.method in ('GET', 'HEAD'):
|
||||
outheaders.set('ETag', output.etag, replace_all=True)
|
||||
if accept_ranges:
|
||||
outheaders.set('Accept-Ranges', 'bytes', replace_all=True)
|
||||
if compressible and not ranges:
|
||||
outheaders.set('Content-Encoding', 'gzip', replace_all=True)
|
||||
output = GeneratedOutput(compress_readable_output(output.src_file), etag=output.etag)
|
||||
if output.content_length is not None and not compressible and not ranges:
|
||||
outheaders.set('Content-Length', '%d' % output.content_length, replace_all=True)
|
||||
|
||||
if compressible or output.content_length is None:
|
||||
outheaders.set('Transfer-Encoding', 'chunked', replace_all=True)
|
||||
|
||||
if ranges:
|
||||
if len(ranges) == 1:
|
||||
r = ranges[0]
|
||||
outheaders.set('Content-Length', '%d' % r.size, replace_all=True)
|
||||
outheaders.set('Content-Range', 'bytes %d-%d/%d' % (r.start, r.stop, output.content_length), replace_all=True)
|
||||
output.ranges = r
|
||||
else:
|
||||
range_parts = get_range_parts(ranges, outheaders.get('Content-Type'), output.content_length)
|
||||
size = sum(map(len, range_parts)) + sum(r.size + 4 for r in ranges)
|
||||
outheaders.set('Content-Length', '%d' % size, replace_all=True)
|
||||
outheaders.set('Content-Type', 'multipart/byteranges; boundary=' + MULTIPART_SEPARATOR, replace_all=True)
|
||||
output.ranges = izip_longest(ranges, range_parts)
|
||||
request.status_code = httplib.PARTIAL_CONTENT
|
||||
return output
|
||||
|
||||
def create_http_handler(handler):
|
||||
static_cache = {} # noqa
|
||||
@wraps(handler)
|
||||
def wrapper(*args, **kwargs):
|
||||
ans = HTTPConnection(*args, **kwargs)
|
||||
ans.request_handler = handler
|
||||
ans.static_cache = {}
|
||||
return ans
|
||||
return wrapper
|
@ -6,510 +6,132 @@ from __future__ import (unicode_literals, division, absolute_import,
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
import socket, os, ssl, time, sys
|
||||
from operator import and_
|
||||
from Queue import Queue, Full
|
||||
from threading import Thread, current_thread, Lock
|
||||
from io import DEFAULT_BUFFER_SIZE, BytesIO
|
||||
import ssl, socket, select, os, traceback
|
||||
from io import BytesIO
|
||||
from functools import partial
|
||||
|
||||
from calibre.srv.errors import MaxSizeExceeded
|
||||
from calibre import as_unicode
|
||||
from calibre.ptempfile import TemporaryDirectory
|
||||
from calibre.srv.opts import Options
|
||||
from calibre.srv.utils import socket_errors_to_ignore, socket_error_eintr, socket_errors_nonblocking, Corked, HandleInterrupt
|
||||
from calibre.srv.utils import (
|
||||
socket_errors_socket_closed, socket_errors_nonblocking, HandleInterrupt, socket_errors_eintr)
|
||||
from calibre.utils.socket_inheritance import set_socket_inherit
|
||||
from calibre.utils.logging import ThreadSafeLog
|
||||
from calibre.utils.monotonic import monotonic
|
||||
|
||||
class SocketFile(object): # {{{
|
||||
"""Faux file object attached to a socket object. Works with non-blocking
|
||||
sockets, unlike the fileobject created by socket.makefile() """
|
||||
READ, WRITE, RDWR = 'READ', 'WRITE', 'RDWR'
|
||||
|
||||
name = "<socket>"
|
||||
class Connection(object):
|
||||
|
||||
__slots__ = (
|
||||
"mode", "bufsize", "softspace", "_sock", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf", "_wbuf_len", "_close", 'bytes_read', 'bytes_written',
|
||||
)
|
||||
|
||||
def __init__(self, sock, bufsize=-1, close=False):
|
||||
self._sock = sock
|
||||
self.bytes_read = self.bytes_written = 0
|
||||
self.mode = 'r+b'
|
||||
self.bufsize = DEFAULT_BUFFER_SIZE if bufsize < 0 else bufsize
|
||||
self.softspace = False
|
||||
# _rbufsize is the suggested recv buffer size. It is *strictly*
|
||||
# obeyed within readline() for recv calls. If it is larger than
|
||||
# default_bufsize it will be used for recv calls within read().
|
||||
if self.bufsize == 0:
|
||||
self._rbufsize = 1
|
||||
elif bufsize == 1:
|
||||
self._rbufsize = DEFAULT_BUFFER_SIZE
|
||||
def __init__(self, socket, opts, ssl_context, tdir):
|
||||
self.opts = opts
|
||||
self.tdir = tdir
|
||||
self.ssl_context = ssl_context
|
||||
self.wait_for = READ
|
||||
self.response_started = False
|
||||
if self.ssl_context is not None:
|
||||
self.ready = False
|
||||
self.socket = self.ssl_context.wrap_socket(socket, server_side=True, do_handshake_on_connect=False)
|
||||
self.set_state(RDWR, self.do_ssl_handshake)
|
||||
else:
|
||||
self._rbufsize = bufsize
|
||||
self._wbufsize = bufsize
|
||||
# We use BytesIO for the read buffer to avoid holding a list
|
||||
# of variously sized string objects which have been known to
|
||||
# fragment the heap due to how they are malloc()ed and often
|
||||
# realloc()ed down much smaller than their original allocation.
|
||||
self._rbuf = BytesIO()
|
||||
self._wbuf = [] # A list of strings
|
||||
self._wbuf_len = 0
|
||||
self._close = close
|
||||
self.ready = True
|
||||
self.socket = socket
|
||||
self.connection_ready()
|
||||
self.last_activity = monotonic()
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
return self._sock is None
|
||||
def set_state(self, wait_for, func, *args, **kwargs):
|
||||
self.wait_for = wait_for
|
||||
if args or kwargs:
|
||||
func = partial(func, *args, **kwargs)
|
||||
self.handle_event = func
|
||||
|
||||
def do_ssl_handshake(self, event):
|
||||
try:
|
||||
self._sslobj.do_handshake()
|
||||
except ssl.SSLWantReadError:
|
||||
self.set_state(READ, self.do_ssl_handshake)
|
||||
except ssl.SSLWantWriteError:
|
||||
self.set_state(WRITE, self.do_ssl_handshake)
|
||||
self.ready = True
|
||||
self.connection_ready()
|
||||
|
||||
def send(self, data):
|
||||
try:
|
||||
ret = self.socket.send(data)
|
||||
self.last_activity = monotonic()
|
||||
return ret
|
||||
except socket.error as e:
|
||||
if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr:
|
||||
return 0
|
||||
elif e.errno in socket_errors_socket_closed:
|
||||
self.ready = False
|
||||
return 0
|
||||
raise
|
||||
|
||||
def recv(self, buffer_size):
|
||||
try:
|
||||
data = self.socket.recv(buffer_size)
|
||||
self.last_activity = monotonic()
|
||||
if not data:
|
||||
# a closed connection is indicated by signaling
|
||||
# a read condition, and having recv() return 0.
|
||||
self.ready = False
|
||||
return b''
|
||||
return data
|
||||
except socket.error as e:
|
||||
if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr:
|
||||
return b''
|
||||
if e.errno in socket_errors_socket_closed:
|
||||
self.ready = False
|
||||
return b''
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
if self._sock is not None:
|
||||
try:
|
||||
self.flush()
|
||||
except socket.error:
|
||||
pass
|
||||
finally:
|
||||
if self._close and self._sock is not None:
|
||||
self._sock.close()
|
||||
self._sock = None
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.close()
|
||||
except:
|
||||
# close() may fail if __init__ didn't complete
|
||||
pass
|
||||
|
||||
def fileno(self):
|
||||
return self._sock.fileno()
|
||||
|
||||
def gettimeout(self):
|
||||
return self._sock.gettimeout()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
def flush(self):
|
||||
if self._wbuf_len:
|
||||
data = b''.join(self._wbuf)
|
||||
self._wbuf = []
|
||||
self._wbuf_len = 0
|
||||
data_size = len(data)
|
||||
view = memoryview(data)
|
||||
write_offset = 0
|
||||
buffer_size = max(self._rbufsize, DEFAULT_BUFFER_SIZE)
|
||||
try:
|
||||
while write_offset < data_size:
|
||||
try:
|
||||
bytes_sent = self._sock.send(view[write_offset:write_offset+buffer_size])
|
||||
write_offset += bytes_sent
|
||||
self.bytes_written += bytes_sent
|
||||
except socket.error as e:
|
||||
if e.args[0] not in socket_errors_nonblocking:
|
||||
raise
|
||||
finally:
|
||||
if write_offset < data_size:
|
||||
remainder = data[write_offset:]
|
||||
self._wbuf.append(remainder)
|
||||
self._wbuf_len = len(remainder)
|
||||
del view, data # explicit free
|
||||
|
||||
def write(self, data):
|
||||
if not isinstance(data, bytes):
|
||||
raise TypeError('Cannot write data of type: %s to a socket' % type(data))
|
||||
if not data:
|
||||
return
|
||||
self._wbuf.append(data)
|
||||
self._wbuf_len += len(data)
|
||||
if self._wbufsize == 0 or (self._wbufsize == 1 and b'\n' in data) or (self._wbufsize > 1 and self._wbuf_len >= self._wbufsize):
|
||||
self.flush()
|
||||
|
||||
def writelines(self, lines):
|
||||
for line in lines:
|
||||
self.write(line)
|
||||
|
||||
def recv(self, size):
|
||||
while True:
|
||||
try:
|
||||
data = self._sock.recv(size)
|
||||
self.bytes_read += len(data)
|
||||
return data
|
||||
except socket.error, e:
|
||||
if e.args[0] not in socket_errors_nonblocking and e.args[0] not in socket_error_eintr:
|
||||
raise
|
||||
|
||||
def read(self, size=-1):
|
||||
# Use max, disallow tiny reads in a loop as they are very inefficient.
|
||||
# We never leave read() with any leftover data from a new recv() call
|
||||
# in our internal buffer.
|
||||
rbufsize = max(self._rbufsize, DEFAULT_BUFFER_SIZE)
|
||||
buf = self._rbuf
|
||||
buf.seek(0, os.SEEK_END)
|
||||
if size < 0:
|
||||
# Read until EOF
|
||||
self._rbuf = BytesIO() # reset _rbuf. we consume it via buf.
|
||||
while True:
|
||||
data = self.recv(rbufsize)
|
||||
if not data:
|
||||
break
|
||||
buf.write(data)
|
||||
return buf.getvalue()
|
||||
else:
|
||||
# Read until size bytes or EOF seen, whichever comes first
|
||||
buf_len = buf.tell()
|
||||
if buf_len >= size:
|
||||
# Already have size bytes in our buffer? Extract and return.
|
||||
buf.seek(0)
|
||||
rv = buf.read(size)
|
||||
self._rbuf = BytesIO()
|
||||
self._rbuf.write(buf.read())
|
||||
return rv
|
||||
|
||||
self._rbuf = BytesIO() # reset _rbuf. we consume it via buf.
|
||||
while True:
|
||||
left = size - buf_len
|
||||
# recv() will malloc the amount of memory given as its
|
||||
# parameter even though it often returns much less data
|
||||
# than that. The returned data string is short lived
|
||||
# as we copy it into a StringIO and free it. This avoids
|
||||
# fragmentation issues on many platforms.
|
||||
data = self.recv(left)
|
||||
if not data:
|
||||
break
|
||||
n = len(data)
|
||||
if n == size and not buf_len:
|
||||
# Shortcut. Avoid buffer data copies when:
|
||||
# - We have no data in our buffer.
|
||||
# AND
|
||||
# - Our call to recv returned exactly the
|
||||
# number of bytes we were asked to read.
|
||||
return data
|
||||
if n == left:
|
||||
buf.write(data)
|
||||
del data # explicit free
|
||||
break
|
||||
buf.write(data) # noqa
|
||||
buf_len += n
|
||||
del data # noqa explicit free
|
||||
return buf.getvalue()
|
||||
|
||||
def readline(self, size=-1, maxsize=sys.maxsize):
|
||||
buf = self._rbuf
|
||||
buf.seek(0, os.SEEK_END)
|
||||
if buf.tell() > 0:
|
||||
# check if we already have it in our buffer
|
||||
buf.seek(0)
|
||||
bline = buf.readline(size)
|
||||
self._rbuf = BytesIO()
|
||||
if bline.endswith(b'\n') or len(bline) == size:
|
||||
self._rbuf.write(buf.read())
|
||||
if len(bline) > maxsize:
|
||||
raise MaxSizeExceeded('Line length', len(bline), maxsize)
|
||||
return bline
|
||||
else:
|
||||
self._rbuf.write(bline)
|
||||
self._rbuf.write(buf.read())
|
||||
del bline
|
||||
|
||||
if size < 0:
|
||||
# Read until \n or EOF, whichever comes first
|
||||
if self._rbufsize <= 1:
|
||||
# Speed up unbuffered case
|
||||
buf.seek(0)
|
||||
buffers = [buf.read()]
|
||||
self._rbuf = BytesIO() # reset _rbuf. we consume it via buf.
|
||||
data = None
|
||||
recv = self.recv
|
||||
sz = len(buffers[0])
|
||||
while data != b'\n':
|
||||
data = recv(1)
|
||||
if not data:
|
||||
break
|
||||
sz += 1
|
||||
if sz > maxsize:
|
||||
raise MaxSizeExceeded('Line length', sz, maxsize)
|
||||
buffers.append(data)
|
||||
return b''.join(buffers)
|
||||
|
||||
buf.seek(0, os.SEEK_END)
|
||||
self._rbuf = BytesIO() # reset _rbuf. we consume it via buf.
|
||||
while True:
|
||||
data = self.recv(self._rbufsize)
|
||||
if not data:
|
||||
break
|
||||
nl = data.find(b'\n')
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
buf.write(data[:nl])
|
||||
self._rbuf.write(data[nl:])
|
||||
del data
|
||||
break
|
||||
buf.write(data) # noqa
|
||||
if buf.tell() > maxsize:
|
||||
raise MaxSizeExceeded('Line length', buf.tell(), maxsize)
|
||||
return buf.getvalue()
|
||||
else:
|
||||
# Read until size bytes or \n or EOF seen, whichever comes first
|
||||
buf.seek(0, os.SEEK_END)
|
||||
buf_len = buf.tell()
|
||||
if buf_len >= size:
|
||||
buf.seek(0)
|
||||
rv = buf.read(size)
|
||||
self._rbuf = BytesIO()
|
||||
self._rbuf.write(buf.read())
|
||||
if len(rv) > maxsize:
|
||||
raise MaxSizeExceeded('Line length', len(rv), maxsize)
|
||||
return rv
|
||||
self._rbuf = BytesIO() # reset _rbuf. we consume it via buf.
|
||||
while True:
|
||||
data = self.recv(self._rbufsize)
|
||||
if not data:
|
||||
break
|
||||
left = size - buf_len
|
||||
# did we just receive a newline?
|
||||
nl = data.find(b'\n', 0, left)
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
# save the excess data to _rbuf
|
||||
self._rbuf.write(data[nl:])
|
||||
if buf_len:
|
||||
buf.write(data[:nl])
|
||||
break
|
||||
else:
|
||||
# Shortcut. Avoid data copy through buf when returning
|
||||
# a substring of our first recv() and buf has no
|
||||
# existing data.
|
||||
if nl > maxsize:
|
||||
raise MaxSizeExceeded('Line length', nl, maxsize)
|
||||
return data[:nl]
|
||||
n = len(data)
|
||||
if n == size and not buf_len:
|
||||
# Shortcut. Avoid data copy through buf when
|
||||
# returning exactly all of our first recv().
|
||||
if n > maxsize:
|
||||
raise MaxSizeExceeded('Line length', n, maxsize)
|
||||
return data
|
||||
if n >= left:
|
||||
buf.write(data[:left])
|
||||
self._rbuf.write(data[left:])
|
||||
break
|
||||
buf.write(data)
|
||||
buf_len += n
|
||||
if buf.tell() > maxsize:
|
||||
raise MaxSizeExceeded('Line length', buf.tell(), maxsize)
|
||||
return buf.getvalue()
|
||||
|
||||
def readlines(self, sizehint=0, maxsize=sys.maxsize):
|
||||
total = 0
|
||||
ans = []
|
||||
while True:
|
||||
line = self.readline(maxsize=maxsize)
|
||||
if not line:
|
||||
break
|
||||
ans.append(line)
|
||||
total += len(line)
|
||||
if sizehint and total >= sizehint:
|
||||
break
|
||||
return ans
|
||||
|
||||
def __iter__(self):
|
||||
line = True
|
||||
while line:
|
||||
line = self.readline()
|
||||
if line:
|
||||
yield line
|
||||
|
||||
# }}}
|
||||
|
||||
class Connection(object): # {{{
|
||||
|
||||
' A thin wrapper around an active socket '
|
||||
|
||||
remote_addr = None
|
||||
remote_port = None
|
||||
|
||||
def __init__(self, server_loop, socket):
|
||||
self.server_loop = server_loop
|
||||
self.socket = socket
|
||||
self.corked = Corked(socket)
|
||||
self.socket_file = SocketFile(socket)
|
||||
self.closed = False
|
||||
|
||||
def close(self):
|
||||
"""Close the socket underlying this connection."""
|
||||
if self.closed:
|
||||
return
|
||||
self.socket_file.close()
|
||||
self.ready = False
|
||||
self.handle_event = None # prevent reference cycles
|
||||
try:
|
||||
self.socket.shutdown(socket.SHUT_WR)
|
||||
self.socket.close()
|
||||
except socket.error:
|
||||
pass
|
||||
self.closed = True
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
# }}}
|
||||
|
||||
class WorkerThread(Thread): # {{{
|
||||
|
||||
daemon = True
|
||||
|
||||
def __init__(self, server_loop):
|
||||
self.serving = False
|
||||
self.server_loop = server_loop
|
||||
self.conn = None
|
||||
self.forcible_shutdown = False
|
||||
Thread.__init__(self, name='ServerWorker')
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
while True:
|
||||
self.serving = False
|
||||
self.conn = conn = self.server_loop.requests.get()
|
||||
if conn is None:
|
||||
return # Clean exit
|
||||
with conn, self:
|
||||
self.server_loop.req_resp_handler(conn)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
self.server_loop.stop()
|
||||
except socket.error:
|
||||
if not self.forcible_shutdown:
|
||||
raise
|
||||
|
||||
def __enter__(self):
|
||||
self.serving = True
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.serving = False
|
||||
# }}}
|
||||
|
||||
class ThreadPool(object): # {{{
|
||||
|
||||
def __init__(self, server_loop, min_threads=10, max_threads=-1, accepted_queue_size=-1, accepted_queue_timeout=10):
|
||||
self.server_loop = server_loop
|
||||
self.min_threads = max(1, min_threads)
|
||||
self.max_threads = max_threads
|
||||
self._threads = []
|
||||
self._queue = Queue(maxsize=accepted_queue_size)
|
||||
self._queue_put_timeout = accepted_queue_timeout
|
||||
self.get = self._queue.get
|
||||
|
||||
def start(self):
|
||||
"""Start the pool of threads."""
|
||||
self._threads = [self._spawn_worker() for i in xrange(self.min_threads)]
|
||||
|
||||
@property
|
||||
def idle(self):
|
||||
return sum(int(not w.serving) for w in self._threads)
|
||||
def state_description(self):
|
||||
return ''
|
||||
|
||||
@property
|
||||
def busy(self):
|
||||
return sum(int(w.serving) for w in self._threads)
|
||||
def report_unhandled_exception(self, e, formatted_traceback):
|
||||
pass
|
||||
|
||||
def put(self, obj):
|
||||
self._queue.put(obj, block=True, timeout=self._queue_put_timeout)
|
||||
def connection_ready(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def grow(self, amount):
|
||||
"""Spawn new worker threads (not above self.max_threads)."""
|
||||
budget = max(self.max_threads - len(self._threads), 0) if self.max_threads > 0 else sys.maxsize
|
||||
n_new = min(amount, budget)
|
||||
self._threads.extend([self._spawn_worker() for i in xrange(n_new)])
|
||||
|
||||
def _spawn_worker(self):
|
||||
worker = WorkerThread(self.server_loop)
|
||||
worker.start()
|
||||
return worker
|
||||
|
||||
@staticmethod
|
||||
def _all(func, items):
|
||||
results = [func(item) for item in items]
|
||||
return reduce(and_, results, True)
|
||||
|
||||
def shrink(self, amount):
|
||||
"""Kill off worker threads (not below self.min_threads)."""
|
||||
# Grow/shrink the pool if necessary.
|
||||
# Remove any dead threads from our list
|
||||
orig = len(self._threads)
|
||||
self._threads = [t for t in self._threads if t.is_alive()]
|
||||
amount -= orig - len(self._threads)
|
||||
|
||||
# calculate the number of threads above the minimum
|
||||
n_extra = max(len(self._threads) - self.min_threads, 0)
|
||||
|
||||
# don't remove more than amount
|
||||
n_to_remove = min(amount, n_extra)
|
||||
|
||||
# put shutdown requests on the queue equal to the number of threads
|
||||
# to remove. As each request is processed by a worker, that worker
|
||||
# will terminate and be culled from the list.
|
||||
for n in xrange(n_to_remove):
|
||||
self._queue.put(None)
|
||||
|
||||
def stop(self, timeout=5):
|
||||
# Must shut down threads here so the code that calls
|
||||
# this method can know when all threads are stopped.
|
||||
for worker in self._threads:
|
||||
self._queue.put(None)
|
||||
|
||||
# Don't join the current thread (this should never happen, since
|
||||
# ServerLoop calls stop() in its own thread, but better to be safe).
|
||||
current = current_thread()
|
||||
if timeout and timeout >= 0:
|
||||
endtime = time.time() + timeout
|
||||
while self._threads:
|
||||
worker = self._threads.pop()
|
||||
if worker is not current and worker.is_alive():
|
||||
try:
|
||||
if timeout is None or timeout < 0:
|
||||
worker.join()
|
||||
else:
|
||||
remaining_time = endtime - time.time()
|
||||
if remaining_time > 0:
|
||||
worker.join(remaining_time)
|
||||
if worker.is_alive():
|
||||
# We exhausted the timeout.
|
||||
# Forcibly shut down the socket.
|
||||
worker.forcible_shutdown = True
|
||||
c = worker.conn
|
||||
if c and not c.socket_file.closed:
|
||||
c.socket.shutdown(socket.SHUT_RDWR)
|
||||
c.socket.close()
|
||||
worker.join()
|
||||
except KeyboardInterrupt:
|
||||
pass # Ignore repeated Ctrl-C.
|
||||
|
||||
@property
|
||||
def qsize(self):
|
||||
return self._queue.qsize()
|
||||
# }}}
|
||||
def handle_timeout(self):
|
||||
return False
|
||||
|
||||
class ServerLoop(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
req_resp_handler,
|
||||
handler,
|
||||
bind_address=('localhost', 8080),
|
||||
opts=None,
|
||||
# A calibre logging object. If None a default log that logs to
|
||||
# A calibre logging object. If None, a default log that logs to
|
||||
# stdout is used
|
||||
log=None
|
||||
):
|
||||
self.ready = False
|
||||
self.handler = handler
|
||||
self.opts = opts or Options()
|
||||
self.req_resp_handler = req_resp_handler
|
||||
self.log = log or ThreadSafeLog(level=ThreadSafeLog.DEBUG)
|
||||
self.gso_cache, self.gso_lock = {}, Lock()
|
||||
ba = bind_address
|
||||
if not isinstance(ba, basestring):
|
||||
ba = tuple(ba)
|
||||
if not ba[0]:
|
||||
# AI_PASSIVE does not work with host of '' or None
|
||||
ba = ('0.0.0.0', ba[1])
|
||||
|
||||
ba = tuple(bind_address)
|
||||
if not ba[0]:
|
||||
# AI_PASSIVE does not work with host of '' or None
|
||||
ba = ('0.0.0.0', ba[1])
|
||||
self.bind_address = ba
|
||||
self.bound_address = None
|
||||
self.connection_map = {}
|
||||
|
||||
self.ssl_context = None
|
||||
if self.opts.ssl_certfile is not None and self.opts.ssl_keyfile is not None:
|
||||
self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
@ -523,51 +145,33 @@ class ServerLoop(object):
|
||||
set_socket_inherit(self.pre_activated_socket, False)
|
||||
self.bind_address = self.pre_activated_socket.getsockname()
|
||||
|
||||
self.ready = False
|
||||
self.requests = ThreadPool(self, min_threads=self.opts.min_threads, max_threads=self.opts.max_threads)
|
||||
|
||||
def __str__(self):
|
||||
return "%s(%r)" % (self.__class__.__name__, self.bind_address)
|
||||
__repr__ = __str__
|
||||
|
||||
@property
|
||||
def num_active_connections(self):
|
||||
return len(self.connection_map)
|
||||
|
||||
def serve_forever(self):
|
||||
""" Listen for incoming connections. """
|
||||
|
||||
if self.pre_activated_socket is None:
|
||||
# Select the appropriate socket
|
||||
if isinstance(self.bind_address, basestring):
|
||||
# AF_UNIX socket
|
||||
|
||||
# So we can reuse the socket...
|
||||
try:
|
||||
os.unlink(self.bind_address)
|
||||
except EnvironmentError:
|
||||
pass
|
||||
|
||||
# So everyone can access the socket...
|
||||
try:
|
||||
os.chmod(self.bind_address, 0777)
|
||||
except EnvironmentError:
|
||||
pass
|
||||
|
||||
info = [
|
||||
(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_address)]
|
||||
else:
|
||||
# AF_INET or AF_INET6 socket
|
||||
# Get the correct address family for our host (allows IPv6
|
||||
# addresses)
|
||||
host, port = self.bind_address
|
||||
try:
|
||||
info = socket.getaddrinfo(
|
||||
host, port, socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM, 0, socket.AI_PASSIVE)
|
||||
except socket.gaierror:
|
||||
if ':' in host:
|
||||
info = [(socket.AF_INET6, socket.SOCK_STREAM,
|
||||
0, "", self.bind_address + (0, 0))]
|
||||
else:
|
||||
info = [(socket.AF_INET, socket.SOCK_STREAM,
|
||||
0, "", self.bind_address)]
|
||||
# AF_INET or AF_INET6 socket
|
||||
# Get the correct address family for our host (allows IPv6
|
||||
# addresses)
|
||||
host, port = self.bind_address
|
||||
try:
|
||||
info = socket.getaddrinfo(
|
||||
host, port, socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM, 0, socket.AI_PASSIVE)
|
||||
except socket.gaierror:
|
||||
if ':' in host:
|
||||
info = [(socket.AF_INET6, socket.SOCK_STREAM,
|
||||
0, "", self.bind_address + (0, 0))]
|
||||
else:
|
||||
info = [(socket.AF_INET, socket.SOCK_STREAM,
|
||||
0, "", self.bind_address)]
|
||||
|
||||
self.socket = None
|
||||
msg = "No socket could be created"
|
||||
@ -589,28 +193,28 @@ class ServerLoop(object):
|
||||
self.pre_activated_socket = None
|
||||
self.setup_socket()
|
||||
|
||||
self.socket.listen(5)
|
||||
self.connection_map = {}
|
||||
self.socket.listen(min(socket.SOMAXCONN, 128))
|
||||
self.bound_address = ba = self.socket.getsockname()
|
||||
if isinstance(ba, tuple):
|
||||
ba = ':'.join(map(type(''), ba))
|
||||
self.log('calibre server listening on', ba)
|
||||
with TemporaryDirectory(prefix='srv-') as tdir:
|
||||
self.tdir = tdir
|
||||
self.ready = True
|
||||
self.log('calibre server listening on', ba)
|
||||
|
||||
# Create worker threads
|
||||
self.requests.start()
|
||||
self.ready = True
|
||||
|
||||
while self.ready:
|
||||
try:
|
||||
self.tick()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
raise
|
||||
except:
|
||||
self.log.exception('Error in ServerLoop.tick')
|
||||
while True:
|
||||
try:
|
||||
self.tick()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
self.shutdown()
|
||||
break
|
||||
except:
|
||||
self.log.exception('Error in ServerLoop.tick')
|
||||
|
||||
def setup_socket(self):
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
if self.opts.no_delay and not isinstance(self.bind_address, basestring):
|
||||
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
|
||||
# If listening on the IPV6 any address ('::' = IN6ADDR_ANY),
|
||||
# activate dual-stack.
|
||||
@ -623,149 +227,159 @@ class ServerLoop(object):
|
||||
# Apparently, the socket option is not available in
|
||||
# this machine's TCP stack
|
||||
pass
|
||||
self.socket.setblocking(0)
|
||||
|
||||
def bind(self, family, atype, proto=0):
|
||||
"""Create (or recreate) the actual socket object."""
|
||||
'''Create (or recreate) the actual socket object.'''
|
||||
self.socket = socket.socket(family, atype, proto)
|
||||
set_socket_inherit(self.socket, False)
|
||||
self.setup_socket()
|
||||
self.socket.bind(self.bind_address)
|
||||
|
||||
def tick(self):
|
||||
"""Accept a new connection and put it on the Queue."""
|
||||
now = monotonic()
|
||||
for s, conn in tuple(self.connection_map.iteritems()):
|
||||
if now - conn.last_activity > self.opts.timeout:
|
||||
if not conn.handle_timeout():
|
||||
self.log.debug('Closing connection because of extended inactivity')
|
||||
self.close(s, conn)
|
||||
|
||||
read_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is READ or c.wait_for is RDWR]
|
||||
write_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is WRITE or c.wait_for is RDWR]
|
||||
try:
|
||||
s, addr = self.socket.accept()
|
||||
if not self.ready:
|
||||
readable, writable, _ = select.select([self.socket] + read_needed, write_needed, [], self.opts.timeout)
|
||||
except select.error as e:
|
||||
if e.errno in socket_errors_eintr:
|
||||
return
|
||||
|
||||
set_socket_inherit(s, False)
|
||||
if hasattr(s, 'settimeout'):
|
||||
s.settimeout(self.opts.timeout)
|
||||
|
||||
if self.ssl_context is not None:
|
||||
for s, conn in tuple(self.connection_map.iteritems()):
|
||||
try:
|
||||
s = self.ssl_context.wrap_socket(s, server_side=True)
|
||||
except ssl.SSLEOFError:
|
||||
return # Ignore, client closed connection
|
||||
except ssl.SSLError as e:
|
||||
if e.args[1].endswith('http request'):
|
||||
msg = (b"The client sent a plain HTTP request, but "
|
||||
b"this server only speaks HTTPS on this port.")
|
||||
response = [
|
||||
b"HTTP/1.1 400 Bad Request\r\n",
|
||||
str("Content-Length: %s\r\n" % len(msg)),
|
||||
b"Content-Type: text/plain\r\n\r\n",
|
||||
msg
|
||||
]
|
||||
with SocketFile(s._sock) as f:
|
||||
f.write(response)
|
||||
return
|
||||
elif e.args[1].endswith('unknown protocol'):
|
||||
return # Drop connection
|
||||
raise
|
||||
if hasattr(s, 'settimeout'):
|
||||
s.settimeout(self.opts.timeout)
|
||||
select.select([s], [], [], 0)
|
||||
except select.error:
|
||||
self.close(s, conn) # Bad socket, discard
|
||||
|
||||
conn = Connection(self, s)
|
||||
|
||||
if not isinstance(self.bind_address, basestring):
|
||||
# optional values
|
||||
# Until we do DNS lookups, omit REMOTE_HOST
|
||||
if addr is None: # sometimes this can happen
|
||||
# figure out if AF_INET or AF_INET6.
|
||||
if len(s.getsockname()) == 2:
|
||||
# AF_INET
|
||||
addr = ('0.0.0.0', 0)
|
||||
else:
|
||||
# AF_INET6
|
||||
addr = ('::', 0)
|
||||
conn.remote_addr = addr[0]
|
||||
conn.remote_port = addr[1]
|
||||
|
||||
try:
|
||||
self.requests.put(conn)
|
||||
except Full:
|
||||
self.log.warn('Server overloaded, dropping connection')
|
||||
conn.close()
|
||||
return
|
||||
except socket.timeout:
|
||||
# The only reason for the timeout in start() is so we can
|
||||
# notice keyboard interrupts on Win32, which don't interrupt
|
||||
# accept() by default
|
||||
return
|
||||
except socket.error as e:
|
||||
if e.args[0] in socket_error_eintr | socket_errors_nonblocking | socket_errors_to_ignore:
|
||||
return
|
||||
raise
|
||||
|
||||
def stop(self):
|
||||
""" Gracefully shutdown the server loop. """
|
||||
if not self.ready:
|
||||
return
|
||||
# We run the stop code in its own thread so that it is not interrupted
|
||||
# by KeyboardInterrupt
|
||||
self.ready = False
|
||||
t = Thread(target=self._stop)
|
||||
t.start()
|
||||
try:
|
||||
t.join()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
def _stop(self):
|
||||
self.log('Shutting down server gracefully, waiting for connections to close...')
|
||||
self.requests.stop(self.opts.shutdown_timeout)
|
||||
sock = self.tick_once()
|
||||
if hasattr(sock, "close"):
|
||||
sock.close()
|
||||
self.socket = None
|
||||
ignore = set()
|
||||
for s, conn, event in self.get_actions(readable, writable):
|
||||
if s in ignore:
|
||||
continue
|
||||
try:
|
||||
conn.handle_event(event)
|
||||
if not conn.ready:
|
||||
self.close(s, conn)
|
||||
except Exception as e:
|
||||
ignore.add(s)
|
||||
if conn.ready:
|
||||
self.log.exception('Unhandled exception in state: %s' % conn.state_description)
|
||||
if conn.response_started:
|
||||
self.close(s, conn)
|
||||
else:
|
||||
try:
|
||||
conn.report_unhandled_exception(e, traceback.format_exc())
|
||||
except Exception:
|
||||
self.close(s, conn)
|
||||
else:
|
||||
self.log.error('Error in SSL handshake, terminating connection: %s' % as_unicode(e))
|
||||
self.close(s, conn)
|
||||
|
||||
def tick_once(self):
|
||||
# Touch our own socket to make accept() return immediately.
|
||||
def wakeup(self):
|
||||
# Touch our own socket to make select() return immediately.
|
||||
sock = getattr(self, "socket", None)
|
||||
if sock is not None:
|
||||
if not isinstance(self.bind_address, basestring):
|
||||
try:
|
||||
host, port = sock.getsockname()[:2]
|
||||
except socket.error as e:
|
||||
if e.args[0] not in socket_errors_to_ignore:
|
||||
raise
|
||||
else:
|
||||
# Ensure tick() returns by opening a transient connection
|
||||
# to our own listening socket
|
||||
for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM):
|
||||
af, socktype, proto, canonname, sa = res
|
||||
s = None
|
||||
try:
|
||||
s = socket.socket(af, socktype, proto)
|
||||
s.settimeout(1.0)
|
||||
s.connect((host, port))
|
||||
try:
|
||||
host, port = sock.getsockname()[:2]
|
||||
except socket.error as e:
|
||||
if e.errno not in socket_errors_socket_closed:
|
||||
raise
|
||||
else:
|
||||
for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM):
|
||||
af, socktype, proto, canonname, sa = res
|
||||
s = None
|
||||
try:
|
||||
s = socket.socket(af, socktype, proto)
|
||||
s.settimeout(1.0)
|
||||
s.connect((host, port))
|
||||
s.close()
|
||||
except socket.error:
|
||||
if s is not None:
|
||||
s.close()
|
||||
except socket.error:
|
||||
if s is not None:
|
||||
s.close()
|
||||
return sock
|
||||
|
||||
def echo_handler(conn):
|
||||
keep_going = True
|
||||
while keep_going:
|
||||
def close(self, s, conn):
|
||||
self.connection_map.pop(s, None)
|
||||
conn.close()
|
||||
|
||||
def get_actions(self, readable, writable):
|
||||
for s in readable:
|
||||
if s is self.socket:
|
||||
s, addr = self.accept()
|
||||
if s is not None:
|
||||
self.connection_map[s] = conn = self.handler(s, self.opts, self.ssl_context, self.tdir)
|
||||
if self.ssl_context is not None:
|
||||
yield s, conn, RDWR
|
||||
else:
|
||||
yield s, self.connection_map[s], READ
|
||||
for s in writable:
|
||||
yield s, self.connection_map[s], WRITE
|
||||
|
||||
def accept(self):
|
||||
try:
|
||||
line = conn.socket_file.readline()
|
||||
except socket.timeout:
|
||||
continue
|
||||
conn.server_loop.log('Received:', repr(line))
|
||||
if not line.rstrip():
|
||||
keep_going = False
|
||||
line = b'bye\r\n'
|
||||
conn.socket_file.write(line)
|
||||
conn.socket_file.flush()
|
||||
return self.socket.accept()
|
||||
except socket.error:
|
||||
return None, None
|
||||
|
||||
def stop(self):
|
||||
self.ready = False
|
||||
self.wakeup()
|
||||
|
||||
def shutdown(self):
|
||||
try:
|
||||
if getattr(self, 'socket', None):
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
except socket.error:
|
||||
pass
|
||||
for s, conn in tuple(self.connection_map.iteritems()):
|
||||
self.close(s, conn)
|
||||
|
||||
class EchoLine(Connection): # {{{
|
||||
|
||||
bye_after_echo = False
|
||||
|
||||
def connection_ready(self):
|
||||
self.rbuf = BytesIO()
|
||||
self.set_state(READ, self.read_line)
|
||||
|
||||
def read_line(self, event):
|
||||
data = self.recv(1)
|
||||
if data:
|
||||
self.rbuf.write(data)
|
||||
if b'\n' == data:
|
||||
if self.rbuf.tell() < 3:
|
||||
# Empty line
|
||||
self.rbuf = BytesIO(b'bye' + self.rbuf.getvalue())
|
||||
self.bye_after_echo = True
|
||||
self.set_state(WRITE, self.echo)
|
||||
self.rbuf.seek(0)
|
||||
|
||||
def echo(self, event):
|
||||
pos = self.rbuf.tell()
|
||||
self.rbuf.seek(0, os.SEEK_END)
|
||||
left = self.rbuf.tell() - pos
|
||||
self.rbuf.seek(pos)
|
||||
sent = self.send(self.rbuf.read(512))
|
||||
if sent == left:
|
||||
self.rbuf = BytesIO()
|
||||
self.set_state(READ, self.read_line)
|
||||
if self.bye_after_echo:
|
||||
self.ready = False
|
||||
else:
|
||||
self.rbuf.seek(pos + sent)
|
||||
# }}}
|
||||
|
||||
if __name__ == '__main__':
|
||||
s = ServerLoop(echo_handler)
|
||||
with HandleInterrupt(s.tick_once):
|
||||
try:
|
||||
s.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
s.stop()
|
||||
s = ServerLoop(EchoLine)
|
||||
with HandleInterrupt(s.wakeup):
|
||||
s.serve_forever()
|
||||
|
@ -36,14 +36,6 @@ raw_options = (
|
||||
'shutdown_timeout', 5.0,
|
||||
None,
|
||||
|
||||
'Minimum number of connection handling threads',
|
||||
'min_threads', 10,
|
||||
None,
|
||||
|
||||
'Maximum number of simultaneous connections (beyond this number of connections will be dropped)',
|
||||
'max_threads', 500,
|
||||
None,
|
||||
|
||||
'Allow socket pre-allocation, for example, with systemd socket activation',
|
||||
'allow_socket_preallocation', True,
|
||||
None,
|
||||
@ -52,7 +44,7 @@ raw_options = (
|
||||
'max_header_line_size', 8.0,
|
||||
None,
|
||||
|
||||
'Max. size of a HTTP request (in MB)',
|
||||
'Max. allowed size for files uploaded to the server (in MB)',
|
||||
'max_request_body_size', 500.0,
|
||||
None,
|
||||
|
||||
@ -60,12 +52,6 @@ raw_options = (
|
||||
'compress_min_size', 1024,
|
||||
None,
|
||||
|
||||
'Decrease latency by using the TCP_NODELAY feature',
|
||||
'no_delay', True,
|
||||
'no_delay turns on TCP_NODELAY which decreases latency at the cost of'
|
||||
' worse overall performance when sending multiple small packets. It'
|
||||
' prevents the TCP stack from aggregating multiple small TCP packets.',
|
||||
|
||||
'Use zero copy file transfers for increased performance',
|
||||
'use_sendfile', True,
|
||||
'This will use zero-copy in-kernel transfers when sending files over the network,'
|
||||
|
@ -1,356 +0,0 @@
|
||||
#!/usr/bin/env python2
|
||||
# vim:fileencoding=utf-8
|
||||
from __future__ import (unicode_literals, division, absolute_import,
|
||||
print_function)
|
||||
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
import os, hashlib, httplib, zlib, struct, time, uuid
|
||||
from io import DEFAULT_BUFFER_SIZE, BytesIO
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from future_builtins import map
|
||||
from itertools import izip_longest
|
||||
|
||||
from calibre import force_unicode, guess_type
|
||||
from calibre.srv.errors import IfNoneMatch, RangeNotSatisfiable
|
||||
from calibre.srv.sendfile import file_metadata, copy_range, sendfile_to_socket
|
||||
|
||||
Range = namedtuple('Range', 'start stop size')
|
||||
MULTIPART_SEPARATOR = uuid.uuid4().hex.decode('ascii')
|
||||
|
||||
def get_ranges(headervalue, content_length):
|
||||
''' Return a list of ranges from the Range header. If this function returns
|
||||
an empty list, it indicates no valid range was found. '''
|
||||
if not headervalue:
|
||||
return None
|
||||
|
||||
result = []
|
||||
try:
|
||||
bytesunit, byteranges = headervalue.split("=", 1)
|
||||
except Exception:
|
||||
return None
|
||||
if bytesunit.strip() != 'bytes':
|
||||
return None
|
||||
|
||||
for brange in byteranges.split(","):
|
||||
start, stop = [x.strip() for x in brange.split("-", 1)]
|
||||
if start:
|
||||
if not stop:
|
||||
stop = content_length - 1
|
||||
try:
|
||||
start, stop = int(start), int(stop)
|
||||
except Exception:
|
||||
continue
|
||||
if start >= content_length:
|
||||
continue
|
||||
if stop < start:
|
||||
continue
|
||||
stop = min(stop, content_length - 1)
|
||||
result.append(Range(start, stop, stop - start + 1))
|
||||
elif stop:
|
||||
# Negative subscript (last N bytes)
|
||||
try:
|
||||
stop = int(stop)
|
||||
except Exception:
|
||||
continue
|
||||
if stop > content_length:
|
||||
result.append(Range(0, content_length-1, content_length))
|
||||
else:
|
||||
result.append(Range(content_length - stop, content_length - 1, stop))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def acceptable_encoding(val, allowed=frozenset({'gzip'})):
|
||||
def enc(x):
|
||||
e, r = x.partition(';')[::2]
|
||||
p, v = r.partition('=')[::2]
|
||||
q = 1.0
|
||||
if p == 'q' and v:
|
||||
try:
|
||||
q = float(v)
|
||||
except Exception:
|
||||
pass
|
||||
return e.lower(), q
|
||||
|
||||
emap = dict(enc(x.strip()) for x in val.split(','))
|
||||
acceptable = sorted(set(emap) & allowed, key=emap.__getitem__, reverse=True)
|
||||
if acceptable:
|
||||
return acceptable[0]
|
||||
|
||||
def gzip_prefix(mtime):
|
||||
# See http://www.gzip.org/zlib/rfc-gzip.html
|
||||
return b''.join((
|
||||
b'\x1f\x8b', # ID1 and ID2: gzip marker
|
||||
b'\x08', # CM: compression method
|
||||
b'\x00', # FLG: none set
|
||||
# MTIME: 4 bytes
|
||||
struct.pack(b"<L", int(mtime) & 0xFFFFFFFF),
|
||||
b'\x02', # XFL: max compression, slowest algo
|
||||
b'\xff', # OS: unknown
|
||||
))
|
||||
|
||||
def write_chunked_data(dest, data):
|
||||
dest.write(('%X\r\n' % len(data)).encode('ascii'))
|
||||
dest.write(data)
|
||||
dest.write(b'\r\n')
|
||||
|
||||
def write_compressed_file_obj(input_file, dest, compress_level=6):
|
||||
crc = zlib.crc32(b"")
|
||||
size = 0
|
||||
zobj = zlib.compressobj(compress_level,
|
||||
zlib.DEFLATED, -zlib.MAX_WBITS,
|
||||
zlib.DEF_MEM_LEVEL, 0)
|
||||
prefix_written = False
|
||||
while True:
|
||||
data = input_file.read(DEFAULT_BUFFER_SIZE)
|
||||
if not data:
|
||||
break
|
||||
size += len(data)
|
||||
crc = zlib.crc32(data, crc)
|
||||
data = zobj.compress(data)
|
||||
if not prefix_written:
|
||||
prefix_written = True
|
||||
data = gzip_prefix(time.time()) + data
|
||||
write_chunked_data(dest, data)
|
||||
data = zobj.flush() + struct.pack(b"<L", crc & 0xFFFFFFFF) + struct.pack(b"<L", size & 0xFFFFFFFF)
|
||||
write_chunked_data(dest, data)
|
||||
write_chunked_data(dest, b'')
|
||||
|
||||
def get_range_parts(ranges, content_type, content_length):
|
||||
|
||||
def part(r):
|
||||
ans = ['--%s' % MULTIPART_SEPARATOR, 'Content-Range: bytes %d-%d/%d' % (r.start, r.stop, content_length)]
|
||||
if content_type:
|
||||
ans.append('Content-Type: %s' % content_type)
|
||||
ans.append('')
|
||||
return ('\r\n'.join(ans)).encode('ascii')
|
||||
return list(map(part, ranges)) + [('--%s--' % MULTIPART_SEPARATOR).encode('ascii')]
|
||||
|
||||
def parse_multipart_byterange(buf, content_type):
|
||||
from calibre.srv.http import read_headers
|
||||
sep = (content_type.rsplit('=', 1)[-1]).encode('utf-8')
|
||||
ans = []
|
||||
|
||||
def parse_part():
|
||||
line = buf.readline()
|
||||
if not line:
|
||||
raise ValueError('Premature end of message')
|
||||
if not line.startswith(b'--' + sep):
|
||||
raise ValueError('Malformed start of multipart message')
|
||||
if line.endswith(b'--'):
|
||||
return None
|
||||
headers = read_headers(buf.readline)
|
||||
cr = headers.get('Content-Range')
|
||||
if not cr:
|
||||
raise ValueError('Missing Content-Range header in sub-part')
|
||||
if not cr.startswith('bytes '):
|
||||
raise ValueError('Malformed Content-Range header in sub-part, no prefix')
|
||||
try:
|
||||
start, stop = map(lambda x: int(x.strip()), cr.partition(' ')[-1].partition('/')[0].partition('-')[::2])
|
||||
except Exception:
|
||||
raise ValueError('Malformed Content-Range header in sub-part, failed to parse byte range')
|
||||
content_length = stop - start + 1
|
||||
ret = buf.read(content_length)
|
||||
if len(ret) != content_length:
|
||||
raise ValueError('Malformed sub-part, length of body not equal to length specified in Content-Range')
|
||||
buf.readline()
|
||||
return (start, ret)
|
||||
while True:
|
||||
data = parse_part()
|
||||
if data is None:
|
||||
break
|
||||
ans.append(data)
|
||||
return ans
|
||||
|
||||
class ReadableOutput(object):
|
||||
|
||||
def __init__(self, output, outheaders):
|
||||
self.src_file = output
|
||||
self.src_file.seek(0, os.SEEK_END)
|
||||
self.content_length = self.src_file.tell()
|
||||
self.etag = None
|
||||
self.accept_ranges = True
|
||||
self.use_sendfile = False
|
||||
|
||||
def write(self, dest):
|
||||
if self.use_sendfile:
|
||||
dest.flush() # Ensure everything in the SocketFile buffer is sent before calling sendfile()
|
||||
sent = sendfile_to_socket(self.src_file, 0, self.content_length, dest)
|
||||
else:
|
||||
sent = copy_range(self.src_file, 0, self.content_length, dest)
|
||||
if sent != self.content_length:
|
||||
raise IOError(
|
||||
'Failed to send complete file (%r) (%s != %s bytes), perhaps the file was modified during send?' % (
|
||||
getattr(self.src_file, 'name', '<file>'), sent, self.content_length))
|
||||
self.src_file = None
|
||||
|
||||
def write_compressed(self, dest):
|
||||
self.src_file.seek(0)
|
||||
write_compressed_file_obj(self.src_file, dest)
|
||||
self.src_file = None
|
||||
|
||||
def write_ranges(self, ranges, dest):
|
||||
if isinstance(ranges, Range):
|
||||
r = ranges
|
||||
self.copy_range(r.start, r.size, dest)
|
||||
else:
|
||||
for r, header in ranges:
|
||||
dest.write(header)
|
||||
if r is not None:
|
||||
dest.write(b'\r\n')
|
||||
self.copy_range(r.start, r.size, dest)
|
||||
dest.write(b'\r\n')
|
||||
self.src_file = None
|
||||
|
||||
def copy_range(self, start, size, dest):
|
||||
if self.use_sendfile:
|
||||
dest.flush() # Ensure everything in the SocketFile buffer is sent before calling sendfile()
|
||||
sent = sendfile_to_socket(self.src_file, start, size, dest)
|
||||
else:
|
||||
sent = copy_range(self.src_file, start, size, dest)
|
||||
if sent != size:
|
||||
raise IOError('Failed to send byte range from file (%r) (%s != %s bytes), perhaps the file was modified during send?' % (
|
||||
getattr(self.src_file, 'name', '<file>'), sent, size))
|
||||
|
||||
class FileSystemOutputFile(ReadableOutput):
|
||||
|
||||
def __init__(self, output, outheaders, stat_result, use_sendfile):
|
||||
self.src_file = output
|
||||
self.name = output.name
|
||||
self.content_length = stat_result.st_size
|
||||
self.etag = '"%s"' % hashlib.sha1(type('')(stat_result.st_mtime) + force_unicode(output.name or '')).hexdigest()
|
||||
self.accept_ranges = True
|
||||
self.use_sendfile = use_sendfile and sendfile_to_socket is not None
|
||||
|
||||
|
||||
class DynamicOutput(object):
|
||||
|
||||
def __init__(self, output, outheaders):
|
||||
if isinstance(output, bytes):
|
||||
self.data = output
|
||||
else:
|
||||
self.data = output.encode('utf-8')
|
||||
ct = outheaders.get('Content-Type')
|
||||
if not ct:
|
||||
outheaders.set('Content-Type', 'text/plain; charset=UTF-8', replace_all=True)
|
||||
self.content_length = len(self.data)
|
||||
self.etag = None
|
||||
self.accept_ranges = False
|
||||
|
||||
def write(self, dest):
|
||||
dest.write(self.data)
|
||||
self.data = None
|
||||
|
||||
def write_compressed(self, dest):
|
||||
write_compressed_file_obj(BytesIO(self.data), dest)
|
||||
|
||||
class GeneratedOutput(object):
|
||||
|
||||
def __init__(self, output, outheaders):
|
||||
self.output = output
|
||||
self.content_length = self.etag = None
|
||||
self.accept_ranges = False
|
||||
|
||||
def write(self, dest):
|
||||
for line in self.output:
|
||||
if line:
|
||||
write_chunked_data(dest, line)
|
||||
|
||||
class StaticGeneratedOutput(object):
|
||||
|
||||
def __init__(self, data):
|
||||
if isinstance(data, type('')):
|
||||
data = data.encode('utf-8')
|
||||
self.data = data
|
||||
self.etag = '"%s"' % hashlib.sha1(data).hexdigest()
|
||||
self.content_length = len(data)
|
||||
self.accept_ranges = False
|
||||
|
||||
def write(self, dest):
|
||||
dest.write(self.data)
|
||||
|
||||
def write_compressed(self, dest):
|
||||
write_compressed_file_obj(BytesIO(self.data), dest)
|
||||
|
||||
def generate_static_output(cache, gso_lock, name, generator):
|
||||
with gso_lock:
|
||||
ans = cache.get(name)
|
||||
if ans is None:
|
||||
ans = cache[name] = StaticGeneratedOutput(generator())
|
||||
return ans
|
||||
|
||||
def parse_if_none_match(val):
|
||||
return {x.strip() for x in val.split(',')}
|
||||
|
||||
def finalize_output(output, inheaders, outheaders, status_code, is_http1, method, opts):
|
||||
ct = outheaders.get('Content-Type', '')
|
||||
compressible = not ct or ct.startswith('text/') or ct.startswith('image/svg') or ct.startswith('application/json')
|
||||
stat_result = file_metadata(output)
|
||||
if stat_result is not None:
|
||||
output = FileSystemOutputFile(output, outheaders, stat_result, opts.use_sendfile)
|
||||
if 'Content-Type' not in outheaders:
|
||||
mt = guess_type(output.name)[0]
|
||||
if mt:
|
||||
if mt in ('text/plain', 'text/html'):
|
||||
mt =+ '; charset=UTF-8'
|
||||
outheaders['Content-Type'] = mt
|
||||
elif isinstance(output, (bytes, type(''))):
|
||||
output = DynamicOutput(output, outheaders)
|
||||
elif hasattr(output, 'read'):
|
||||
output = ReadableOutput(output, outheaders)
|
||||
elif isinstance(output, StaticGeneratedOutput):
|
||||
pass
|
||||
else:
|
||||
output = GeneratedOutput(output, outheaders)
|
||||
compressible = (status_code == httplib.OK and compressible and
|
||||
(opts.compress_min_size > -1 and output.content_length >= opts.compress_min_size) and
|
||||
acceptable_encoding(inheaders.get('Accept-Encoding', '')) and not is_http1)
|
||||
accept_ranges = (not compressible and output.accept_ranges is not None and status_code == httplib.OK and
|
||||
not is_http1)
|
||||
ranges = get_ranges(inheaders.get('Range'), output.content_length) if output.accept_ranges and method in ('GET', 'HEAD') else None
|
||||
if_range = (inheaders.get('If-Range') or '').strip()
|
||||
if if_range and if_range != output.etag:
|
||||
ranges = None
|
||||
if ranges is not None and not ranges:
|
||||
raise RangeNotSatisfiable(output.content_length)
|
||||
|
||||
for header in ('Accept-Ranges', 'Content-Encoding', 'Transfer-Encoding', 'ETag', 'Content-Length'):
|
||||
outheaders.pop('header', all=True)
|
||||
|
||||
none_match = parse_if_none_match(inheaders.get('If-None-Match', ''))
|
||||
matched = '*' in none_match or (output.etag and output.etag in none_match)
|
||||
if matched:
|
||||
raise IfNoneMatch(output.etag)
|
||||
|
||||
if output.etag and method in ('GET', 'HEAD'):
|
||||
outheaders.set('ETag', output.etag, replace_all=True)
|
||||
if accept_ranges:
|
||||
outheaders.set('Accept-Ranges', 'bytes', replace_all=True)
|
||||
elif compressible:
|
||||
outheaders.set('Content-Encoding', 'gzip', replace_all=True)
|
||||
if output.content_length is not None and not compressible and not ranges:
|
||||
outheaders.set('Content-Length', '%d' % output.content_length, replace_all=True)
|
||||
|
||||
if compressible or output.content_length is None:
|
||||
outheaders.set('Transfer-Encoding', 'chunked', replace_all=True)
|
||||
|
||||
if ranges:
|
||||
if len(ranges) == 1:
|
||||
r = ranges[0]
|
||||
outheaders.set('Content-Length', '%d' % r.size, replace_all=True)
|
||||
outheaders.set('Content-Range', 'bytes %d-%d/%d' % (r.start, r.stop, output.content_length), replace_all=True)
|
||||
output.commit = partial(output.write_ranges, r)
|
||||
else:
|
||||
range_parts = get_range_parts(ranges, outheaders.get('Content-Type'), output.content_length)
|
||||
size = sum(map(len, range_parts)) + sum(r.size + 4 for r in ranges)
|
||||
outheaders.set('Content-Length', '%d' % size, replace_all=True)
|
||||
outheaders.set('Content-Type', 'multipart/byteranges; boundary=' + MULTIPART_SEPARATOR, replace_all=True)
|
||||
output.commit = partial(output.write_ranges, izip_longest(ranges, range_parts))
|
||||
status_code = httplib.PARTIAL_CONTENT
|
||||
else:
|
||||
output.commit = output.write_compressed if compressible else output.write
|
||||
|
||||
return status_code, output
|
@ -10,7 +10,7 @@ import os, ctypes, errno, socket
|
||||
from io import DEFAULT_BUFFER_SIZE
|
||||
from select import select
|
||||
|
||||
from calibre.constants import iswindows, isosx
|
||||
from calibre.constants import islinux, isosx
|
||||
from calibre.srv.utils import eintr_retry_call
|
||||
|
||||
def file_metadata(fileobj):
|
||||
@ -33,10 +33,15 @@ def copy_range(src_file, start, size, dest):
|
||||
del data
|
||||
return total_sent
|
||||
|
||||
class CannotSendfile(Exception):
|
||||
pass
|
||||
|
||||
if iswindows:
|
||||
sendfile_to_socket = None
|
||||
elif isosx:
|
||||
class SendfileInterrupted(Exception):
|
||||
pass
|
||||
|
||||
sendfile_to_socket = sendfile_to_socket_async = None
|
||||
|
||||
if isosx:
|
||||
libc = ctypes.CDLL(None, use_errno=True)
|
||||
sendfile = ctypes.CFUNCTYPE(
|
||||
ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int64, ctypes.POINTER(ctypes.c_int64), ctypes.c_void_p, ctypes.c_int, use_errno=True)(
|
||||
@ -68,7 +73,19 @@ elif isosx:
|
||||
offset += num_bytes.value
|
||||
return total_sent
|
||||
|
||||
else:
|
||||
def sendfile_to_socket_async(fileobj, offset, size, socket_file):
|
||||
num_bytes = ctypes.c_int64(size)
|
||||
ret = sendfile(fileobj.fileno(), socket_file.fileno(), offset, ctypes.byref(num_bytes), None, 0)
|
||||
if ret != 0:
|
||||
err = ctypes.get_errno()
|
||||
if err in (errno.EBADF, errno.ENOTSUP, errno.ENOTSOCK, errno.EOPNOTSUPP):
|
||||
raise CannotSendfile()
|
||||
if err in (errno.EINTR, errno.EAGAIN):
|
||||
raise SendfileInterrupted()
|
||||
raise IOError((err, os.strerror(err)))
|
||||
return num_bytes.value
|
||||
|
||||
elif islinux:
|
||||
libc = ctypes.CDLL(None, use_errno=True)
|
||||
sendfile = ctypes.CFUNCTYPE(
|
||||
ctypes.c_ssize_t, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_int64), ctypes.c_size_t, use_errno=True)(('sendfile64', libc))
|
||||
@ -97,3 +114,15 @@ else:
|
||||
size -= sent
|
||||
total_sent += sent
|
||||
return total_sent
|
||||
|
||||
def sendfile_to_socket_async(fileobj, offset, size, socket_file):
|
||||
off = ctypes.c_int64(offset)
|
||||
sent = sendfile(socket_file.fileno(), fileobj.fileno(), ctypes.byref(off), size)
|
||||
if sent < 0:
|
||||
err = ctypes.get_errno()
|
||||
if err in (errno.ENOSYS, errno.EINVAL):
|
||||
raise CannotSendfile()
|
||||
if err in (errno.EINTR, errno.EAGAIN):
|
||||
raise SendfileInterrupted()
|
||||
raise IOError((err, os.strerror(err)))
|
||||
return sent
|
||||
|
@ -36,7 +36,7 @@ class TestServer(Thread):
|
||||
Thread.__init__(self, name='ServerMain')
|
||||
from calibre.srv.opts import Options
|
||||
from calibre.srv.loop import ServerLoop
|
||||
from calibre.srv.http import create_http_handler
|
||||
from calibre.srv.http_response import create_http_handler
|
||||
kwargs['shutdown_timeout'] = kwargs.get('shutdown_timeout', 0.1)
|
||||
self.loop = ServerLoop(
|
||||
create_http_handler(handler),
|
||||
@ -68,5 +68,5 @@ class TestServer(Thread):
|
||||
return httplib.HTTPConnection(self.address[0], self.address[1], strict=True, timeout=timeout)
|
||||
|
||||
def change_handler(self, handler):
|
||||
from calibre.srv.http import create_http_handler
|
||||
self.loop.req_resp_handler = create_http_handler(handler)
|
||||
from calibre.srv.http_response import create_http_handler
|
||||
self.loop.handler = create_http_handler(handler)
|
||||
|
@ -6,57 +6,56 @@ from __future__ import (unicode_literals, division, absolute_import,
|
||||
__license__ = 'GPL v3'
|
||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||
|
||||
import textwrap, httplib, hashlib, zlib, string
|
||||
import httplib, hashlib, zlib, string
|
||||
from io import BytesIO
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from calibre import guess_type
|
||||
from calibre.srv.tests.base import BaseTest, TestServer
|
||||
|
||||
def headers(raw):
|
||||
return BytesIO(textwrap.dedent(raw).encode('utf-8'))
|
||||
|
||||
class TestHTTP(BaseTest):
|
||||
|
||||
def test_header_parsing(self): # {{{
|
||||
'Test parsing of HTTP headers'
|
||||
from calibre.srv.http import read_headers
|
||||
from calibre.srv.http_request import HTTPHeaderParser
|
||||
|
||||
def test(name, raw, **kwargs):
|
||||
hdict = read_headers(headers(raw).readline)
|
||||
self.assertSetEqual(set(hdict.items()), {(k.replace('_', '-').title(), v) for k, v in kwargs.iteritems()}, name + ' failed')
|
||||
def test(name, *lines, **kwargs):
|
||||
p = HTTPHeaderParser()
|
||||
p.push(*lines)
|
||||
self.assertTrue(p.finished)
|
||||
self.assertSetEqual(set(p.hdict.items()), {(k.replace('_', '-').title(), v) for k, v in kwargs.iteritems()}, name + ' failed')
|
||||
|
||||
test('Continuation line parsing',
|
||||
'''\
|
||||
a: one\r
|
||||
b: two\r
|
||||
2\r
|
||||
\t3\r
|
||||
c:three\r
|
||||
\r\n''', a='one', b='two 2 3', c='three')
|
||||
'a: one',
|
||||
'b: two',
|
||||
' 2',
|
||||
'\t3',
|
||||
'c:three',
|
||||
'\r\n', a='one', b='two 2 3', c='three')
|
||||
|
||||
test('Non-ascii headers parsing',
|
||||
'''\
|
||||
a:mūs\r
|
||||
\r\n''', a='mūs'.encode('utf-8'))
|
||||
b'a:mūs\r', '\r\n', a='mūs'.encode('utf-8'))
|
||||
|
||||
test('Comma-separated parsing',
|
||||
'''\
|
||||
Accept-Encoding: one\r
|
||||
Accept-Encoding: two\r
|
||||
\r\n''', accept_encoding='one, two')
|
||||
'Accept-Encoding: one',
|
||||
'accept-Encoding: two',
|
||||
'\r\n', accept_encoding='one, two')
|
||||
|
||||
def parse(line):
|
||||
HTTPHeaderParser()(line)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
read_headers(headers('Connection:mūs\r\n').readline)
|
||||
read_headers(headers('Connection\r\n').readline)
|
||||
read_headers(headers('Connection:a\r\n').readline)
|
||||
read_headers(headers('Connection:a\n').readline)
|
||||
read_headers(headers(' Connection:a\n').readline)
|
||||
parse('Connection:mūs\r\n')
|
||||
parse('Connection\r\n')
|
||||
parse('Connection:a\r\n')
|
||||
parse('Connection:a\n')
|
||||
parse(' Connection:a\n')
|
||||
parse(':a\n')
|
||||
# }}}
|
||||
|
||||
def test_accept_encoding(self): # {{{
|
||||
'Test parsing of Accept-Encoding'
|
||||
from calibre.srv.respond import acceptable_encoding
|
||||
from calibre.srv.http_response import acceptable_encoding
|
||||
def test(name, val, ans, allowed={'gzip'}):
|
||||
self.ae(acceptable_encoding(val, allowed), ans, name + ' failed')
|
||||
test('Empty field', '', None)
|
||||
@ -68,7 +67,7 @@ class TestHTTP(BaseTest):
|
||||
|
||||
def test_range_parsing(self): # {{{
|
||||
'Test parsing of Range header'
|
||||
from calibre.srv.respond import get_ranges
|
||||
from calibre.srv.http_response import get_ranges
|
||||
def test(val, *args):
|
||||
pval = get_ranges(val, 100)
|
||||
if len(args) == 1 and args[0] is None:
|
||||
@ -91,11 +90,38 @@ class TestHTTP(BaseTest):
|
||||
'Test basic HTTP protocol conformance'
|
||||
from calibre.srv.errors import HTTP404
|
||||
body = 'Requested resource not found'
|
||||
def handler(conn):
|
||||
def handler(data):
|
||||
raise HTTP404(body)
|
||||
def raw_send(conn, raw):
|
||||
conn.send(raw)
|
||||
conn._HTTPConnection__state = httplib._CS_REQ_SENT
|
||||
return conn.getresponse()
|
||||
|
||||
with TestServer(handler, timeout=0.1, max_header_line_size=100./1024, max_request_body_size=100./(1024*1024)) as server:
|
||||
# Test 404
|
||||
conn = server.connect()
|
||||
r = raw_send(conn, b'hello\n')
|
||||
self.ae(r.status, httplib.BAD_REQUEST)
|
||||
self.ae(r.read(), b'HTTP requires CRLF line terminators')
|
||||
|
||||
r = raw_send(conn, b'\r\nGET /index.html HTTP/1.1\r\n\r\n')
|
||||
self.ae(r.status, httplib.NOT_FOUND), self.ae(r.read(), b'Requested resource not found')
|
||||
|
||||
r = raw_send(conn, b'\r\n\r\nGET /index.html HTTP/1.1\r\n\r\n')
|
||||
self.ae(r.status, httplib.BAD_REQUEST)
|
||||
self.ae(r.read(), b'Multiple leading empty lines not allowed')
|
||||
|
||||
r = raw_send(conn, b'hello world\r\n')
|
||||
self.ae(r.status, httplib.BAD_REQUEST)
|
||||
self.ae(r.read(), b'Malformed Request-Line')
|
||||
|
||||
r = raw_send(conn, b'x' * 200)
|
||||
self.ae(r.status, httplib.BAD_REQUEST)
|
||||
self.ae(r.read(), b'')
|
||||
|
||||
r = raw_send(conn, b'XXX /index.html HTTP/1.1\r\n\r\n')
|
||||
self.ae(r.status, httplib.BAD_REQUEST), self.ae(r.read(), b'Unknown HTTP method')
|
||||
|
||||
# Test 404
|
||||
conn.request('HEAD', '/moose')
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.NOT_FOUND)
|
||||
@ -104,32 +130,48 @@ class TestHTTP(BaseTest):
|
||||
self.ae(r.getheader('Content-Type'), 'text/plain; charset=UTF-8')
|
||||
self.ae(len(r.getheaders()), 3)
|
||||
self.ae(r.read(), '')
|
||||
conn.request('GET', '/moose')
|
||||
conn.request('GET', '/choose')
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.NOT_FOUND)
|
||||
self.ae(r.read(), 'Requested resource not found')
|
||||
self.ae(r.read(), b'Requested resource not found')
|
||||
|
||||
server.change_handler(lambda conn:conn.path[0] + conn.input_reader.read().decode('ascii'))
|
||||
# Test 500
|
||||
orig = server.loop.log.filter_level
|
||||
server.loop.log.filter_level = server.loop.log.ERROR + 10
|
||||
server.change_handler(lambda data:1/0)
|
||||
conn = server.connect()
|
||||
conn.request('GET', '/test/')
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.INTERNAL_SERVER_ERROR)
|
||||
server.loop.log.filter_level = orig
|
||||
|
||||
server.change_handler(lambda data:data.path[0] + data.read().decode('ascii'))
|
||||
conn = server.connect()
|
||||
|
||||
# Test simple GET
|
||||
conn.request('GET', '/test/')
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.OK)
|
||||
self.ae(r.read(), 'test')
|
||||
self.ae(r.read(), b'test')
|
||||
|
||||
# Test TRACE
|
||||
lines = ['TRACE /xxx HTTP/1.1', 'Test: value', 'Xyz: abc, def', '', '']
|
||||
r = raw_send(conn, ('\r\n'.join(lines)).encode('ascii'))
|
||||
self.ae(r.status, httplib.OK)
|
||||
self.ae(r.read().decode('utf-8'), '\n'.join(lines[:-2]))
|
||||
|
||||
# Test POST with simple body
|
||||
conn.request('POST', '/test', 'body')
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.CREATED)
|
||||
self.ae(r.read(), 'testbody')
|
||||
self.ae(r.read(), b'testbody')
|
||||
|
||||
# Test POST with chunked transfer encoding
|
||||
conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'})
|
||||
conn.send(b'4\r\nbody\r\n0\r\n\r\n')
|
||||
conn.send(b'4\r\nbody\r\na\r\n1234567890\r\n0\r\n\r\n')
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.CREATED)
|
||||
self.ae(r.read(), 'testbody')
|
||||
self.ae(r.read(), b'testbody1234567890')
|
||||
|
||||
# Test various incorrect input
|
||||
orig_level, server.log.filter_level = server.log.filter_level, server.log.ERROR
|
||||
@ -150,19 +192,26 @@ class TestHTTP(BaseTest):
|
||||
self.ae(r.status, httplib.BAD_REQUEST)
|
||||
self.assertIn(b'not a valid chunk size', r.read())
|
||||
|
||||
conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'})
|
||||
conn.send(b'4\r\nbody\r\n200\r\n\r\n')
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.REQUEST_ENTITY_TOO_LARGE)
|
||||
conn.request('POST', '/test', body='a'*200)
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.REQUEST_ENTITY_TOO_LARGE)
|
||||
|
||||
conn = server.connect()
|
||||
conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'})
|
||||
conn.send(b'3\r\nbody\r\n0\r\n\r\n')
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.BAD_REQUEST)
|
||||
self.assertIn(b'!= CRLF', r.read())
|
||||
self.ae(r.status, httplib.BAD_REQUEST), self.ae(r.read(), b'Chunk does not have trailing CRLF')
|
||||
|
||||
conn = server.connect(timeout=1)
|
||||
conn.request('POST', '/test', headers={'Transfer-Encoding': 'chunked'})
|
||||
conn.send(b'30\r\nbody\r\n0\r\n\r\n')
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.BAD_REQUEST)
|
||||
self.assertIn(b'Timed out waiting for chunk', r.read())
|
||||
self.ae(r.status, httplib.REQUEST_TIMEOUT)
|
||||
self.assertIn(b'', r.read())
|
||||
|
||||
server.log.filter_level = orig_level
|
||||
conn = server.connect()
|
||||
@ -180,18 +229,17 @@ class TestHTTP(BaseTest):
|
||||
|
||||
# Test closing
|
||||
conn.request('GET', '/close', headers={'Connection':'close'})
|
||||
self.ae(server.loop.requests.busy, 1)
|
||||
r = conn.getresponse()
|
||||
self.ae(server.loop.num_active_connections, 1)
|
||||
self.ae(r.status, 200), self.ae(r.read(), 'close')
|
||||
self.ae(server.loop.requests.busy, 0)
|
||||
server.loop.wakeup()
|
||||
self.ae(server.loop.num_active_connections, 0)
|
||||
self.assertIsNone(conn.sock)
|
||||
self.ae(server.loop.requests.idle, 10)
|
||||
|
||||
# }}}
|
||||
|
||||
def test_http_response(self): # {{{
|
||||
'Test HTTP protocol responses'
|
||||
from calibre.srv.respond import parse_multipart_byterange
|
||||
from calibre.srv.http_response import parse_multipart_byterange
|
||||
def handler(conn):
|
||||
return conn.generate_static_output('test', lambda : ''.join(conn.path))
|
||||
with TestServer(handler, timeout=0.1, compress_min_size=0) as server, \
|
||||
@ -216,9 +264,10 @@ class TestHTTP(BaseTest):
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.OK), self.ae(zlib.decompress(r.read(), 16+zlib.MAX_WBITS), b'an_etagged_path')
|
||||
|
||||
for i in '12':
|
||||
# Test getting a filesystem file
|
||||
# Test getting a filesystem file
|
||||
for use_sendfile in (True, False):
|
||||
server.change_handler(lambda conn: f)
|
||||
server.loop.opts.use_sendfile = use_sendfile
|
||||
conn = server.connect()
|
||||
conn.request('GET', '/test')
|
||||
r = conn.getresponse()
|
||||
@ -229,27 +278,27 @@ class TestHTTP(BaseTest):
|
||||
self.ae(int(r.getheader('Content-Length')), len(fdata))
|
||||
self.ae(r.status, httplib.OK), self.ae(r.read(), fdata)
|
||||
|
||||
conn.request('GET', '/test', headers={'Range':'bytes=0-25'})
|
||||
conn.request('GET', '/test', headers={'Range':'bytes=2-25'})
|
||||
r = conn.getresponse()
|
||||
self.ae(type('')(r.getheader('Accept-Ranges')), 'bytes')
|
||||
self.ae(type('')(r.getheader('Content-Range')), 'bytes 0-25/%d' % len(fdata))
|
||||
self.ae(int(r.getheader('Content-Length')), 26)
|
||||
self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[0:26])
|
||||
self.ae(type('')(r.getheader('Content-Range')), 'bytes 2-25/%d' % len(fdata))
|
||||
self.ae(int(r.getheader('Content-Length')), 24)
|
||||
self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[2:26])
|
||||
|
||||
conn.request('GET', '/test', headers={'Range':'bytes=100000-'})
|
||||
r = conn.getresponse()
|
||||
self.ae(type('')(r.getheader('Content-Range')), 'bytes */%d' % len(fdata))
|
||||
self.ae(r.status, httplib.REQUESTED_RANGE_NOT_SATISFIABLE)
|
||||
|
||||
conn.request('GET', '/test', headers={'Range':'bytes=0-1000000'})
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata)
|
||||
|
||||
conn.request('GET', '/test', headers={'Range':'bytes=25-50', 'If-Range':etag})
|
||||
r = conn.getresponse()
|
||||
self.ae(int(r.getheader('Content-Length')), 26)
|
||||
self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[25:51])
|
||||
|
||||
conn.request('GET', '/test', headers={'Range':'bytes=0-1000000'})
|
||||
r = conn.getresponse()
|
||||
self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata)
|
||||
|
||||
conn.request('GET', '/test', headers={'Range':'bytes=25-50', 'If-Range':'"nomatch"'})
|
||||
r = conn.getresponse()
|
||||
self.assertFalse(r.getheader('Content-Range'))
|
||||
@ -274,6 +323,4 @@ class TestHTTP(BaseTest):
|
||||
self.ae(data, r.read())
|
||||
|
||||
# Now try it without sendfile
|
||||
server.loop.opts.use_sendfile ^= True
|
||||
conn = server.connect()
|
||||
# }}}
|
||||
|
@ -11,9 +11,13 @@ from contextlib import closing
|
||||
from urlparse import parse_qs
|
||||
import repr as reprlib
|
||||
from email.utils import formatdate
|
||||
from operator import itemgetter
|
||||
|
||||
from calibre.constants import iswindows
|
||||
|
||||
HTTP1 = 'HTTP/1.0'
|
||||
HTTP11 = 'HTTP/1.1'
|
||||
|
||||
def http_date(timeval=None):
|
||||
return type('')(formatdate(timeval=timeval, usegmt=True))
|
||||
|
||||
@ -85,7 +89,8 @@ class MultiDict(dict): # {{{
|
||||
__str__ = __unicode__ = __repr__
|
||||
|
||||
def pretty(self, leading_whitespace=''):
|
||||
return leading_whitespace + ('\n' + leading_whitespace).join('%s: %s' % (k, v) for k, v in self.items())
|
||||
return leading_whitespace + ('\n' + leading_whitespace).join(
|
||||
'%s: %s' % (k, (repr(v) if isinstance(v, bytes) else v)) for k, v in sorted(self.items(), key=itemgetter(0)))
|
||||
# }}}
|
||||
|
||||
def error_codes(*errnames):
|
||||
@ -112,29 +117,15 @@ socket_errors_socket_closed = error_codes( # errors indicating a disconnected c
|
||||
socket_errors_nonblocking = error_codes(
|
||||
'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK')
|
||||
|
||||
class Corked(object):
|
||||
def start_cork(sock):
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0)
|
||||
if hasattr(socket, 'TCP_CORK'):
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 1)
|
||||
|
||||
' Context manager to turn on TCP corking. Ensures maximum throughput for large logical packets. '
|
||||
|
||||
def __init__(self, sock):
|
||||
self.sock = sock
|
||||
|
||||
def __enter__(self):
|
||||
nodelay = self.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
|
||||
if nodelay == 1:
|
||||
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0)
|
||||
self.set_nodelay = True
|
||||
else:
|
||||
self.set_nodelay = False
|
||||
if hasattr(socket, 'TCP_CORK'):
|
||||
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 1)
|
||||
|
||||
def __exit__(self, *args):
|
||||
if self.set_nodelay:
|
||||
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
if hasattr(socket, 'TCP_CORK'):
|
||||
self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 0)
|
||||
self.sock.send(b'') # Ensure that uncorking occurs
|
||||
def stop_cork(sock):
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
if hasattr(socket, 'TCP_CORK'):
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_CORK, 0)
|
||||
|
||||
def create_sock_pair(port=0):
|
||||
'''Create socket pair. Works also on windows by using an ephemeral TCP port.'''
|
||||
|
Loading…
x
Reference in New Issue
Block a user