diff --git a/src/calibre/srv/errors.py b/src/calibre/srv/errors.py index f0ad28cae4..72388edcaa 100644 --- a/src/calibre/srv/errors.py +++ b/src/calibre/srv/errors.py @@ -30,3 +30,9 @@ class IfNoneMatch(Exception): class BadChunkedInput(ValueError): pass + +class RangeNotSatisfiable(ValueError): + + def __init__(self, content_length): + ValueError.__init__(self) + self.content_length = content_length diff --git a/src/calibre/srv/http.py b/src/calibre/srv/http.py index 3e2a5bebcc..a2e50658cb 100644 --- a/src/calibre/srv/http.py +++ b/src/calibre/srv/http.py @@ -16,7 +16,7 @@ from operator import itemgetter from calibre import as_unicode from calibre.constants import __version__ from calibre.srv.errors import ( - MaxSizeExceeded, NonHTTPConnRequest, HTTP404, IfNoneMatch, BadChunkedInput) + MaxSizeExceeded, NonHTTPConnRequest, HTTP404, IfNoneMatch, BadChunkedInput, RangeNotSatisfiable) from calibre.srv.respond import finalize_output, generate_static_output from calibre.srv.utils import MultiDict, http_date, socket_errors_to_ignore @@ -499,14 +499,26 @@ class HTTPPair(object): ] if etag is not None: buf.append('ETag: ' + etag) - for header in ('Expires', 'Cache-Control', 'Vary'): - val = self.outheaders.get(header) - if val: - buf.append(header + ': ' + val) + self.send_buf(buf) + + def send_buf(self, buf, include_cache_headers=True): + if include_cache_headers: + for header in ('Expires', 'Cache-Control', 'Vary'): + val = self.outheaders.get(header) + if val: + buf.append(header + ': ' + val) buf.append('') buf = [(x + '\r\n').encode('ascii') for x in buf] self.flushed_write(b''.join(buf)) + def send_range_not_satisfiable(self, content_length): + buf = [ + '%s %d %s' % (self.response_protocol, httplib.REQUESTED_RANGE_NOT_SATISFIABLE, httplib.responses[httplib.REQUESTED_RANGE_NOT_SATISFIABLE]), + "Date: " + http_date(), + "Content-Range: bytes */%d" % content_length, + ] + self.send_buf(buf) + def flushed_write(self, data): self.conn.socket_file.write(data) self.conn.socket_file.flush() @@ -552,6 +564,9 @@ class HTTPPair(object): else: self.simple_response(httplib.PRECONDITION_FAILED) return + except RangeNotSatisfiable as e: + self.send_range_not_satisfiable(e.content_length) + return with self.conn.corked: self.send_headers() diff --git a/src/calibre/srv/respond.py b/src/calibre/srv/respond.py index ed50e560b1..9c16879d3f 100644 --- a/src/calibre/srv/respond.py +++ b/src/calibre/srv/respond.py @@ -6,16 +6,22 @@ from __future__ import (unicode_literals, division, absolute_import, __license__ = 'GPL v3' __copyright__ = '2015, Kovid Goyal ' -import os, hashlib, shutil, httplib, zlib, struct, time +import os, hashlib, shutil, httplib, zlib, struct, time, uuid from io import DEFAULT_BUFFER_SIZE, BytesIO +from collections import namedtuple +from functools import partial +from future_builtins import map +from itertools import izip_longest -from calibre import force_unicode -from calibre.srv.errors import IfNoneMatch +from calibre import force_unicode, guess_type +from calibre.srv.errors import IfNoneMatch, RangeNotSatisfiable + +Range = namedtuple('Range', 'start stop size') +MULTIPART_SEPARATOR = uuid.uuid4().hex.decode('ascii') def get_ranges(headervalue, content_length): - """Return a list of (start, num_of_bytes) indices from a Range header, or None. - If this function returns an empty list, it indicates no valid range was found. - """ + ''' Return a list of ranges from the Range header. If this function returns + an empty list, it indicates no valid range was found. ''' if not headervalue: return None @@ -40,7 +46,7 @@ def get_ranges(headervalue, content_length): continue if stop < start: continue - result.append((start, stop - start + 1)) + result.append(Range(start, stop, stop - start + 1)) elif stop: # Negative subscript (last N bytes) try: @@ -48,9 +54,9 @@ def get_ranges(headervalue, content_length): except Exception: continue if stop > content_length: - result.append((0, content_length)) + result.append(Range(0, content_length-1, content_length)) else: - result.append((content_length - stop, stop)) + result.append(Range(content_length - stop, content_length - 1, stop)) return result @@ -111,24 +117,91 @@ def write_compressed_file_obj(input_file, dest, compress_level=6): write_chunked_data(dest, data) write_chunked_data(dest, b'') +def get_range_parts(ranges, content_type, content_length): + + def part(r): + ans = ['--%s' % MULTIPART_SEPARATOR, 'Content-Range: bytes %d-%d/%d' % (r.start, r.stop, content_length)] + if content_type: + ans.append('Content-Type: %s' % content_type) + ans.append('') + return ('\r\n'.join(ans)).encode('ascii') + return list(map(part, ranges)) + [('--%s--' % MULTIPART_SEPARATOR).encode('ascii')] + +def parse_multipart_byterange(buf, content_type): + from calibre.srv.http import read_headers + sep = (content_type.rsplit('=', 1)[-1]).encode('utf-8') + ans = [] + + def parse_part(): + line = buf.readline() + if not line: + raise ValueError('Premature end of message') + if not line.startswith(b'--' + sep): + raise ValueError('Malformed start of multipart message') + if line.endswith(b'--'): + return None + headers = read_headers(buf.readline) + cr = headers.get('Content-Range') + if not cr: + raise ValueError('Missing Content-Range header in sub-part') + if not cr.startswith('bytes '): + raise ValueError('Malformed Content-Range header in sub-part, no prefix') + try: + start, stop = map(lambda x: int(x.strip()), cr.partition(' ')[-1].partition('/')[0].partition('-')[::2]) + except Exception: + raise ValueError('Malformed Content-Range header in sub-part, failed to parse byte range') + content_length = stop - start + 1 + ret = buf.read(content_length) + if len(ret) != content_length: + raise ValueError('Malformed sub-part, length of body not equal to length specified in Content-Range') + buf.readline() + return (start, ret) + while True: + data = parse_part() + if data is None: + break + ans.append(data) + return ans class FileSystemOutputFile(object): - def __init__(self, output, outheaders): - self.output_file = output - pos = output.tell() - output.seek(0, os.SEEK_END) - self.content_length = output.tell() - pos + def __init__(self, output, outheaders, size): + self.src_file = output + self.name = output.name + self.content_length = size self.etag = '"%s"' % hashlib.sha1(type('')(os.fstat(output.fileno()).st_mtime) + force_unicode(output.name or '')).hexdigest() - output.seek(pos) self.accept_ranges = True def write(self, dest): - shutil.copyfileobj(self.output_file, dest) - self.output_file = None + self.src_file.seek(0) + shutil.copyfileobj(self.src_file, dest) + self.src_file = None def write_compressed(self, dest): - write_compressed_file_obj(self.output_file, dest) + self.src_file.seek(0) + write_compressed_file_obj(self.src_file, dest) + self.src_file = None + + def write_ranges(self, ranges, dest): + if isinstance(ranges, Range): + r = ranges + self.copy_range(r.start, r.size, dest) + else: + for r, header in ranges: + dest.write(header) + if r is not None: + dest.write(b'\r\n') + self.copy_range(r.start, r.size, dest) + dest.write(b'\r\n') + self.src_file = None + + def copy_range(self, start, size, dest): + self.src_file.seek(start) + while size > 0: + data = self.src_file.read(min(size, DEFAULT_BUFFER_SIZE)) + dest.write(data) + size -= len(data) + del data class DynamicOutput(object): @@ -193,8 +266,19 @@ def parse_if_none_match(val): def finalize_output(output, inheaders, outheaders, status_code, is_http1, method, compress_min_size): ct = outheaders.get('Content-Type', '') compressible = not ct or ct.startswith('text/') or ct.startswith('image/svg') or ct.startswith('application/json') - if isinstance(output, file): - output = FileSystemOutputFile(output, outheaders) + try: + fd = output.fileno() + fsize = os.fstat(fd).st_size + except Exception: + fd = fsize = None + if fsize is not None: + output = FileSystemOutputFile(output, outheaders, fsize) + if 'Content-Type' not in outheaders: + mt = guess_type(output.name)[0] + if mt: + if mt in ('text/plain', 'text/html'): + mt =+ '; charset=UTF-8' + outheaders['Content-Type'] = mt elif isinstance(output, (bytes, type(''))): output = DynamicOutput(output, outheaders) elif isinstance(output, StaticGeneratedOutput): @@ -206,7 +290,12 @@ def finalize_output(output, inheaders, outheaders, status_code, is_http1, method acceptable_encoding(inheaders.get('Accept-Encoding', '')) and not is_http1) accept_ranges = (not compressible and output.accept_ranges is not None and status_code == httplib.OK and not is_http1) - ranges = None + ranges = get_ranges(inheaders.get('Range'), output.content_length) if output.accept_ranges and method in ('GET', 'HEAD') else None + if_range = (inheaders.get('If-Range') or '').strip() + if if_range and if_range != output.etag: + ranges = None + if ranges is not None and not ranges: + raise RangeNotSatisfiable(output.content_length) for header in ('Accept-Ranges', 'Content-Encoding', 'Transfer-Encoding', 'ETag', 'Content-Length'): outheaders.pop('header', all=True) @@ -216,8 +305,6 @@ def finalize_output(output, inheaders, outheaders, status_code, is_http1, method if matched: raise IfNoneMatch(output.etag) - # TODO: Ranges, If-Range - if output.etag and method in ('GET', 'HEAD'): outheaders.set('ETag', output.etag, replace_all=True) if accept_ranges: @@ -230,6 +317,20 @@ def finalize_output(output, inheaders, outheaders, status_code, is_http1, method if compressible or output.content_length is None: outheaders.set('Transfer-Encoding', 'chunked', replace_all=True) - output.commit = output.write_compressed if compressible else output.write + if ranges: + if len(ranges) == 1: + r = ranges[0] + outheaders.set('Content-Length', '%d' % r.size, replace_all=True) + outheaders.set('Content-Range', 'bytes %d-%d/%d' % (r.start, r.stop, output.content_length), replace_all=True) + output.commit = partial(output.write_ranges, r) + else: + range_parts = get_range_parts(ranges, outheaders.get('Content-Type'), output.content_length) + size = sum(map(len, range_parts)) + sum(r.size + 4 for r in ranges) + outheaders.set('Content-Length', '%d' % size, replace_all=True) + outheaders.set('Content-Type', 'multipart/byteranges; boundary=' + MULTIPART_SEPARATOR, replace_all=True) + output.commit = partial(output.write_ranges, izip_longest(ranges, range_parts)) + status_code = httplib.PARTIAL_CONTENT + else: + output.commit = output.write_compressed if compressible else output.write return status_code, output diff --git a/src/calibre/srv/tests/http.py b/src/calibre/srv/tests/http.py index eced708e95..c96672e973 100644 --- a/src/calibre/srv/tests/http.py +++ b/src/calibre/srv/tests/http.py @@ -8,8 +8,9 @@ __copyright__ = '2015, Kovid Goyal ' import textwrap, httplib, hashlib, zlib, string from io import BytesIO +from tempfile import NamedTemporaryFile -from calibre.ptempfile import PersistentTemporaryFile +from calibre import guess_type from calibre.srv.tests.base import BaseTest, TestServer def headers(raw): @@ -73,17 +74,17 @@ class TestHTTP(BaseTest): if len(args) == 1 and args[0] is None: self.assertIsNone(pval, val) else: - self.assertListEqual(pval, list(args), val) + self.assertListEqual([tuple(x) for x in pval], list(args), val) test('crap', None) test('crap=', None) test('crap=1', None) test('crap=1-2', None) test('bytes=a-2') - test('bytes=0-99', (0, 100)) - test('bytes=0-0,-1', (0, 1), (99, 1)) - test('bytes=-5', (95, 5)) - test('bytes=95-', (95, 5)) - test('bytes=-200', (0, 100)) + test('bytes=0-99', (0, 99, 100)) + test('bytes=0-0,-1', (0, 0, 1), (99, 99, 1)) + test('bytes=-5', (95, 99, 5)) + test('bytes=95-', (95, 99, 5)) + test('bytes=-200', (0, 99, 100)) # }}} def test_http_basic(self): # {{{ @@ -190,12 +191,12 @@ class TestHTTP(BaseTest): def test_http_response(self): # {{{ 'Test HTTP protocol responses' + from calibre.srv.respond import parse_multipart_byterange def handler(conn): return conn.generate_static_output('test', lambda : ''.join(conn.path)) - with TestServer(handler, timeout=0.1, compress_min_size=0) as server, PersistentTemporaryFile('test.epub') as f: + with TestServer(handler, timeout=0.1, compress_min_size=0) as server, NamedTemporaryFile(suffix='test.epub') as f: fdata = string.ascii_letters * 100 - f.write(fdata) - f.close() + f.write(fdata), f.seek(0) # Test ETag conn = server.connect() @@ -214,4 +215,47 @@ class TestHTTP(BaseTest): r = conn.getresponse() self.ae(r.status, httplib.OK), self.ae(zlib.decompress(r.read(), 16+zlib.MAX_WBITS), b'an_etagged_path') + # Test getting a filesystem file + server.change_handler(lambda conn: f) + conn = server.connect() + conn.request('GET', '/test') + r = conn.getresponse() + etag = type('')(r.getheader('ETag')) + self.assertTrue(etag) + self.ae(r.getheader('Content-Type'), guess_type(f.name)[0]) + self.ae(type('')(r.getheader('Accept-Ranges')), 'bytes') + self.ae(int(r.getheader('Content-Length')), len(fdata)) + self.ae(r.status, httplib.OK), self.ae(r.read(), fdata) + + conn.request('GET', '/test', headers={'Range':'bytes=0-25'}) + r = conn.getresponse() + self.ae(type('')(r.getheader('Accept-Ranges')), 'bytes') + self.ae(type('')(r.getheader('Content-Range')), 'bytes 0-25/%d' % len(fdata)) + self.ae(int(r.getheader('Content-Length')), 26) + self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[0:26]) + + conn.request('GET', '/test', headers={'Range':'bytes=100000-'}) + r = conn.getresponse() + self.ae(type('')(r.getheader('Content-Range')), 'bytes */%d' % len(fdata)) + self.ae(r.status, httplib.REQUESTED_RANGE_NOT_SATISFIABLE), self.ae(r.read(), b'') + + conn.request('GET', '/test', headers={'Range':'bytes=25-50', 'If-Range':etag}) + r = conn.getresponse() + self.ae(int(r.getheader('Content-Length')), 26) + self.ae(r.status, httplib.PARTIAL_CONTENT), self.ae(r.read(), fdata[25:51]) + + conn.request('GET', '/test', headers={'Range':'bytes=25-50', 'If-Range':'"nomatch"'}) + r = conn.getresponse() + self.assertFalse(r.getheader('Content-Range')) + self.ae(int(r.getheader('Content-Length')), len(fdata)) + self.ae(r.status, httplib.OK), self.ae(r.read(), fdata) + + conn.request('GET', '/test', headers={'Range':'bytes=0-25,26-50'}) + r = conn.getresponse() + clen = int(r.getheader('Content-Length')) + data = r.read() + self.ae(clen, len(data)) + buf = BytesIO(data) + self.ae(parse_multipart_byterange(buf, r.getheader('Content-Type')), [(0, fdata[:26]), (26, fdata[26:51])]) + # }}}