]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Move address manipulation from HTTP1Connection to HTTPServer.
authorBen Darnell <ben@bendarnell.com>
Sun, 20 Apr 2014 21:01:43 +0000 (17:01 -0400)
committerBen Darnell <ben@bendarnell.com>
Mon, 21 Apr 2014 02:52:52 +0000 (22:52 -0400)
The connection now contains only an opaque 'origin' object to which
the caller can attach address info.  This object can also be mutated
as in HTTPServer's xheader support.

tornado/http1connection.py
tornado/httpserver.py
tornado/httputil.py
tornado/simple_httpclient.py
tornado/test/httpserver_test.py
tornado/wsgi.py

index 68e93143340fac93c2cd011e3f62fd82fe060fe4..b06680031084613c46490513ebea3f630f718ac9 100644 (file)
@@ -16,8 +16,6 @@
 
 from __future__ import absolute_import, division, print_function, with_statement
 
-import socket
-
 from tornado.concurrent import Future
 from tornado.escape import native_str, utf8
 from tornado import gen
@@ -46,34 +44,14 @@ class HTTP1Connection(object):
     We parse HTTP headers and bodies, and execute the request callback
     until the HTTP conection is closed.
     """
-    def __init__(self, stream, address, is_client, params=None):
+    def __init__(self, stream, is_client, params=None, context=None):
         self.is_client = is_client
         self.stream = stream
-        self.address = address
         if params is None:
             params = HTTP1ConnectionParameters()
         self.params = params
+        self.context = context
         self.no_keep_alive = params.no_keep_alive
-        # Save the socket's address family now so we know how to
-        # interpret self.address even after the stream is closed
-        # and its socket attribute replaced with None.
-        if stream.socket is not None:
-            self.address_family = stream.socket.family
-        else:
-            self.address_family = None
-        # In HTTPServerRequest we want an IP, not a full socket address.
-        if (self.address_family in (socket.AF_INET, socket.AF_INET6) and
-            address is not None):
-            self.remote_ip = address[0]
-        else:
-            # Unix (or other) socket; fake the remote address.
-            self.remote_ip = '0.0.0.0'
-        if self.params.protocol:
-            self.protocol = self.params.protocol
-        elif isinstance(stream, iostream.SSLIOStream):
-            self.protocol = "https"
-        else:
-            self.protocol = "http"
         # The body limits can be altered by the delegate, so save them
         # here instead of just referencing self.params later.
         self._max_body_size = (self.params.max_body_size or
@@ -169,8 +147,8 @@ class HTTP1Connection(object):
                                 self.stream.io_loop.time() + self._body_timeout,
                                 body_future, self.stream.io_loop)
                         except gen.TimeoutError:
-                            gen_log.info("Timeout reading body from %r",
-                                         self.address)
+                            gen_log.info("Timeout reading body from %s",
+                                         self.context)
                             self.stream.close()
                             raise gen.Return(False)
             self._read_finished = True
@@ -182,8 +160,8 @@ class HTTP1Connection(object):
             if self.stream is None:
                 raise gen.Return(False)
         except httputil.HTTPInputException as e:
-            gen_log.info("Malformed HTTP message from %r: %s",
-                         self.address, e)
+            gen_log.info("Malformed HTTP message from %s: %s",
+                         self.context, e)
             self.close()
             raise gen.Return(False)
         finally:
@@ -481,12 +459,12 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
 
 
 class HTTP1ServerConnection(object):
-    def __init__(self, stream, address, params=None):
+    def __init__(self, stream, params=None, context=None):
         self.stream = stream
-        self.address = address
         if params is None:
             params = HTTP1ConnectionParameters()
         self.params = params
+        self.context = context
 
     def start_serving(self, delegate):
         assert isinstance(delegate, httputil.HTTPServerConnectionDelegate)
@@ -498,8 +476,8 @@ class HTTP1ServerConnection(object):
     @gen.coroutine
     def _server_request_loop(self, delegate):
         while True:
-            conn = HTTP1Connection(self.stream, self.address, False,
-                                   self.params)
+            conn = HTTP1Connection(self.stream, False,
+                                   self.params, self.context)
             request_delegate = delegate.start_request(conn)
             try:
                 ret = yield conn.read_response(request_delegate)
index 036c1580034c97f1c00c11bbd4b140b68ebcbb25..2021ae4d8192fa1447bf0e20d03f97a21f52b2d9 100644 (file)
@@ -28,8 +28,11 @@ class except to start a server at the beginning of the process
 
 from __future__ import absolute_import, division, print_function, with_statement
 
+import socket
+
 from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters
 from tornado import httputil
+from tornado import iostream
 from tornado import netutil
 from tornado.tcpserver import TCPServer
 
@@ -153,51 +156,64 @@ class HTTPServer(TCPServer, httputil.HTTPServerConnectionDelegate):
                            read_chunk_size=chunk_size)
 
     def handle_stream(self, stream, address):
+        context = _HTTPRequestContext(stream, address,
+                                      self.conn_params.protocol)
         conn = HTTP1ServerConnection(
-            stream, address=address,
-            params=self.conn_params)
+            stream, self.conn_params, context)
         conn.start_serving(self)
 
     def start_request(self, connection):
         return _ServerRequestAdapter(self, connection)
 
 
-class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
-    """Adapts the `HTTPMessageDelegate` interface to the interface expected
-    by our clients.
-    """
-    def __init__(self, server, connection):
-        self.server = server
-        self.connection = connection
-        self.request = None
-        if isinstance(server.request_callback,
-                      httputil.HTTPServerConnectionDelegate):
-            self.delegate = server.request_callback.start_request(connection)
-            self._chunks = None
+class _HTTPRequestContext(object):
+    def __init__(self, stream, address, protocol):
+        self.address = address
+        self.protocol = protocol
+        # Save the socket's address family now so we know how to
+        # interpret self.address even after the stream is closed
+        # and its socket attribute replaced with None.
+        if stream.socket is not None:
+            self.address_family = stream.socket.family
         else:
-            self.delegate = None
-            self._chunks = []
+            self.address_family = None
+        # In HTTPServerRequest we want an IP, not a full socket address.
+        if (self.address_family in (socket.AF_INET, socket.AF_INET6) and
+            address is not None):
+            self.remote_ip = address[0]
+        else:
+            # Unix (or other) socket; fake the remote address.
+            self.remote_ip = '0.0.0.0'
+        if protocol:
+            self.protocol = protocol
+        elif isinstance(stream, iostream.SSLIOStream):
+            self.protocol = "https"
+        else:
+            self.protocol = "http"
+        self._orig_remote_ip = self.remote_ip
+        self._orig_protocol = self.protocol
 
-    def _apply_xheaders(self, headers):
-        """Rewrite the connection.remote_ip and connection.protocol fields.
 
