]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Make Application an HTTPServerConnectionDelegate.
authorBen Darnell <ben@bendarnell.com>
Sun, 16 Mar 2014 16:57:54 +0000 (12:57 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 16 Mar 2014 16:57:54 +0000 (12:57 -0400)
HTTPServer now forwards delegate events to its request callback
if it implements this interface.  This allows the application to be
in the loop as the request is being read.

tornado/http1connection.py
tornado/httpserver.py
tornado/httputil.py
tornado/web.py

index c2a5eef9027de273701b5611c41a07e2daf319a2..3a1284c06fc37ae3b468fed2602528429c1d17aa 100644 (file)
@@ -16,6 +16,8 @@
 
 from __future__ import absolute_import, division, print_function, with_statement
 
+import socket
+
 from tornado.concurrent import Future
 from tornado.escape import native_str, utf8
 from tornado import gen
@@ -40,6 +42,13 @@ class HTTP1Connection(object):
         # and its socket attribute replaced with None.
         self.address_family = stream.socket.family
         self.no_keep_alive = no_keep_alive
+        # In HTTPServerRequest we want an IP, not a full socket address.
+        if (self.address_family in (socket.AF_INET, socket.AF_INET6) and
+            address is not None):
+            self.remote_ip = address[0]
+        else:
+            # Unix (or other) socket; fake the remote address.
+            self.remote_ip = '0.0.0.0'
         if protocol:
             self.protocol = protocol
         elif isinstance(stream, iostream.SSLIOStream):
index 35d5a9750b74dbc8469dd3014d4ea98ecf311cdf..ff512ba1a634a00358c99bd974c2edde7a67c436 100644 (file)
@@ -154,51 +154,78 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
     def start_request(self, connection):
         return _ServerRequestAdapter(self, connection)
 
+
 class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
-    """Adapts the `HTTPMessageDelegate` interface to the `HTTPServerRequest`
-    interface expected by our clients.
+    """Adapts the `HTTPMessageDelegate` interface to the interface expected
+    by our clients.
     """
     def __init__(self, server, connection):
         self.server = server
         self.connection = connection
-        self._chunks = []
-
-    def headers_received(self, start_line, headers):
-        # HTTPRequest wants an IP, not a full socket address
-        if self.connection.address_family in (socket.AF_INET, socket.AF_INET6):
-            remote_ip = self.connection.address[0]
+        self.request = None
+        if isinstance(server.request_callback,
+                      httputil.HTTPServerConnectionDelegate):
+            self.delegate = server.request_callback.start_request(connection)
+            self._chunks = None
         else:
-            # Unix (or other) socket; fake the remote address
-            remote_ip = '0.0.0.0'
+            self.delegate = None
+            self._chunks = []
+
+    def _apply_xheaders(self, headers):
+        """Rewrite the connection.remote_ip and connection.protocol fields.
+
+        This is hacky, but the movement of logic between `HTTPServer`
+        and `.Application` leaves us without a clean place to do this.
+        """
+        self._orig_remote_ip = self.connection.remote_ip
+        self._orig_protocol = self.connection.protocol
+        # Squid uses X-Forwarded-For, others use X-Real-Ip
+        ip = headers.get("X-Forwarded-For", self.connection.remote_ip)
+        ip = ip.split(',')[-1].strip()
+        ip = headers.get("X-Real-Ip", ip)
+        if netutil.is_valid_ip(ip):
+            self.connection.remote_ip = ip
+        # AWS uses X-Forwarded-Proto
+        proto_header = headers.get(
+            "X-Scheme", headers.get("X-Forwarded-Proto",
+                                    self.connection.protocol))
+        if proto_header in ("http", "https"):
+            self.connection.protocol = proto_header
+
+    def _unapply_xheaders(self):
+        """Undo changes from `_apply_xheaders`.
+
+        Xheaders are per-request so they should not leak to the next
+        request on the same connection.
+        """
+        self.connection.remote_ip = self._orig_remote_ip
+        self.connection.protocol = self._orig_protocol
 
-        protocol = self.connection.protocol
-
-        # xheaders can override the defaults
+    def headers_received(self, start_line, headers):
         if self.server.xheaders:
-            # Squid uses X-Forwarded-For, others use X-Real-Ip
-            ip = headers.get("X-Forwarded-For", remote_ip)
-            ip = ip.split(',')[-1].strip()
-            ip = headers.get("X-Real-Ip", ip)
-            if netutil.is_valid_ip(ip):
-                remote_ip = ip
-            # AWS uses X-Forwarded-Proto
-            proto_header = headers.get(
-                "X-Scheme", headers.get("X-Forwarded-Proto", protocol))
-            if proto_header in ("http", "https"):
-                protocol = proto_header
-
-        self.request = httputil.HTTPServerRequest(
-            connection=self.connection, method=start_line.method,
-            uri=start_line.path, version=start_line.version,
-            headers=headers, remote_ip=remote_ip, protocol=protocol)
+            self._apply_xheaders(headers)
+        if self.delegate is None:
+            self.request = httputil.HTTPServerRequest(
+                connection=self.connection, start_line=start_line,
+                headers=headers)
+        else:
+            self.delegate.headers_received(start_line, headers)
 
     def data_received(self, chunk):
-        self._chunks.append(chunk)
+        if self.delegate is None:
+            self._chunks.append(chunk)
+        else:
+            self.delegate.data_received(chunk)
 
     def finish(self):
-        self.request.body = b''.join(self._chunks)
-        self.request._parse_body()
-        self.server.request_callback(self.request)
+        if self.delegate is None:
+            self.request.body = b''.join(self._chunks)
+            self.request._parse_body()
+            self.server.request_callback(self.request)
+        else:
+            self.delegate.finish()
+        if self.server.xheaders:
+            self._unapply_xheaders()
 
 
 HTTPRequest = httputil.HTTPServerRequest
index c8161669ae4b6c48d34ff17afd39059e025e50b2..68a0d429f1bbae690d82a3ff6815d7bc669a71e3 100644 (file)
@@ -317,9 +317,11 @@ class HTTPServerRequest(object):
        are typically kept open in HTTP/1.1, multiple requests can be handled
        sequentially on a single connection.
     """
-    def __init__(self, method, uri, version="HTTP/1.0", headers=None,
+    def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
                  body=None, remote_ip=None, protocol=None, host=None,
-                 files=None, connection=None):
+                 files=None, connection=None, start_line=None):
+        if start_line is not None:
+            method, uri, version = start_line
         self.method = method
         self.uri = uri
         self.version = version
