diff --git a/src/calibre/srv/http_request.py b/src/calibre/srv/http_request.py index c020e5bdc2..c2d1718e12 100644 --- a/src/calibre/srv/http_request.py +++ b/src/calibre/srv/http_request.py @@ -160,7 +160,7 @@ class HTTPRequest(Connection): def read(self, buf, endpos): size = endpos - buf.tell() if size > 0: - data = self.recv(min(size, DEFAULT_BUFFER_SIZE)) + data = self.recv(size) if data: buf.write(data) return len(data) >= size @@ -170,19 +170,21 @@ class HTTPRequest(Connection): return True def readline(self, buf): - if buf.tell() >= self.max_header_line_size - 1: + line = self.read_buffer.readline() + buf.write(line) + if buf.tell() > self.max_header_line_size: self.simple_response(self.header_line_too_long_error_code) return - data = self.recv(1) - if data: - buf.write(data) - if b'\n' == data: - line = buf.getvalue() - buf.seek(0), buf.truncate() - if line.endswith(b'\r\n'): - return line - else: - self.simple_response(httplib.BAD_REQUEST, 'HTTP requires CRLF line terminators') + if line.endswith(b'\n'): + line = buf.getvalue() + buf.seek(0), buf.truncate() + if not line.endswith(b'\r\n'): + self.simple_response(httplib.BAD_REQUEST, 'HTTP requires CRLF line terminators') + return + return line + if not line: + # read buffer is empty, fill it + self.fill_read_buffer() def connection_ready(self): 'Become ready to read an HTTP request' diff --git a/src/calibre/srv/loop.py b/src/calibre/srv/loop.py index bf27245481..9a55ef43d2 100644 --- a/src/calibre/srv/loop.py +++ b/src/calibre/srv/loop.py @@ -22,6 +22,89 @@ from calibre.utils.monotonic import monotonic READ, WRITE, RDWR = 'READ', 'WRITE', 'RDWR' +class ReadBuffer(object): + + __slots__ = ('ba', 'buf', 'read_pos', 'write_pos', 'full_state') + + def __init__(self, size=4096): + self.ba = bytearray(size) + self.buf = memoryview(self.ba) + self.read_pos = 0 + self.write_pos = 0 + self.full_state = WRITE + + @property + def has_data(self): + return self.read_pos != self.write_pos or self.full_state is READ + + @property + def has_space(self): + return self.read_pos != self.write_pos or self.full_state is WRITE + + def read(self, size): + # Read from this buffer, retuning the read bytes as a bytestring + if self.read_pos == self.write_pos and self.full_state is WRITE: + return b'' + if self.read_pos < self.write_pos: + sz = min(self.write_pos - self.read_pos, size) + npos = self.read_pos + sz + ans = self.buf[self.read_pos:npos].tobytes() + self.read_pos = npos + if self.read_pos == self.write_pos: + self.full_state = WRITE + else: + sz = min(size, len(self.buf) - self.read_pos) + ans = self.buf[self.read_pos:self.read_pos + sz].tobytes() + self.read_pos = (self.read_pos + sz) % len(self.buf) + if self.read_pos == self.write_pos: + self.full_state = WRITE + if size > sz and self.read_pos < self.write_pos: + ans += self.read(size - len(ans.buf)) + return ans + + def recv_from(self, socket): + # Write into this buffer from socket, return number of bytes written + if self.read_pos == self.write_pos and self.full_state is READ: + return 0 + if self.write_pos < self.read_pos: + num = socket.recv_into(self.buf[self.write_pos:self.read_pos]) + self.write_pos += num + else: + num = socket.recv_into(self.buf[self.write_pos:]) + self.write_pos = (self.write_pos + num) % len(self.buf) + if self.write_pos == self.read_pos: + self.full_state = READ + return num + + def readline(self): + # Return whatever is in the buffer upto (and including) the first \n + if self.read_pos == self.write_pos and self.full_state is WRITE: + return b'' + if self.read_pos < self.write_pos: + pos = self.ba.find(b'\n', self.read_pos, self.write_pos) + if pos < 0: + pos = self.write_pos - 1 + ans = self.buf[self.read_pos:pos + 1].tobytes() + self.read_pos = (pos + 1) % len(self.buf) + if self.read_pos == self.write_pos: + self.full_state = WRITE + else: + pos = self.ba.find(b'\n', self.read_pos) + if pos < 0: + pos = self.ba.find(b'\n', 0, self.write_pos) + if pos < 0: + pos = self.write_pos - 1 + ans = self.buf[self.read_pos:].tobytes() + self.buf[:pos+1].tobytes() + self.read_pos = (pos + 1) % len(self.buf) + if self.read_pos == self.write_pos: + self.full_state = WRITE + else: + ans = self.buf[self.read_pos:pos + 1].tobytes() + self.read_pos = (pos + 1) % len(self.buf) + if self.read_pos == self.write_pos: + self.full_state = WRITE + return ans + class Connection(object): def __init__(self, socket, opts, ssl_context, tdir): @@ -31,6 +114,7 @@ class Connection(object): self.ssl_context = ssl_context self.wait_for = READ self.response_started = False + self.read_buffer = ReadBuffer() if self.ssl_context is not None: self.ready = False self.socket = self.ssl_context.wrap_socket(socket, server_side=True, do_handshake_on_connect=False) @@ -86,9 +170,15 @@ class Connection(object): return 0 raise - def recv(self, buffer_size): + def recv(self, amt): + # If there is data in the read buffer we have to return only that, + # since we dont know if the socket has signalled it is ready for + # reading + if self.read_buffer.has_data: + return self.read_buffer.read(amt) + # read buffer is empty, so read directly from socket try: - data = self.socket.recv(buffer_size) + data = self.socket.recv(amt) self.last_activity = monotonic() if not data: # a closed connection is indicated by signaling @@ -104,6 +194,22 @@ class Connection(object): return b'' raise + def fill_read_buffer(self): + try: + num = self.read_buffer.recv_from(self.socket) + self.last_activity = monotonic() + if not num: + # a closed connection is indicated by signaling + # a read condition, and having recv() return 0. + self.ready = False + except socket.error as e: + if e.errno in socket_errors_nonblocking or e.errno in socket_errors_eintr: + return + if e.errno in socket_errors_socket_closed: + self.ready = False + return + raise + def close(self): self.ready = False self.handle_event = None # prevent reference cycles @@ -262,19 +368,30 @@ class ServerLoop(object): self.log.debug('Closing connection because of extended inactivity') self.close(s, conn) - read_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is READ or c.wait_for is RDWR] - write_needed = [c.socket for c in self.connection_map.itervalues() if c.wait_for is WRITE or c.wait_for is RDWR] - try: - readable, writable, _ = select.select([self.socket] + read_needed, write_needed, [], self.opts.timeout) - except select.error as e: - if e.errno in socket_errors_eintr: + read_needed, write_needed, readable = [], [], [] + for c in self.connection_map.itervalues(): + if c.wait_for is READ: + (readable if c.read_buffer.has_data else read_needed).append(c.socket) + elif c.wait_for is WRITE: + write_needed.append(c.socket) + else: + write_needed.append(c) + (readable if c.read_buffer.has_data else read_needed).append(c.socket) + + if readable: + writable = [] + else: + try: + readable, writable, _ = select.select([self.socket] + read_needed, write_needed, [], self.opts.timeout) + except select.error as e: + if e.errno in socket_errors_eintr: + return + for s, conn in tuple(self.connection_map.iteritems()): + try: + select.select([s], [], [], 0) + except select.error: + self.close(s, conn) # Bad socket, discard return - for s, conn in tuple(self.connection_map.iteritems()): - try: - select.select([s], [], [], 0) - except select.error: - self.close(s, conn) # Bad socket, discard - return if not self.ready: return