-        This is hacky, but the movement of logic between `HTTPServer`
-        and `.Application` leaves us without a clean place to do this.
-        """
-        self._orig_remote_ip = self.connection.remote_ip
-        self._orig_protocol = self.connection.protocol
+    def __str__(self):
+        if self.address_family in (socket.AF_INET, socket.AF_INET6):
+            return self.remote_ip
+        else:
+            return str(self.address)
+
+    def _apply_xheaders(self, headers):
+        """Rewrite the ``remote_ip`` and ``protocol`` fields."""
         # Squid uses X-Forwarded-For, others use X-Real-Ip
-        ip = headers.get("X-Forwarded-For", self.connection.remote_ip)
+        ip = headers.get("X-Forwarded-For", self.remote_ip)
         ip = ip.split(',')[-1].strip()
         ip = headers.get("X-Real-Ip", ip)
         if netutil.is_valid_ip(ip):
-            self.connection.remote_ip = ip
+            self.remote_ip = ip
         # AWS uses X-Forwarded-Proto
         proto_header = headers.get(
             "X-Scheme", headers.get("X-Forwarded-Proto",
-                                    self.connection.protocol))
+                                    self.protocol))
         if proto_header in ("http", "https"):
-            self.connection.protocol = proto_header
+            self.protocol = proto_header
 
     def _unapply_xheaders(self):
         """Undo changes from `_apply_xheaders`.
@@ -205,12 +221,29 @@ class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
         Xheaders are per-request so they should not leak to the next
         request on the same connection.
         """
-        self.connection.remote_ip = self._orig_remote_ip
-        self.connection.protocol = self._orig_protocol
+        self.remote_ip = self._orig_remote_ip
+        self.protocol = self._orig_protocol
+
+
+class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
+    """Adapts the `HTTPMessageDelegate` interface to the interface expected
+    by our clients.
+    """
+    def __init__(self, server, connection):
+        self.server = server
+        self.connection = connection
+        self.request = None
+        if isinstance(server.request_callback,
+                      httputil.HTTPServerConnectionDelegate):
+            self.delegate = server.request_callback.start_request(connection)
+            self._chunks = None
+        else:
+            self.delegate = None
+            self._chunks = []
 
     def headers_received(self, start_line, headers):
         if self.server.xheaders:
-            self._apply_xheaders(headers)
+            self.connection.context._apply_xheaders(headers)
         if self.delegate is None:
             self.request = httputil.HTTPServerRequest(
                 connection=self.connection, start_line=start_line,
@@ -232,7 +265,7 @@ class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
         else:
             self.delegate.finish()
         if self.server.xheaders:
-            self._unapply_xheaders()
+            self.connection.context._unapply_xheaders()
 
 
 HTTPRequest = httputil.HTTPServerRequest
index 8320e46b13e6cca8b7c7f9d6132c2e5767c1cf2c..b96f360d184b0655d2f8c1ea0c47e879bb40c6ef 100644 (file)
@@ -318,8 +318,8 @@ class HTTPServerRequest(object):
        sequentially on a single connection.
     """
     def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