@@ -327,8 +329,8 @@ class HTTPServerRequest(object):
         self.body = body or ""
 
         # set remote IP and protocol
-        self.remote_ip = remote_ip
-        self.protocol = protocol or "http"
+        self.remote_ip = remote_ip or getattr(connection, 'remote_ip')
+        self.protocol = protocol or getattr(connection, 'protocol', "http")
 
         self.host = host or self.headers.get("Host") or "127.0.0.1"
         self.files = files or {}
index 424e3544922185d7edaa70300e6ab7e1573b1842..8442cd18989d21a01f637850dc2a6c7539f2d380 100644 (file)
@@ -1423,7 +1423,7 @@ def addslash(method):
     return wrapper
 
 
-class Application(object):
+class Application(httputil.HTTPServerConnectionDelegate):
     """A collection of request handlers that make up a web application.
 
     Instances of this class are callable and can be passed directly to
@@ -1616,64 +1616,15 @@ class Application(object):
                 except TypeError:
                     pass
 
-    def __call__(self, request):
-        """Called by HTTPServer to execute the request."""
-        transforms = [t(request) for t in self.transforms]
-        handler = None
-        args = []
-        kwargs = {}
-        handlers = self._get_host_handlers(request)
-        if not handlers:
-            handler = RedirectHandler(
-                self, request, url="http://" + self.default_host + "/")
-        else:
-            for spec in handlers:
-                match = spec.regex.match(request.path)
-                if match:
-                    handler = spec.handler_class(self, request, **spec.kwargs)
-                    if spec.regex.groups:
-                        # None-safe wrapper around url_unescape to handle
-                        # unmatched optional groups correctly
-                        def unquote(s):
-                            if s is None:
-                                return s
-                            return escape.url_unescape(s, encoding=None,
-                                                       plus=False)
-                        # Pass matched groups to the handler.  Since
-                        # match.groups() includes both named and unnamed groups,
-                        # we want to use either groups or groupdict but not both.
-                        # Note that args are passed as bytes so the handler can
-                        # decide what encoding to use.
-
-                        if spec.regex.groupindex:
-                            kwargs = dict(
-                                (str(k), unquote(v))
-                                for (k, v) in match.groupdict().items())
-                        else:
-                            args = [unquote(s) for s in match.groups()]
-                    break
-            if not handler:
-                if self.settings.get('default_handler_class'):
-                    handler_class = self.settings['default_handler_class']
-                    handler_args = self.settings.get(
-                        'default_handler_args', {})
-                else:
-                    handler_class = ErrorHandler
-                    handler_args = dict(status_code=404)
-                handler = handler_class(self, request, **handler_args)
+    def start_request(self, connection):
+        # Modern HTTPServer interface
+        return _RequestDispatcher(self, connection)
 
-        # If template cache is disabled (usually in the debug mode),
-        # re-compile templates and reload static files on every
-        # request so you don't need to restart to see changes
-        if not self.settings.get("compiled_template_cache", True):
-            with RequestHandler._template_loader_lock:
-                for loader in RequestHandler._template_loaders.values():
-                    loader.reset()
-        if not self.settings.get('static_hash_cache', True):
-            StaticFileHandler.reset()
-
-        handler._execute(transforms, *args, **kwargs)
-        return handler
+    def __call__(self, request):
+        # Legacy HTTPServer interface
+        dispatcher = _RequestDispatcher(self, None)
+        dispatcher.set_request(request)
+        return dispatcher.execute()
 
     def reverse_url(self, name, *args):
         """Returns a URL path for handler named ``name``
