diff --git a/src/cherrypy/__init__.py b/src/cherrypy/__init__.py index 317e465062..274f00694b 100644 --- a/src/cherrypy/__init__.py +++ b/src/cherrypy/__init__.py @@ -57,7 +57,7 @@ These API's are described in the CherryPy specification: http://www.cherrypy.org/wiki/CherryPySpec """ -__version__ = "3.1.0" +__version__ = "3.1.1" from urlparse import urljoin as _urljoin @@ -337,6 +337,10 @@ class _ThreadLocalProxy(object): def __len__(self): child = getattr(serving, self.__attrname__) return len(child) + + def __nonzero__(self): + child = getattr(serving, self.__attrname__) + return bool(child) # Create request and response object (the same objects will be used diff --git a/src/cherrypy/_cpconfig.py b/src/cherrypy/_cpconfig.py index 815bb1936e..adc9911b17 100644 --- a/src/cherrypy/_cpconfig.py +++ b/src/cherrypy/_cpconfig.py @@ -303,13 +303,26 @@ def _engine_namespace_handler(k, v): elif k == 'autoreload_match': engine.autoreload.match = v elif k == 'reload_files': - engine.autoreload.files = v + engine.autoreload.files = set(v) elif k == 'deadlock_poll_freq': engine.timeout_monitor.frequency = v elif k == 'SIGHUP': engine.listeners['SIGHUP'] = set([v]) elif k == 'SIGTERM': engine.listeners['SIGTERM'] = set([v]) + elif "." in k: + plugin, attrname = k.split(".", 1) + plugin = getattr(engine, plugin) + if attrname == 'on': + if v and callable(getattr(plugin, 'subscribe', None)): + plugin.subscribe() + return + elif (not v) and callable(getattr(plugin, 'unsubscribe', None)): + plugin.unsubscribe() + return + setattr(plugin, attrname, v) + else: + setattr(engine, k, v) Config.namespaces["engine"] = _engine_namespace_handler diff --git a/src/cherrypy/_cpdispatch.py b/src/cherrypy/_cpdispatch.py index 4e7104bf72..c29fb05bbe 100644 --- a/src/cherrypy/_cpdispatch.py +++ b/src/cherrypy/_cpdispatch.py @@ -21,7 +21,127 @@ class PageHandler(object): self.kwargs = kwargs def __call__(self): - return self.callable(*self.args, **self.kwargs) + try: + return self.callable(*self.args, **self.kwargs) + except TypeError, x: + test_callable_spec(self.callable, self.args, self.kwargs) + raise + +def test_callable_spec(callable, callable_args, callable_kwargs): + """ + Inspect callable and test to see if the given args are suitable for it. + + When an error occurs during the handler's invoking stage there are 2 + erroneous cases: + 1. Too many parameters passed to a function which doesn't define + one of *args or **kwargs. + 2. Too little parameters are passed to the function. + + There are 3 sources of parameters to a cherrypy handler. + 1. query string parameters are passed as keyword parameters to the handler. + 2. body parameters are also passed as keyword parameters. + 3. when partial matching occurs, the final path atoms are passed as + positional args. + Both the query string and path atoms are part of the URI. If they are + incorrect, then a 404 Not Found should be raised. Conversely the body + parameters are part of the request; if they are invalid a 400 Bad Request. + """ + (args, varargs, varkw, defaults) = inspect.getargspec(callable) + + if args and args[0] == 'self': + args = args[1:] + + arg_usage = dict([(arg, 0,) for arg in args]) + vararg_usage = 0 + varkw_usage = 0 + extra_kwargs = set() + + for i, value in enumerate(callable_args): + try: + arg_usage[args[i]] += 1 + except IndexError: + vararg_usage += 1 + + for key in callable_kwargs.keys(): + try: + arg_usage[key] += 1 + except KeyError: + varkw_usage += 1 + extra_kwargs.add(key) + + for i, val in enumerate(defaults or []): + # Defaults take effect only when the arg hasn't been used yet. + if arg_usage[args[i]] == 0: + arg_usage[args[i]] += 1 + + missing_args = [] + multiple_args = [] + for key, usage in arg_usage.iteritems(): + if usage == 0: + missing_args.append(key) + elif usage > 1: + multiple_args.append(key) + + if missing_args: + # In the case where the method allows body arguments + # there are 3 potential errors: + # 1. not enough query string parameters -> 404 + # 2. not enough body parameters -> 400 + # 3. not enough path parts (partial matches) -> 404 + # + # We can't actually tell which case it is, + # so I'm raising a 404 because that covers 2/3 of the + # possibilities + # + # In the case where the method does not allow body + # arguments it's definitely a 404. + raise cherrypy.HTTPError(404, + message="Missing parameters: %s" % ",".join(missing_args)) + + # the extra positional arguments come from the path - 404 Not Found + if not varargs and vararg_usage > 0: + raise cherrypy.HTTPError(404) + + body_params = cherrypy.request.body_params or {} + body_params = set(body_params.keys()) + qs_params = set(callable_kwargs.keys()) - body_params + + if multiple_args: + + if qs_params.intersection(set(multiple_args)): + # If any of the multiple parameters came from the query string then + # it's a 404 Not Found + error = 404 + else: + # Otherwise it's a 400 Bad Request + error = 400 + + raise cherrypy.HTTPError(error, + message="Multiple values for parameters: "\ + "%s" % ",".join(multiple_args)) + + if not varkw and varkw_usage > 0: + + # If there were extra query string parameters, it's a 404 Not Found + extra_qs_params = set(qs_params).intersection(extra_kwargs) + if extra_qs_params: + raise cherrypy.HTTPError(404, + message="Unexpected query string "\ + "parameters: %s" % ", ".join(extra_qs_params)) + + # If there were any extra body parameters, it's a 400 Not Found + extra_body_params = set(body_params).intersection(extra_kwargs) + if extra_body_params: + raise cherrypy.HTTPError(400, + message="Unexpected body parameters: "\ + "%s" % ", ".join(extra_body_params)) + + +try: + import inspect +except ImportError: + test_callable_spec = lambda callable, args, kwargs: None + class LateParamPageHandler(PageHandler): diff --git a/src/cherrypy/_cplogging.py b/src/cherrypy/_cplogging.py index 12c93d3cd8..8556e108f9 100644 --- a/src/cherrypy/_cplogging.py +++ b/src/cherrypy/_cplogging.py @@ -126,7 +126,6 @@ class LogManager(object): if stream is None: stream=sys.stderr h = logging.StreamHandler(stream) - h.setLevel(logging.DEBUG) h.setFormatter(logfmt) h._cpbuiltin = "screen" log.addHandler(h) @@ -149,7 +148,6 @@ class LogManager(object): def _add_builtin_file_handler(self, log, fname): h = logging.FileHandler(fname) - h.setLevel(logging.DEBUG) h.setFormatter(logfmt) h._cpbuiltin = "file" log.addHandler(h) @@ -197,7 +195,6 @@ class LogManager(object): if enable: if not h: h = WSGIErrorHandler() - h.setLevel(logging.DEBUG) h.setFormatter(logfmt) h._cpbuiltin = "wsgi" log.addHandler(h) diff --git a/src/cherrypy/_cprequest.py b/src/cherrypy/_cprequest.py index 9ec310c972..3b245519cb 100644 --- a/src/cherrypy/_cprequest.py +++ b/src/cherrypy/_cprequest.py @@ -8,7 +8,7 @@ import types import cherrypy from cherrypy import _cpcgifs, _cpconfig from cherrypy._cperror import format_exc, bare_error -from cherrypy.lib import http +from cherrypy.lib import http, file_generator class Hook(object): @@ -747,15 +747,6 @@ class Request(object): cherrypy.response.finalize() -def file_generator(input, chunkSize=65536): - """Yield the given input (a file object) in chunks (default 64k). (Core)""" - chunk = input.read(chunkSize) - while chunk: - yield chunk - chunk = input.read(chunkSize) - input.close() - - class Body(object): """The body of the HTTP response (the response entity).""" diff --git a/src/cherrypy/_cpserver.py b/src/cherrypy/_cpserver.py index 53259cb086..0888295bec 100644 --- a/src/cherrypy/_cpserver.py +++ b/src/cherrypy/_cpserver.py @@ -49,6 +49,7 @@ class Server(ServerAdapter): protocol_version = 'HTTP/1.1' reverse_dns = False thread_pool = 10 + thread_pool_max = -1 max_request_header_size = 500 * 1024 max_request_body_size = 100 * 1024 * 1024 instance = None diff --git a/src/cherrypy/_cptools.py b/src/cherrypy/_cptools.py index 930ddab277..80ff583d22 100644 --- a/src/cherrypy/_cptools.py +++ b/src/cherrypy/_cptools.py @@ -466,7 +466,7 @@ _d.log_tracebacks = Tool('before_error_response', cptools.log_traceback) _d.log_headers = Tool('before_error_response', cptools.log_request_headers) _d.log_hooks = Tool('on_end_request', cptools.log_hooks, priority=100) _d.err_redirect = ErrorTool(cptools.redirect) -_d.etags = Tool('before_finalize', cptools.validate_etags) +_d.etags = Tool('before_finalize', cptools.validate_etags, priority=75) _d.decode = Tool('before_handler', encoding.decode) # the order of encoding, gzip, caching is important _d.encode = Tool('before_finalize', encoding.encode, priority=70) diff --git a/src/cherrypy/_cptree.py b/src/cherrypy/_cptree.py index 3dcd4a47e5..36d00865a2 100644 --- a/src/cherrypy/_cptree.py +++ b/src/cherrypy/_cptree.py @@ -153,6 +153,8 @@ class Tree(object): root: an instance of a "controller class" (a collection of page handler methods) which represents the root of the application. + This may also be an Application instance, or None if using + a dispatcher other than the default. script_name: a string containing the "mount point" of the application. This should start with a slash, and be the path portion of the URL at which to mount the given root. For example, if root.index() @@ -168,11 +170,15 @@ class Tree(object): if isinstance(root, Application): app = root + if script_name != "" and script_name != app.script_name: + raise ValueError, "Cannot specify a different script name and pass an Application instance to cherrypy.mount" + script_name = app.script_name else: app = Application(root, script_name) # If mounted at "", add favicon.ico - if script_name == "" and root and not hasattr(root, "favicon_ico"): + if (script_name == "" and root is not None + and not hasattr(root, "favicon_ico")): favicon = os.path.join(os.getcwd(), os.path.dirname(__file__), "favicon.ico") root.favicon_ico = tools.staticfile.handler(favicon) diff --git a/src/cherrypy/_cpwsgi_server.py b/src/cherrypy/_cpwsgi_server.py index 5953cf2ab0..ac8bfaa2ec 100644 --- a/src/cherrypy/_cpwsgi_server.py +++ b/src/cherrypy/_cpwsgi_server.py @@ -43,6 +43,7 @@ class CPWSGIServer(wsgiserver.CherryPyWSGIServer): s.__init__(self, bind_addr, cherrypy.tree, server.thread_pool, server.socket_host, + max = server.thread_pool_max, request_queue_size = server.socket_queue_size, timeout = server.socket_timeout, shutdown_timeout = server.shutdown_timeout, diff --git a/src/cherrypy/cherryd b/src/cherrypy/cherryd index 3d5cbdefce..ef1bd7a3d4 100644 --- a/src/cherrypy/cherryd +++ b/src/cherrypy/cherryd @@ -8,9 +8,10 @@ from cherrypy.process import plugins, servers def start(configfiles=None, daemonize=False, environment=None, - fastcgi=False, pidfile=None, imports=None): + fastcgi=False, scgi=False, pidfile=None, imports=None): """Subscribe all engine plugins and start the engine.""" - for i in imports: + sys.path = [''] + sys.path + for i in imports or []: exec "import %s" % i for c in configfiles or []: @@ -35,16 +36,27 @@ def start(configfiles=None, daemonize=False, environment=None, if hasattr(engine, "console_control_handler"): engine.console_control_handler.subscribe() - if fastcgi: - # turn off autoreload when using fastcgi - cherrypy.config.update({'autoreload.on': False}) - + if fastcgi and scgi: + # fastcgi and scgi aren't allowed together. + cherrypy.log.error("fastcgi and scgi aren't allowed together.", 'ENGINE') + sys.exit(1) + elif fastcgi or scgi: + # Turn off autoreload when using fastcgi or scgi. + cherrypy.config.update({'engine.autoreload_on': False}) + # Turn off the default HTTP server (which is subscribed by default). cherrypy.server.unsubscribe() - fastcgi_port = cherrypy.config.get('server.socket_port', 4000) - fastcgi_bindaddr = cherrypy.config.get('server.socket_host', '0.0.0.0') - bindAddress = (fastcgi_bindaddr, fastcgi_port) - f = servers.FlupFCGIServer(application=cherrypy.tree, bindAddress=bindAddress) + sock_file = cherrypy.config.get('server.socket_file', None) + if sock_file: + bindAddress = sock_file + else: + flup_port = cherrypy.config.get('server.socket_port', 4000) + flup_bindaddr = cherrypy.config.get('server.socket_host', '0.0.0.0') + bindAddress = (flup_bindaddr, flup_port) + if fastcgi: + f = servers.FlupFCGIServer(application=cherrypy.tree, bindAddress=bindAddress) + else: + f = servers.FlupSCGIServer(application=cherrypy.tree, bindAddress=bindAddress) s = servers.ServerAdapter(engine, httpserver=f, bind_addr=bindAddress) s.subscribe() @@ -70,6 +82,8 @@ if __name__ == '__main__': help="apply the given config environment") p.add_option('-f', action="store_true", dest='fastcgi', help="start a fastcgi server instead of the default HTTP server") + p.add_option('-s', action="store_true", dest='scgi', + help="start a scgi server instead of the default HTTP server") p.add_option('-i', '--import', action="append", dest='imports', help="specify modules to import") p.add_option('-p', '--pidfile', dest='pidfile', default=None, @@ -77,6 +91,6 @@ if __name__ == '__main__': options, args = p.parse_args() start(options.config, options.daemonize, - options.environment, options.fastcgi, options.pidfile, + options.environment, options.fastcgi, options.scgi, options.pidfile, options.imports) diff --git a/src/cherrypy/lib/__init__.py b/src/cherrypy/lib/__init__.py index 4e225cb12e..47be2eddd0 100644 --- a/src/cherrypy/lib/__init__.py +++ b/src/cherrypy/lib/__init__.py @@ -133,3 +133,26 @@ def unrepr(s): return _Builder().build(obj) + +def file_generator(input, chunkSize=65536): + """Yield the given input (a file object) in chunks (default 64k). (Core)""" + chunk = input.read(chunkSize) + while chunk: + yield chunk + chunk = input.read(chunkSize) + input.close() + + +def file_generator_limited(fileobj, count, chunk_size=65536): + """Yield the given file object in chunks, stopping after `count` + bytes has been emitted. Default chunk size is 64kB. (Core) + """ + remaining = count + while remaining > 0: + chunk = fileobj.read(min(chunk_size, remaining)) + chunklen = len(chunk) + if chunklen == 0: + return + remaining -= chunklen + yield chunk + diff --git a/src/cherrypy/lib/cptools.py b/src/cherrypy/lib/cptools.py index 966c328945..eefd1ae73a 100644 --- a/src/cherrypy/lib/cptools.py +++ b/src/cherrypy/lib/cptools.py @@ -212,7 +212,7 @@ class SessionAuth(object): def on_check(self, username): pass - def login_screen(self, from_page='..', username='', error_msg=''): + def login_screen(self, from_page='..', username='', error_msg='', **kwargs): return """ Message: %(error_msg)s
@@ -224,7 +224,7 @@ Message: %(error_msg)s """ % {'from_page': from_page, 'username': username, 'error_msg': error_msg} - def do_login(self, username, password, from_page='..'): + def do_login(self, username, password, from_page='..', **kwargs): """Login. May raise redirect, or return True if request handled.""" error_msg = self.check_username_and_password(username, password) if error_msg: @@ -239,7 +239,7 @@ Message: %(error_msg)s self.on_login(username) raise cherrypy.HTTPRedirect(from_page or "/") - def do_logout(self, from_page='..'): + def do_logout(self, from_page='..', **kwargs): """Logout. May raise redirect, or return True if request handled.""" sess = cherrypy.session username = sess.get(self.session_key) diff --git a/src/cherrypy/lib/http.py b/src/cherrypy/lib/http.py index 5449d55b7f..82dfa5bf80 100644 --- a/src/cherrypy/lib/http.py +++ b/src/cherrypy/lib/http.py @@ -251,7 +251,12 @@ def valid_status(status): image_map_pattern = re.compile(r"[0-9]+,[0-9]+") def parse_query_string(query_string, keep_blank_values=True): - """Build a params dictionary from a query_string.""" + """Build a params dictionary from a query_string. + + Duplicate key/value pairs in the provided query_string will be + returned as {'key': [val1, val2, ...]}. Single key/values will + be returned as strings: {'key': 'value'}. + """ if image_map_pattern.match(query_string): # Server-side image map. Map the coords to 'x' and 'y' # (like CGI::Request does). diff --git a/src/cherrypy/lib/httpauth.py b/src/cherrypy/lib/httpauth.py index 524db9a6b5..bc658244cd 100644 --- a/src/cherrypy/lib/httpauth.py +++ b/src/cherrypy/lib/httpauth.py @@ -275,7 +275,7 @@ def _computeDigestResponse(auth_map, password, method = "GET", A1 = None,**kwarg else: H_A1 = H(_A1(params, password)) - if qop == "auth" or aop == "auth-int": + if qop in ("auth", "auth-int"): # If the "qop" value is "auth" or "auth-int": # request-digest = <"> < KD ( H(A1), unq(nonce-value) # ":" nc-value @@ -290,7 +290,6 @@ def _computeDigestResponse(auth_map, password, method = "GET", A1 = None,**kwarg params["qop"], H_A2, ) - elif qop is None: # If the "qop" directive is not present (this construction is # for compatibility with RFC 2069): diff --git a/src/cherrypy/lib/profiler.py b/src/cherrypy/lib/profiler.py index 9d5481dd7e..704fec47a5 100644 --- a/src/cherrypy/lib/profiler.py +++ b/src/cherrypy/lib/profiler.py @@ -160,7 +160,15 @@ class ProfileAggregator(Profiler): class make_app: def __init__(self, nextapp, path=None, aggregate=False): - """Make a WSGI middleware app which wraps 'nextapp' with profiling.""" + """Make a WSGI middleware app which wraps 'nextapp' with profiling. + + nextapp: the WSGI application to wrap, usually an instance of + cherrypy.Application. + path: where to dump the profiling output. + aggregate: if True, profile data for all HTTP requests will go in + a single file. If False (the default), each HTTP request will + dump its profile data into a separate file. + """ self.nextapp = nextapp self.aggregate = aggregate if aggregate: diff --git a/src/cherrypy/lib/safemime.py b/src/cherrypy/lib/safemime.py index cf41dbf27b..0d13ae9a91 100644 --- a/src/cherrypy/lib/safemime.py +++ b/src/cherrypy/lib/safemime.py @@ -109,7 +109,7 @@ class MultipartWrapper(object): def safe_multipart(flash_only=False): """Wrap request.rfile in a reader that won't crash on no trailing CRLF.""" h = cherrypy.request.headers - if not h.get('Content-Type').startswith('multipart/'): + if not h.get('Content-Type','').startswith('multipart/'): return if flash_only and not 'Shockwave Flash' in h.get('User-Agent', ''): return diff --git a/src/cherrypy/lib/sessions-r2062.py b/src/cherrypy/lib/sessions-r2062.py new file mode 100644 index 0000000000..881002e6fd --- /dev/null +++ b/src/cherrypy/lib/sessions-r2062.py @@ -0,0 +1,698 @@ +"""Session implementation for CherryPy. + +We use cherrypy.request to store some convenient variables as +well as data about the session for the current request. Instead of +polluting cherrypy.request we use a Session object bound to +cherrypy.session to store these variables. +""" + +import datetime +import os +try: + import cPickle as pickle +except ImportError: + import pickle +import random +import sha +import time +import threading +import types +from warnings import warn + +import cherrypy +from cherrypy.lib import http + + +missing = object() + +class Session(object): + """A CherryPy dict-like Session object (one per request).""" + + __metaclass__ = cherrypy._AttributeDocstrings + + _id = None + id_observers = None + id_observers__doc = "A list of callbacks to which to pass new id's." + + id__doc = "The current session ID." + def _get_id(self): + return self._id + def _set_id(self, value): + self._id = value + for o in self.id_observers: + o(value) + id = property(_get_id, _set_id, doc=id__doc) + + timeout = 60 + timeout__doc = "Number of minutes after which to delete session data." + + locked = False + locked__doc = """ + If True, this session instance has exclusive read/write access + to session data.""" + + loaded = False + loaded__doc = """ + If True, data has been retrieved from storage. This should happen + automatically on the first attempt to access session data.""" + + clean_thread = None + clean_thread__doc = "Class-level Monitor which calls self.clean_up." + + clean_freq = 5 + clean_freq__doc = "The poll rate for expired session cleanup in minutes." + + def __init__(self, id=None, **kwargs): + self.id_observers = [] + self._data = {} + + for k, v in kwargs.iteritems(): + setattr(self, k, v) + + if id is None: + self.regenerate() + else: + self.id = id + if not self._exists(): + # Expired or malicious session. Make a new one. + # See http://www.cherrypy.org/ticket/709. + self.id = None + self.regenerate() + + def regenerate(self): + """Replace the current session (with a new id).""" + if self.id is not None: + self.delete() + + old_session_was_locked = self.locked + if old_session_was_locked: + self.release_lock() + + self.id = None + while self.id is None: + self.id = self.generate_id() + # Assert that the generated id is not already stored. + if self._exists(): + self.id = None + + if old_session_was_locked: + self.acquire_lock() + + def clean_up(self): + """Clean up expired sessions.""" + pass + + try: + os.urandom(20) + except (AttributeError, NotImplementedError): + # os.urandom not available until Python 2.4. Fall back to random.random. + def generate_id(self): + """Return a new session id.""" + return sha.new('%s' % random.random()).hexdigest() + else: + def generate_id(self): + """Return a new session id.""" + return os.urandom(20).encode('hex') + + def save(self): + """Save session data.""" + try: + # If session data has never been loaded then it's never been + # accessed: no need to delete it + if self.loaded: + t = datetime.timedelta(seconds = self.timeout * 60) + expiration_time = datetime.datetime.now() + t + self._save(expiration_time) + + finally: + if self.locked: + # Always release the lock if the user didn't release it + self.release_lock() + + def load(self): + """Copy stored session data into this session instance.""" + data = self._load() + # data is either None or a tuple (session_data, expiration_time) + if data is None or data[1] < datetime.datetime.now(): + # Expired session: flush session data + self._data = {} + else: + self._data = data[0] + self.loaded = True + + # Stick the clean_thread in the class, not the instance. + # The instances are created and destroyed per-request. + cls = self.__class__ + if self.clean_freq and not cls.clean_thread: + # clean_up is in instancemethod and not a classmethod, + # so that tool config can be accessed inside the method. + t = cherrypy.process.plugins.Monitor( + cherrypy.engine, self.clean_up, self.clean_freq * 60) + t.subscribe() + cls.clean_thread = t + t.start() + + def delete(self): + """Delete stored session data.""" + self._delete() + + def __getitem__(self, key): + if not self.loaded: self.load() + return self._data[key] + + def __setitem__(self, key, value): + if not self.loaded: self.load() + self._data[key] = value + + def __delitem__(self, key): + if not self.loaded: self.load() + del self._data[key] + + def pop(self, key, default=missing): + """Remove the specified key and return the corresponding value. + If key is not found, default is returned if given, + otherwise KeyError is raised. + """ + if not self.loaded: self.load() + if default is missing: + return self._data.pop(key) + else: + return self._data.pop(key, default) + + def __contains__(self, key): + if not self.loaded: self.load() + return key in self._data + + def has_key(self, key): + """D.has_key(k) -> True if D has a key k, else False.""" + if not self.loaded: self.load() + return self._data.has_key(key) + + def get(self, key, default=None): + """D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.""" + if not self.loaded: self.load() + return self._data.get(key, default) + + def update(self, d): + """D.update(E) -> None. Update D from E: for k in E: D[k] = E[k].""" + if not self.loaded: self.load() + self._data.update(d) + + def setdefault(self, key, default=None): + """D.setdefault(k[,d]) -> D.get(k,d), also set D[k]=d if k not in D.""" + if not self.loaded: self.load() + return self._data.setdefault(key, default) + + def clear(self): + """D.clear() -> None. Remove all items from D.""" + if not self.loaded: self.load() + self._data.clear() + + def keys(self): + """D.keys() -> list of D's keys.""" + if not self.loaded: self.load() + return self._data.keys() + + def items(self): + """D.items() -> list of D's (key, value) pairs, as 2-tuples.""" + if not self.loaded: self.load() + return self._data.items() + + def values(self): + """D.values() -> list of D's values.""" + if not self.loaded: self.load() + return self._data.values() + + +class RamSession(Session): + + # Class-level objects. Don't rebind these! + cache = {} + locks = {} + + def clean_up(self): + """Clean up expired sessions.""" + now = datetime.datetime.now() + for id, (data, expiration_time) in self.cache.items(): + if expiration_time < now: + try: + del self.cache[id] + except KeyError: + pass + try: + del self.locks[id] + except KeyError: + pass + + def _exists(self): + return self.id in self.cache + + def _load(self): + return self.cache.get(self.id) + + def _save(self, expiration_time): + self.cache[self.id] = (self._data, expiration_time) + + def _delete(self): + del self.cache[self.id] + + def acquire_lock(self): + """Acquire an exclusive lock on the currently-loaded session data.""" + self.locked = True + self.locks.setdefault(self.id, threading.RLock()).acquire() + + def release_lock(self): + """Release the lock on the currently-loaded session data.""" + self.locks[self.id].release() + self.locked = False + + def __len__(self): + """Return the number of active sessions.""" + return len(self.cache) + + +class FileSession(Session): + """Implementation of the File backend for sessions + + storage_path: the folder where session data will be saved. Each session + will be saved as pickle.dump(data, expiration_time) in its own file; + the filename will be self.SESSION_PREFIX + self.id. + """ + + SESSION_PREFIX = 'session-' + LOCK_SUFFIX = '.lock' + + def __init__(self, id=None, **kwargs): + # The 'storage_path' arg is required for file-based sessions. + kwargs['storage_path'] = os.path.abspath(kwargs['storage_path']) + Session.__init__(self, id=id, **kwargs) + + def setup(cls, **kwargs): + """Set up the storage system for file-based sessions. + + This should only be called once per process; this will be done + automatically when using sessions.init (as the built-in Tool does). + """ + # The 'storage_path' arg is required for file-based sessions. + kwargs['storage_path'] = os.path.abspath(kwargs['storage_path']) + + for k, v in kwargs.iteritems(): + setattr(cls, k, v) + + # Warn if any lock files exist at startup. + lockfiles = [fname for fname in os.listdir(cls.storage_path) + if (fname.startswith(cls.SESSION_PREFIX) + and fname.endswith(cls.LOCK_SUFFIX))] + if lockfiles: + plural = ('', 's')[len(lockfiles) > 1] + warn("%s session lockfile%s found at startup. If you are " + "only running one process, then you may need to " + "manually delete the lockfiles found at %r." + % (len(lockfiles), plural, cls.storage_path)) + setup = classmethod(setup) + + def _get_file_path(self): + f = os.path.join(self.storage_path, self.SESSION_PREFIX + self.id) + if not os.path.abspath(f).startswith(self.storage_path): + raise cherrypy.HTTPError(400, "Invalid session id in cookie.") + return f + + def _exists(self): + path = self._get_file_path() + return os.path.exists(path) + + def _load(self, path=None): + if path is None: + path = self._get_file_path() + try: + f = open(path, "rb") + try: + return pickle.load(f) + finally: + f.close() + except (IOError, EOFError): + return None + + def _save(self, expiration_time): + f = open(self._get_file_path(), "wb") + try: + pickle.dump((self._data, expiration_time), f) + finally: + f.close() + + def _delete(self): + try: + os.unlink(self._get_file_path()) + except OSError: + pass + + def acquire_lock(self, path=None): + """Acquire an exclusive lock on the currently-loaded session data.""" + if path is None: + path = self._get_file_path() + path += self.LOCK_SUFFIX + while True: + try: + lockfd = os.open(path, os.O_CREAT|os.O_WRONLY|os.O_EXCL) + except OSError: + time.sleep(0.1) + else: + os.close(lockfd) + break + self.locked = True + + def release_lock(self, path=None): + """Release the lock on the currently-loaded session data.""" + if path is None: + path = self._get_file_path() + os.unlink(path + self.LOCK_SUFFIX) + self.locked = False + + def clean_up(self): + """Clean up expired sessions.""" + now = datetime.datetime.now() + # Iterate over all session files in self.storage_path + for fname in os.listdir(self.storage_path): + if (fname.startswith(self.SESSION_PREFIX) + and not fname.endswith(self.LOCK_SUFFIX)): + # We have a session file: lock and load it and check + # if it's expired. If it fails, nevermind. + path = os.path.join(self.storage_path, fname) + self.acquire_lock(path) + try: + contents = self._load(path) + # _load returns None on IOError + if contents is not None: + data, expiration_time = contents + if expiration_time < now: + # Session expired: deleting it + os.unlink(path) + finally: + self.release_lock(path) + + def __len__(self): + """Return the number of active sessions.""" + return len([fname for fname in os.listdir(self.storage_path) + if (fname.startswith(self.SESSION_PREFIX) + and not fname.endswith(self.LOCK_SUFFIX))]) + + +class PostgresqlSession(Session): + """ Implementation of the PostgreSQL backend for sessions. It assumes + a table like this: + + create table session ( + id varchar(40), + data text, + expiration_time timestamp + ) + + You must provide your own get_db function. + """ + + def __init__(self, id=None, **kwargs): + Session.__init__(self, id, **kwargs) + self.cursor = self.db.cursor() + + def setup(cls, **kwargs): + """Set up the storage system for Postgres-based sessions. + + This should only be called once per process; this will be done + automatically when using sessions.init (as the built-in Tool does). + """ + for k, v in kwargs.iteritems(): + setattr(cls, k, v) + + self.db = self.get_db() + setup = classmethod(setup) + + def __del__(self): + if self.cursor: + self.cursor.close() + self.db.commit() + + def _exists(self): + # Select session data from table + self.cursor.execute('select data, expiration_time from session ' + 'where id=%s', (self.id,)) + rows = self.cursor.fetchall() + return bool(rows) + + def _load(self): + # Select session data from table + self.cursor.execute('select data, expiration_time from session ' + 'where id=%s', (self.id,)) + rows = self.cursor.fetchall() + if not rows: + return None + + pickled_data, expiration_time = rows[0] + data = pickle.loads(pickled_data) + return data, expiration_time + + def _save(self, expiration_time): + pickled_data = pickle.dumps(self._data) + self.cursor.execute('update session set data = %s, ' + 'expiration_time = %s where id = %s', + (pickled_data, expiration_time, self.id)) + + def _delete(self): + self.cursor.execute('delete from session where id=%s', (self.id,)) + + def acquire_lock(self): + """Acquire an exclusive lock on the currently-loaded session data.""" + # We use the "for update" clause to lock the row + self.locked = True + self.cursor.execute('select id from session where id=%s for update', + (self.id,)) + + def release_lock(self): + """Release the lock on the currently-loaded session data.""" + # We just close the cursor and that will remove the lock + # introduced by the "for update" clause + self.cursor.close() + self.locked = False + + def clean_up(self): + """Clean up expired sessions.""" + self.cursor.execute('delete from session where expiration_time < %s', + (datetime.datetime.now(),)) + + +class MemcachedSession(Session): + + # The most popular memcached client for Python isn't thread-safe. + # Wrap all .get and .set operations in a single lock. + mc_lock = threading.RLock() + + # This is a seperate set of locks per session id. + locks = {} + + servers = ['127.0.0.1:11211'] + + def setup(cls, **kwargs): + """Set up the storage system for memcached-based sessions. + + This should only be called once per process; this will be done + automatically when using sessions.init (as the built-in Tool does). + """ + for k, v in kwargs.iteritems(): + setattr(cls, k, v) + + import memcache + cls.cache = memcache.Client(cls.servers) + setup = classmethod(setup) + + def _exists(self): + self.mc_lock.acquire() + try: + return bool(self.cache.get(self.id)) + finally: + self.mc_lock.release() + + def _load(self): + self.mc_lock.acquire() + try: + return self.cache.get(self.id) + finally: + self.mc_lock.release() + + def _save(self, expiration_time): + # Send the expiration time as "Unix time" (seconds since 1/1/1970) + td = int(time.mktime(expiration_time.timetuple())) + self.mc_lock.acquire() + try: + if not self.cache.set(self.id, (self._data, expiration_time), td): + raise AssertionError("Session data for id %r not set." % self.id) + finally: + self.mc_lock.release() + + def _delete(self): + self.cache.delete(self.id) + + def acquire_lock(self): + """Acquire an exclusive lock on the currently-loaded session data.""" + self.locked = True + self.locks.setdefault(self.id, threading.RLock()).acquire() + + def release_lock(self): + """Release the lock on the currently-loaded session data.""" + self.locks[self.id].release() + self.locked = False + + def __len__(self): + """Return the number of active sessions.""" + raise NotImplementedError + + +# Hook functions (for CherryPy tools) + +def save(): + """Save any changed session data.""" + + if not hasattr(cherrypy.serving, "session"): + return + + # Guard against running twice + if hasattr(cherrypy.request, "_sessionsaved"): + return + cherrypy.request._sessionsaved = True + + if cherrypy.response.stream: + # If the body is being streamed, we have to save the data + # *after* the response has been written out + cherrypy.request.hooks.attach('on_end_request', cherrypy.session.save) + else: + # If the body is not being streamed, we save the data now + # (so we can release the lock). + if isinstance(cherrypy.response.body, types.GeneratorType): + cherrypy.response.collapse_body() + cherrypy.session.save() +save.failsafe = True + +def close(): + """Close the session object for this request.""" + sess = getattr(cherrypy.serving, "session", None) + if getattr(sess, "locked", False): + # If the session is still locked we release the lock + sess.release_lock() +close.failsafe = True +close.priority = 90 + + +def init(storage_type='ram', path=None, path_header=None, name='session_id', + timeout=60, domain=None, secure=False, clean_freq=5, + persistent=True, **kwargs): + """Initialize session object (using cookies). + + storage_type: one of 'ram', 'file', 'postgresql'. This will be used + to look up the corresponding class in cherrypy.lib.sessions + globals. For example, 'file' will use the FileSession class. + path: the 'path' value to stick in the response cookie metadata. + path_header: if 'path' is None (the default), then the response + cookie 'path' will be pulled from request.headers[path_header]. + name: the name of the cookie. + timeout: the expiration timeout (in minutes) for the stored session data. + If 'persistent' is True (the default), this is also the timeout + for the cookie. + domain: the cookie domain. + secure: if False (the default) the cookie 'secure' value will not + be set. If True, the cookie 'secure' value will be set (to 1). + clean_freq (minutes): the poll rate for expired session cleanup. + persistent: if True (the default), the 'timeout' argument will be used + to expire the cookie. If False, the cookie will not have an expiry, + and the cookie will be a "session cookie" which expires when the + browser is closed. + + Any additional kwargs will be bound to the new Session instance, + and may be specific to the storage type. See the subclass of Session + you're using for more information. + """ + + request = cherrypy.request + + # Guard against running twice + if hasattr(request, "_session_init_flag"): + return + request._session_init_flag = True + + # Check if request came with a session ID + id = None + if name in request.cookie: + id = request.cookie[name].value + + # Find the storage class and call setup (first time only). + storage_class = storage_type.title() + 'Session' + storage_class = globals()[storage_class] + if not hasattr(cherrypy, "session"): + if hasattr(storage_class, "setup"): + storage_class.setup(**kwargs) + + # Create and attach a new Session instance to cherrypy.serving. + # It will possess a reference to (and lock, and lazily load) + # the requested session data. + kwargs['timeout'] = timeout + kwargs['clean_freq'] = clean_freq + cherrypy.serving.session = sess = storage_class(id, **kwargs) + def update_cookie(id): + """Update the cookie every time the session id changes.""" + cherrypy.response.cookie[name] = id + sess.id_observers.append(update_cookie) + + # Create cherrypy.session which will proxy to cherrypy.serving.session + if not hasattr(cherrypy, "session"): + cherrypy.session = cherrypy._ThreadLocalProxy('session') + + if persistent: + cookie_timeout = timeout + else: + # See http://support.microsoft.com/kb/223799/EN-US/ + # and http://support.mozilla.com/en-US/kb/Cookies + cookie_timeout = None + set_response_cookie(path=path, path_header=path_header, name=name, + timeout=cookie_timeout, domain=domain, secure=secure) + + +def set_response_cookie(path=None, path_header=None, name='session_id', + timeout=60, domain=None, secure=False): + """Set a response cookie for the client. + + path: the 'path' value to stick in the response cookie metadata. + path_header: if 'path' is None (the default), then the response + cookie 'path' will be pulled from request.headers[path_header]. + name: the name of the cookie. + timeout: the expiration timeout for the cookie. If 0 or other boolean + False, no 'expires' param will be set, and the cookie will be a + "session cookie" which expires when the browser is closed. + domain: the cookie domain. + secure: if False (the default) the cookie 'secure' value will not + be set. If True, the cookie 'secure' value will be set (to 1). + """ + # Set response cookie + cookie = cherrypy.response.cookie + cookie[name] = cherrypy.serving.session.id + cookie[name]['path'] = (path or cherrypy.request.headers.get(path_header) + or '/') + + # We'd like to use the "max-age" param as indicated in + # http://www.faqs.org/rfcs/rfc2109.html but IE doesn't + # save it to disk and the session is lost if people close + # the browser. So we have to use the old "expires" ... sigh ... +## cookie[name]['max-age'] = timeout * 60 + if timeout: + cookie[name]['expires'] = http.HTTPDate(time.time() + (timeout * 60)) + if domain is not None: + cookie[name]['domain'] = domain + if secure: + cookie[name]['secure'] = 1 + + +def expire(): + """Expire the current session cookie.""" + name = cherrypy.request.config.get('tools.sessions.name', 'session_id') + one_year = 60 * 60 * 24 * 365 + exp = time.gmtime(time.time() - one_year) + t = time.strftime("%a, %d-%b-%Y %H:%M:%S GMT", exp) + cherrypy.response.cookie[name]['expires'] = t + + diff --git a/src/cherrypy/lib/sessions.py b/src/cherrypy/lib/sessions.py index fe842192f3..fb3676c557 100644 --- a/src/cherrypy/lib/sessions.py +++ b/src/cherrypy/lib/sessions.py @@ -143,7 +143,7 @@ class Session(object): # Stick the clean_thread in the class, not the instance. # The instances are created and destroyed per-request. cls = self.__class__ - if not cls.clean_thread: + if self.clean_freq and not cls.clean_thread: # clean_up is in instancemethod and not a classmethod, # so that tool config can be accessed inside the method. t = cherrypy.process.plugins.Monitor( @@ -282,14 +282,19 @@ class FileSession(Session): SESSION_PREFIX = 'session-' LOCK_SUFFIX = '.lock' + def __init__(self, id=None, **kwargs): + # The 'storage_path' arg is required for file-based sessions. + kwargs['storage_path'] = os.path.abspath(kwargs['storage_path']) + Session.__init__(self, id=id, **kwargs) + def setup(cls, **kwargs): """Set up the storage system for file-based sessions. This should only be called once per process; this will be done automatically when using sessions.init (as the built-in Tool does). """ - if 'storage_path' in kwargs: - kwargs['storage_path'] = os.path.abspath(kwargs['storage_path']) + # The 'storage_path' arg is required for file-based sessions. + kwargs['storage_path'] = os.path.abspath(kwargs['storage_path']) for k, v in kwargs.iteritems(): setattr(cls, k, v) diff --git a/src/cherrypy/lib/static.py b/src/cherrypy/lib/static.py index e9df62f7bd..f4e3efe054 100644 --- a/src/cherrypy/lib/static.py +++ b/src/cherrypy/lib/static.py @@ -10,7 +10,7 @@ import time import urllib import cherrypy -from cherrypy.lib import cptools, http +from cherrypy.lib import cptools, http, file_generator_limited def serve_file(path, content_type=None, disposition=None, name=None): @@ -83,13 +83,15 @@ def serve_file(path, content_type=None, disposition=None, name=None): if len(r) == 1: # Return a single-part response. start, stop = r[0] + if stop > c_len: + stop = c_len r_len = stop - start response.status = "206 Partial Content" response.headers['Content-Range'] = ("bytes %s-%s/%s" % (start, stop - 1, c_len)) response.headers['Content-Length'] = r_len bodyfile.seek(start) - response.body = bodyfile.read(r_len) + response.body = file_generator_limited(bodyfile, r_len) else: # Return a multipart/byteranges response. response.status = "206 Partial Content" @@ -111,7 +113,8 @@ def serve_file(path, content_type=None, disposition=None, name=None): yield ("\r\nContent-range: bytes %s-%s/%s\r\n\r\n" % (start, stop - 1, c_len)) bodyfile.seek(start) - yield bodyfile.read(stop - start) + for chunk in file_generator_limited(bodyfile, stop-start): + yield chunk yield "\r\n" # Final boundary yield "--" + boundary + "--" diff --git a/src/cherrypy/lib/xmlrpc.py b/src/cherrypy/lib/xmlrpc.py index c95970f205..59ee0278fe 100644 --- a/src/cherrypy/lib/xmlrpc.py +++ b/src/cherrypy/lib/xmlrpc.py @@ -42,7 +42,7 @@ def respond(body, encoding='utf-8', allow_none=0): encoding=encoding, allow_none=allow_none)) -def on_error(): +def on_error(*args, **kwargs): body = str(sys.exc_info()[1]) import xmlrpclib _set_response(xmlrpclib.dumps(xmlrpclib.Fault(1, body))) diff --git a/src/cherrypy/process/plugins.py b/src/cherrypy/process/plugins.py index 7f3cde7ec8..0e8b4bf919 100644 --- a/src/cherrypy/process/plugins.py +++ b/src/cherrypy/process/plugins.py @@ -21,6 +21,7 @@ class SimplePlugin(object): def subscribe(self): """Register this object as a (multi-channel) listener on the bus.""" for channel in self.bus.listeners: + # Subscribe self.start, self.exit, etc. if present. method = getattr(self, channel, None) if method is not None: self.bus.subscribe(channel, method) @@ -28,6 +29,7 @@ class SimplePlugin(object): def unsubscribe(self): """Unregister this object as a listener on the bus.""" for channel in self.bus.listeners: + # Unsubscribe self.start, self.exit, etc. if present. method = getattr(self, channel, None) if method is not None: self.bus.unsubscribe(channel, method) @@ -213,9 +215,9 @@ class DropPrivileges(SimplePlugin): else: self.bus.log('Started as uid: %r gid: %r' % current_ids()) if self.gid is not None: - os.setgid(gid) + os.setgid(self.gid) if self.uid is not None: - os.setuid(uid) + os.setuid(self.uid) self.bus.log('Running as uid: %r gid: %r' % current_ids()) # umask @@ -231,7 +233,10 @@ class DropPrivileges(SimplePlugin): (old_umask, self.umask)) self.finalized = True - start.priority = 75 + # This is slightly higher than the priority for server.start + # in order to facilitate the most common use: starting on a low + # port (which requires root) and then dropping to another user. + start.priority = 77 class Daemonizer(SimplePlugin): diff --git a/src/cherrypy/process/servers.py b/src/cherrypy/process/servers.py index f4baf83318..ac4178db0b 100644 --- a/src/cherrypy/process/servers.py +++ b/src/cherrypy/process/servers.py @@ -124,8 +124,16 @@ class FlupFCGIServer(object): """Adapter for a flup.server.fcgi.WSGIServer.""" def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.ready = False + + def start(self): + """Start the FCGI server.""" + # We have to instantiate the server class here because its __init__ + # starts a threadpool. If we do it too early, daemonize won't work. from flup.server.fcgi import WSGIServer - self.fcgiserver = WSGIServer(*args, **kwargs) + self.fcgiserver = WSGIServer(*self.args, **self.kwargs) # TODO: report this bug upstream to flup. # If we don't set _oldSIGs on Windows, we get: # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", @@ -135,11 +143,8 @@ class FlupFCGIServer(object): # line 156, in _restoreSignalHandlers # for signum,handler in self._oldSIGs: # AttributeError: 'WSGIServer' object has no attribute '_oldSIGs' + self.fcgiserver._installSignalHandlers = lambda: None self.fcgiserver._oldSIGs = [] - self.ready = False - - def start(self): - """Start the FCGI server.""" self.ready = True self.fcgiserver.run() @@ -152,6 +157,43 @@ class FlupFCGIServer(object): self.fcgiserver._threadPool.maxSpare = 0 +class FlupSCGIServer(object): + """Adapter for a flup.server.scgi.WSGIServer.""" + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.ready = False + + def start(self): + """Start the SCGI server.""" + # We have to instantiate the server class here because its __init__ + # starts a threadpool. If we do it too early, daemonize won't work. + from flup.server.scgi import WSGIServer + self.scgiserver = WSGIServer(*self.args, **self.kwargs) + # TODO: report this bug upstream to flup. + # If we don't set _oldSIGs on Windows, we get: + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 108, in run + # self._restoreSignalHandlers() + # File "C:\Python24\Lib\site-packages\flup\server\threadedserver.py", + # line 156, in _restoreSignalHandlers + # for signum,handler in self._oldSIGs: + # AttributeError: 'WSGIServer' object has no attribute '_oldSIGs' + self.scgiserver._installSignalHandlers = lambda: None + self.scgiserver._oldSIGs = [] + self.ready = True + self.scgiserver.run() + + def stop(self): + """Stop the HTTP server.""" + self.ready = False + # Forcibly stop the scgi server main event loop. + self.scgiserver._keepGoing = False + # Force all worker threads to die off. + self.scgiserver._threadPool.maxSpare = 0 + + def client_host(server_host): """Return the host on which a client can connect to the given listener.""" if server_host == '0.0.0.0': diff --git a/src/cherrypy/process/win32.py b/src/cherrypy/process/win32.py index 9db928da5a..0ca43d5e9b 100644 --- a/src/cherrypy/process/win32.py +++ b/src/cherrypy/process/win32.py @@ -71,8 +71,7 @@ class ConsoleCtrlHandler(plugins.SimplePlugin): class Win32Bus(wspbus.Bus): """A Web Site Process Bus implementation for Win32. - Instead of using time.sleep for blocking, this bus uses native - win32event objects. It also responds to console events. + Instead of time.sleep, this bus blocks using native win32event objects. """ def __init__(self): @@ -99,15 +98,21 @@ class Win32Bus(wspbus.Bus): state = property(_get_state, _set_state) def wait(self, state, interval=0.1): - """Wait for the given state, KeyboardInterrupt or SystemExit. + """Wait for the given state(s), KeyboardInterrupt or SystemExit. Since this class uses native win32event objects, the interval argument is ignored. """ - # Don't wait for an event that beat us to the punch ;) - if self.state != state: - event = self._get_state_event(state) - win32event.WaitForSingleObject(event, win32event.INFINITE) + if isinstance(state, (tuple, list)): + # Don't wait for an event that beat us to the punch ;) + if self.state not in state: + events = tuple([self._get_state_event(s) for s in state]) + win32event.WaitForMultipleObjects(events, 0, win32event.INFINITE) + else: + # Don't wait for an event that beat us to the punch ;) + if self.state != state: + event = self._get_state_event(state) + win32event.WaitForSingleObject(event, win32event.INFINITE) class _ControlCodes(dict): diff --git a/src/cherrypy/process/wspbus.py b/src/cherrypy/process/wspbus.py index 8c0d84bcbc..26abb4702c 100644 --- a/src/cherrypy/process/wspbus.py +++ b/src/cherrypy/process/wspbus.py @@ -153,9 +153,13 @@ class Bus(object): e.code = 1 raise except: - self.log("Error in %r listener %r" % (channel, listener), - level=40, traceback=True) exc = sys.exc_info()[1] + if channel == 'log': + # Assume any further messages to 'log' will fail. + pass + else: + self.log("Error in %r listener %r" % (channel, listener), + level=40, traceback=True) if exc: raise return output @@ -248,9 +252,14 @@ class Bus(object): self._do_execv() def wait(self, state, interval=0.1): - """Wait for the given state.""" + """Wait for the given state(s).""" + if isinstance(state, (tuple, list)): + states = state + else: + states = [state] + def _wait(): - while self.state != state: + while self.state not in states: time.sleep(interval) # From http://psyco.sourceforge.net/psycoguide/bugs.html: diff --git a/src/cherrypy/scaffold/__init__.py b/src/cherrypy/scaffold/__init__.py new file mode 100644 index 0000000000..f50cc213d3 --- /dev/null +++ b/src/cherrypy/scaffold/__init__.py @@ -0,0 +1,61 @@ +""", a CherryPy application. + +Use this as a base for creating new CherryPy applications. When you want +to make a new app, copy and paste this folder to some other location +(maybe site-packages) and rename it to the name of your project, +then tweak as desired. + +Even before any tweaking, this should serve a few demonstration pages. +Change to this directory and run: + + python cherrypy\cherryd -c cherrypy\scaffold\site.conf + +""" + +import cherrypy +from cherrypy import tools, url + +import os +local_dir = os.path.join(os.getcwd(), os.path.dirname(__file__)) + + +class Root: + + _cp_config = {'tools.log_tracebacks.on': True, + } + + def index(self): + return """ +Try some other path, +or a default path.
+Or, just look at the pretty picture:
+ +""" % (url("other"), url("else"), + url("files/made_with_cherrypy_small.png")) + index.exposed = True + + def default(self, *args, **kwargs): + return "args: %s kwargs: %s" % (args, kwargs) + default.exposed = True + + def other(self, a=2, b='bananas', c=None): + cherrypy.response.headers['Content-Type'] = 'text/plain' + if c is None: + return "Have %d %s." % (int(a), b) + else: + return "Have %d %s, %s." % (int(a), b, c) + other.exposed = True + + files = cherrypy.tools.staticdir.handler( + section="/files", + dir=os.path.join(local_dir, "static"), + # Ignore .php files, etc. + match=r'\.(css|gif|html?|ico|jpe?g|js|png|swf|xml)$', + ) + + +root = Root() + +# Uncomment the following to use your own favicon instead of CP's default. +#favicon_path = os.path.join(local_dir, "favicon.ico") +#root.favicon_ico = tools.staticfile.handler(filename=favicon_path) diff --git a/src/cherrypy/wsgiserver/__init__.py b/src/cherrypy/wsgiserver/__init__.py index c3172c73ff..a92869f56d 100644 --- a/src/cherrypy/wsgiserver/__init__.py +++ b/src/cherrypy/wsgiserver/__init__.py @@ -88,6 +88,9 @@ try: import cStringIO as StringIO except ImportError: import StringIO + +_fileobject_uses_str_type = isinstance(socket._fileobject(None)._rbuf, basestring) + import sys import threading import time @@ -332,7 +335,12 @@ class HTTPRequest(object): environ = self.environ - method, path, req_protocol = request_line.strip().split(" ", 2) + try: + method, path, req_protocol = request_line.strip().split(" ", 2) + except ValueError: + self.simple_response(400, "Malformed Request-Line") + return + environ["REQUEST_METHOD"] = method # path may be an abs_path (including "http://host.domain.tld"); @@ -402,13 +410,6 @@ class HTTPRequest(object): self.simple_response("413 Request Entity Too Large") return - # Set AUTH_TYPE, REMOTE_USER - creds = environ.get("HTTP_AUTHORIZATION", "").split(" ", 1) - environ["AUTH_TYPE"] = creds[0] - if creds[0].lower() == 'basic': - user, pw = base64.decodestring(creds[1]).split(":", 1) - environ["REMOTE_USER"] = user - # Persistent connection support if self.response_protocol == "HTTP/1.1": # Both server and client are HTTP/1.1 @@ -588,7 +589,12 @@ class HTTPRequest(object): buf.append("\r\n") if msg: buf.append(msg) - self.wfile.sendall("".join(buf)) + + try: + self.wfile.sendall("".join(buf)) + except socket.error, x: + if x.args[0] not in socket_errors_to_ignore: + raise def start_response(self, status, headers, exc_info = None): """WSGI callable to begin the HTTP response.""" @@ -646,7 +652,8 @@ class HTTPRequest(object): if status < 200 or status in (204, 205, 304): pass else: - if self.response_protocol == 'HTTP/1.1': + if (self.response_protocol == 'HTTP/1.1' + and self.environ["REQUEST_METHOD"] != 'HEAD'): # Use the chunked transfer-coding self.chunked_write = True self.outheaders.append(("Transfer-Encoding", "chunked")) @@ -711,147 +718,327 @@ class FatalSSLAlert(Exception): pass -class CP_fileobject(socket._fileobject): - """Faux file object attached to a socket object.""" - - def sendall(self, data): - """Sendall for non-blocking sockets.""" - while data: - try: - bytes_sent = self.send(data) - data = data[bytes_sent:] - except socket.error, e: - if e.args[0] not in socket_errors_nonblocking: - raise - - def send(self, data): - return self._sock.send(data) - - def flush(self): - if self._wbuf: - buffer = "".join(self._wbuf) - self._wbuf = [] - self.sendall(buffer) - - def recv(self, size): - while True: - try: - return self._sock.recv(size) - except socket.error, e: - if e.args[0] not in socket_errors_nonblocking: - raise - - def read(self, size=-1): - if size < 0: - # Read until EOF - buffers = [self._rbuf] - self._rbuf = "" - if self._rbufsize <= 1: - recv_size = self.default_bufsize - else: - recv_size = self._rbufsize - - while True: - data = self.recv(recv_size) - if not data: - break - buffers.append(data) - return "".join(buffers) - else: - # Read until size bytes or EOF seen, whichever comes first - data = self._rbuf - buf_len = len(data) - if buf_len >= size: - self._rbuf = data[size:] - return data[:size] - buffers = [] - if data: - buffers.append(data) - self._rbuf = "" - while True: - left = size - buf_len - recv_size = max(self._rbufsize, left) - data = self.recv(recv_size) - if not data: - break - buffers.append(data) - n = len(data) - if n >= left: - self._rbuf = data[left:] - buffers[-1] = data[:left] - break - buf_len += n - return "".join(buffers) +if not _fileobject_uses_str_type: + class CP_fileobject(socket._fileobject): + """Faux file object attached to a socket object.""" - def readline(self, size=-1): - data = self._rbuf - if size < 0: - # Read until \n or EOF, whichever comes first - if self._rbufsize <= 1: - # Speed up unbuffered case - assert data == "" - buffers = [] - while data != "\n": - data = self.recv(1) + def sendall(self, data): + """Sendall for non-blocking sockets.""" + while data: + try: + bytes_sent = self.send(data) + data = data[bytes_sent:] + except socket.error, e: + if e.args[0] not in socket_errors_nonblocking: + raise + + def send(self, data): + return self._sock.send(data) + + def flush(self): + if self._wbuf: + buffer = "".join(self._wbuf) + self._wbuf = [] + self.sendall(buffer) + + def recv(self, size): + while True: + try: + return self._sock.recv(size) + except socket.error, e: + if (e.args[0] not in socket_errors_nonblocking + and e.args[0] not in socket_error_eintr): + raise + + def read(self, size=-1): + # Use max, disallow tiny reads in a loop as they are very inefficient. + # We never leave read() with any leftover data from a new recv() call + # in our internal buffer. + rbufsize = max(self._rbufsize, self.default_bufsize) + # Our use of StringIO rather than lists of string objects returned by + # recv() minimizes memory usage and fragmentation that occurs when + # rbufsize is large compared to the typical return value of recv(). + buf = self._rbuf + buf.seek(0, 2) # seek end + if size < 0: + # Read until EOF + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + data = self.recv(rbufsize) + if not data: + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = buf.tell() + if buf_len >= size: + # Already have size bytes in our buffer? Extract and return. + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return rv + + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + left = size - buf_len + # recv() will malloc the amount of memory given as its + # parameter even though it often returns much less data + # than that. The returned data string is short lived + # as we copy it into a StringIO and free it. This avoids + # fragmentation issues on many platforms. + data = self.recv(left) + if not data: + break + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid buffer data copies when: + # - We have no data in our buffer. + # AND + # - Our call to recv returned exactly the + # number of bytes we were asked to read. + return data + if n == left: + buf.write(data) + del data # explicit free + break + assert n <= left, "recv(%d) returned %d bytes" % (left, n) + buf.write(data) + buf_len += n + del data # explicit free + #assert buf_len == buf.tell() + return buf.getvalue() + + def readline(self, size=-1): + buf = self._rbuf + buf.seek(0, 2) # seek end + if buf.tell() > 0: + # check if we already have it in our buffer + buf.seek(0) + bline = buf.readline(size) + if bline.endswith('\n') or len(bline) == size: + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return bline + del bline + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + buf.seek(0) + buffers = [buf.read()] + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + data = None + recv = self.recv + while data != "\n": + data = recv(1) + if not data: + break + buffers.append(data) + return "".join(buffers) + + buf.seek(0, 2) # seek end + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + data = self.recv(self._rbufsize) + if not data: + break + nl = data.find('\n') + if nl >= 0: + nl += 1 + buf.write(data[:nl]) + self._rbuf.write(data[nl:]) + del data + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or \n or EOF seen, whichever comes first + buf.seek(0, 2) # seek end + buf_len = buf.tell() + if buf_len >= size: + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return rv + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + data = self.recv(self._rbufsize) + if not data: + break + left = size - buf_len + # did we just receive a newline? + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + # save the excess data to _rbuf + self._rbuf.write(data[nl:]) + if buf_len: + buf.write(data[:nl]) + break + else: + # Shortcut. Avoid data copy through buf when returning + # a substring of our first recv(). + return data[:nl] + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid data copy through buf when + # returning exactly all of our first recv(). + return data + if n >= left: + buf.write(data[:left]) + self._rbuf.write(data[left:]) + break + buf.write(data) + buf_len += n + #assert buf_len == buf.tell() + return buf.getvalue() + +else: + class CP_fileobject(socket._fileobject): + """Faux file object attached to a socket object.""" + + def sendall(self, data): + """Sendall for non-blocking sockets.""" + while data: + try: + bytes_sent = self.send(data) + data = data[bytes_sent:] + except socket.error, e: + if e.args[0] not in socket_errors_nonblocking: + raise + + def send(self, data): + return self._sock.send(data) + + def flush(self): + if self._wbuf: + buffer = "".join(self._wbuf) + self._wbuf = [] + self.sendall(buffer) + + def recv(self, size): + while True: + try: + return self._sock.recv(size) + except socket.error, e: + if (e.args[0] not in socket_errors_nonblocking + and e.args[0] not in socket_error_eintr): + raise + + def read(self, size=-1): + if size < 0: + # Read until EOF + buffers = [self._rbuf] + self._rbuf = "" + if self._rbufsize <= 1: + recv_size = self.default_bufsize + else: + recv_size = self._rbufsize + + while True: + data = self.recv(recv_size) if not data: break buffers.append(data) return "".join(buffers) - nl = data.find('\n') - if nl >= 0: - nl += 1 - self._rbuf = data[nl:] - return data[:nl] - buffers = [] - if data: - buffers.append(data) - self._rbuf = "" - while True: - data = self.recv(self._rbufsize) - if not data: - break - buffers.append(data) + else: + # Read until size bytes or EOF seen, whichever comes first + data = self._rbuf + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + left = size - buf_len + recv_size = max(self._rbufsize, left) + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + def readline(self, size=-1): + data = self._rbuf + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + assert data == "" + buffers = [] + while data != "\n": + data = self.recv(1) + if not data: + break + buffers.append(data) + return "".join(buffers) nl = data.find('\n') if nl >= 0: nl += 1 self._rbuf = data[nl:] - buffers[-1] = data[:nl] - break - return "".join(buffers) - else: - # Read until size bytes or \n or EOF seen, whichever comes first - nl = data.find('\n', 0, size) - if nl >= 0: - nl += 1 - self._rbuf = data[nl:] - return data[:nl] - buf_len = len(data) - if buf_len >= size: - self._rbuf = data[size:] - return data[:size] - buffers = [] - if data: - buffers.append(data) - self._rbuf = "" - while True: - data = self.recv(self._rbufsize) - if not data: - break - buffers.append(data) - left = size - buf_len - nl = data.find('\n', 0, left) + return data[:nl] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + return "".join(buffers) + else: + # Read until size bytes or \n or EOF seen, whichever comes first + nl = data.find('\n', 0, size) if nl >= 0: nl += 1 self._rbuf = data[nl:] - buffers[-1] = data[:nl] - break - n = len(data) - if n >= left: - self._rbuf = data[left:] - buffers[-1] = data[:left] - break - buf_len += n - return "".join(buffers) + return data[:nl] + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + left = size - buf_len + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) class SSL_fileobject(CP_fileobject): @@ -1203,6 +1390,27 @@ class SSLConnection: """ % (f, f) +try: + import fcntl +except ImportError: + try: + from ctypes import windll, WinError + except ImportError: + def prevent_socket_inheritance(sock): + """Dummy function, since neither fcntl nor ctypes are available.""" + pass + else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (Windows).""" + if not windll.kernel32.SetHandleInformation(sock.fileno(), 1, 0): + raise WinError() +else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (POSIX).""" + fd = sock.fileno() + old_flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) + class CherryPyWSGIServer(object): """An HTTP server for WSGI. @@ -1249,7 +1457,7 @@ class CherryPyWSGIServer(object): protocol = "HTTP/1.1" _bind_addr = "127.0.0.1" - version = "CherryPy/3.1.0" + version = "CherryPy/3.1.1" ready = False _interrupt = None @@ -1396,6 +1604,7 @@ class CherryPyWSGIServer(object): def bind(self, family, type, proto=0): """Create (or recreate) the actual socket object.""" self.socket = socket.socket(family, type, proto) + prevent_socket_inheritance(self.socket) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if self.nodelay: self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -1409,12 +1618,25 @@ class CherryPyWSGIServer(object): ctx.use_certificate_file(self.ssl_certificate) self.socket = SSLConnection(ctx, self.socket) self.populate_ssl_environ() + + # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), + # activate dual-stack. See http://www.cherrypy.org/ticket/871. + if (not isinstance(self.bind_addr, basestring) + and self.bind_addr[0] == '::' and family == socket.AF_INET6): + try: + self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (AttributeError, socket.error): + # Apparently, the socket option is not available in + # this machine's TCP stack + pass + self.socket.bind(self.bind_addr) def tick(self): """Accept a new connection and put it on the Queue.""" try: s, addr = self.socket.accept() + prevent_socket_inheritance(s) if not self.ready: return if hasattr(s, 'settimeout'): @@ -1423,7 +1645,8 @@ class CherryPyWSGIServer(object): environ = self.environ.copy() # SERVER_SOFTWARE is common for IIS. It's also helpful for # us to pass a default value for the "Server" response header. - environ["SERVER_SOFTWARE"] = "%s WSGI Server" % self.version + if environ.get("SERVER_SOFTWARE") is None: + environ["SERVER_SOFTWARE"] = "%s WSGI Server" % self.version # set a non-standard environ entry so the WSGI app can know what # the *real* server protocol is (and what features to support). # See http://www.faqs.org/rfcs/rfc2145.html.