-                 body=None, remote_ip=None, protocol=None, host=None,
-                 files=None, connection=None, start_line=None):
+                 body=None, host=None, files=None, connection=None,
+                 start_line=None):
         if start_line is not None:
             method, uri, version = start_line
         self.method = method
@@ -329,8 +329,9 @@ class HTTPServerRequest(object):
         self.body = body or ""
 
         # set remote IP and protocol
-        self.remote_ip = remote_ip or getattr(connection, 'remote_ip')
-        self.protocol = protocol or getattr(connection, 'protocol', "http")
+        context = getattr(connection, 'context', None)
+        self.remote_ip = getattr(context, 'remote_ip')
+        self.protocol = getattr(context, 'protocol', "http")
 
         self.host = host or self.headers.get("Host") or "127.0.0.1"
         self.files = files or {}
index 06007d9d224b4a6fb969b3e22acd3d6e5b3e3659..417498eb7361c9ba9f82399dd7fe2b1a6d96163f 100644 (file)
@@ -333,11 +333,12 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                    (('?' + self.parsed.query) if self.parsed.query else ''))
         self.stream.set_nodelay(True)
         self.connection = HTTP1Connection(
-            self.stream, self._sockaddr, True,
+            self.stream, True,
             HTTP1ConnectionParameters(
                 no_keep_alive=True, protocol=self.parsed.scheme,
                 max_header_size=self.max_header_size,
-                use_gzip=self.request.use_gzip))
+                use_gzip=self.request.use_gzip),
+            self._sockaddr)
         start_line = httputil.RequestStartLine(self.request.method,
                                                req_path, 'HTTP/1.1')
         self.connection.write_headers(
index 0e8fb36befe3a01ab4a54734961493681ab82e56..22ef7eccc06f41a3a0c646477d8385bc2ad18759 100644 (file)
@@ -41,7 +41,7 @@ def read_stream_body(stream, callback):
 
         def finish(self):
             callback(b''.join(chunks))
-    conn = HTTP1Connection(stream, None, is_client=True)
+    conn = HTTP1Connection(stream, True)
     conn.read_response(Delegate())
 
 
index 1d17810895b8c729d810a0ec6b5c808f9516cd2c..4d2c4627a154066f0697bd700e2b82ae62eaa0a6 100644 (file)
@@ -85,9 +85,10 @@ class WSGIApplication(web.Application):
 
 
 class _WSGIConnection(object):
-    def __init__(self, method, start_response):
+    def __init__(self, method, start_response, context):
         self.method = method
         self.start_response = start_response
+        self.context = context
         self._write_buffer = []
         self._finished = False
         self._expected_content_remaining = None
@@ -134,6 +135,15 @@ class _WSGIConnection(object):
         self._finished = True
 
 
+class _WSGIRequestContext(object):
+    def __init__(self, remote_ip, protocol):
+        self.remote_ip = remote_ip
+        self.protocol = protocol
+
+    def __str__(self):
+        return self.remote_ip
+
+
 class WSGIAdapter(object):
     """Converts a `tornado.web.Application` instance into a WSGI application.
 
@@ -197,10 +207,10 @@ class WSGIAdapter(object):
             host = environ["HTTP_HOST"]
         else:
             host = environ["SERVER_NAME"]
-        connection = _WSGIConnection(method, start_response)
+        connection = _WSGIConnection(method, start_response,
+                                     _WSGIRequestContext(remote_ip, protocol))
         request = httputil.HTTPServerRequest(
-            method, uri, "HTTP/1.1",
-            headers=headers, body=body, remote_ip=remote_ip, protocol=protocol,
+            method, uri, "HTTP/1.1", headers=headers, body=body,
             host=host, connection=connection)
         request._parse_body()
         self.application(request)