@@ -1710,6 +1661,87 @@ class Application(object):
                    handler._request_summary(), request_time)
 
 
+class _RequestDispatcher(httputil.HTTPMessageDelegate):
+    def __init__(self, application, connection):
+        self.application = application
+        self.connection = connection
+        self.request = None
+        self.chunks = []
+        self.handler_class = None
+        self.handler_kwargs = None
+        self.path_args = []
+        self.path_kwargs = {}
+
+    def headers_received(self, start_line, headers):
+        self.set_request(httputil.HTTPServerRequest(
+            connection=self.connection, start_line=start_line, headers=headers))
+
+    def set_request(self, request):
+        self.request = request
+        self._find_handler()
+
+    def _find_handler(self):
+        # Identify the handler to use as soon as we have the request.
+        # Save url path arguments for later.
+        app = self.application
+        handlers = app._get_host_handlers(self.request)
+        if not handlers:
+            self.handler_class = RedirectHandler
+            self.handler_kwargs = dict(url="http://" + app.default_host + "/")
+            return
+        for spec in handlers:
+            match = spec.regex.match(self.request.path)
+            if match:
+                self.handler_class = spec.handler_class
+                self.handler_kwargs = spec.kwargs
+                if spec.regex.groups:
+                    # Pass matched groups to the handler.  Since
+                    # match.groups() includes both named and
+                    # unnamed groups, we want to use either groups
+                    # or groupdict but not both.
+                    if spec.regex.groupindex:
+                        self.path_kwargs = dict(
+                            (str(k), _unquote_or_none(v))
+                            for (k, v) in match.groupdict().items())
+                    else:
+                        self.path_args = [_unquote_or_none(s)
+                                          for s in match.groups()]
+                return
+        if app.settings.get('default_handler_class'):
+            self.handler_class = app.settings['default_handler_class']
+            self.handler_kwargs = app.settings.get(
+                'default_handler_args', {})
+        else:
+            self.handler_class = ErrorHandler
+            self.handler_kwargs = dict(status_code=404)
+
+    def data_received(self, data):
+        self.chunks.append(data)
+
+    def finish(self):
+        self.request.body = b''.join(self.chunks)
+        self.request._parse_body()
+        self.execute()
+
+    def execute(self):
+        # If template cache is disabled (usually in the debug mode),
+        # re-compile templates and reload static files on every
+        # request so you don't need to restart to see changes
+        if not self.application.settings.get("compiled_template_cache", True):
+            with RequestHandler._template_loader_lock:
+                for loader in RequestHandler._template_loaders.values():
+                    loader.reset()
+        if not self.application.settings.get('static_hash_cache', True):
+            StaticFileHandler.reset()
+
+        handler = self.handler_class(self.application, self.request,
+                                     **self.handler_kwargs)
+        transforms = [t(self.request) for t in self.application.transforms]
+        handler._execute(transforms, *self.path_args, **self.path_kwargs)
+        return handler
+
+
+
 class HTTPError(Exception):
     """An exception that will turn into an HTTP error response.
 
@@ -2639,3 +2671,14 @@ def _create_signature(secret, *parts):
     for part in parts:
         hash.update(utf8(part))
     return utf8(hash.hexdigest())
+
+def _unquote_or_none(s):
+    """None-safe wrapper around url_unescape to handle unamteched optional
+    groups correctly.
+
+    Note that args are passed as bytes so the handler can decide what
+    encoding to use.
+    """
+    if s is None:
+        return s
+    return escape.url_unescape(s, encoding=None, plus=False)