]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Move HTTPServerRequest-specific logic from http1connection to httpserver.
authorBen Darnell <ben@bendarnell.com>
Mon, 24 Feb 2014 05:53:53 +0000 (00:53 -0500)
committerBen Darnell <ben@bendarnell.com>
Mon, 24 Feb 2014 05:53:53 +0000 (00:53 -0500)
tornado/http1connection.py
tornado/httpserver.py
tornado/httputil.py

index 8bf2cc1225b231c367b222a6d25e7deeb126f819..b491de6c476da7a3fe8b232709591cf263f4b26c 100644 (file)
 
 from __future__ import absolute_import, division, print_function, with_statement
 
-import socket
-
 from tornado.concurrent import Future
 from tornado.escape import native_str
 from tornado import gen
 from tornado import httputil
 from tornado import iostream
 from tornado.log import gen_log
-from tornado import netutil
 from tornado import stack_context
 
 
-class _BadRequestException(Exception):
-    """Exception class for malformed HTTP requests."""
-    pass
-
-
 class HTTP1Connection(object):
     """Handles a connection to an HTTP client, executing HTTP requests.
 
     We parse HTTP headers and bodies, and execute the request callback
     until the HTTP conection is closed.
     """
-    def __init__(self, stream, address, request_callback, no_keep_alive=False,
-                 xheaders=False, protocol=None):
+    def __init__(self, stream, address, delegate, no_keep_alive=False,
+                 protocol=None):
         self.stream = stream
         self.address = address
         # Save the socket's address family now so we know how to
         # interpret self.address even after the stream is closed
         # and its socket attribute replaced with None.
         self.address_family = stream.socket.family
-        self.request_callback = request_callback
+        self.delegate = delegate
         self.no_keep_alive = no_keep_alive
-        self.xheaders = xheaders
         if protocol:
             self.protocol = protocol
         elif isinstance(stream, iostream.SSLIOStream):
             self.protocol = "https"
         else:
             self.protocol = "http"
+        self._disconnect_on_finish = False
         self._clear_request_state()
         self.stream.set_close_callback(self._on_connection_close)
         self._finish_future = None
@@ -68,17 +60,18 @@ class HTTP1Connection(object):
         while True:
             try:
                 header_data = yield self.stream.read_until(b"\r\n\r\n")
+                request_delegate = self.delegate.start_request(self)
                 self._finish_future = Future()
                 start_line, headers = self._parse_headers(header_data)
-                request = self._make_request(start_line, headers)
-                self._request = request
+                self._disconnect_on_finish = not self._can_keep_alive(
+                    start_line, headers)
+                request_delegate.headers_received(start_line, headers)
                 body_future = self._read_body(headers)
                 if body_future is not None:
-                    request.body = yield body_future
-                self._parse_body(request)
-                self.request_callback(request)
+                    request_delegate.data_received((yield body_future))
+                request_delegate.finish()
                 yield self._finish_future
-            except _BadRequestException as e:
+            except httputil.BadRequestException as e:
                 gen_log.info("Malformed HTTP request from %r: %s",
                              self.address, e)
                 self.close()
@@ -96,7 +89,6 @@ class HTTP1Connection(object):
         and when the connection is closed (to break up cycles and
         facilitate garbage collection in cpython).
         """
-        self._request = None
         self._request_finished = False
         self._write_callback = None
         self._close_callback = None
@@ -157,22 +149,22 @@ class HTTP1Connection(object):
         if self._request_finished and not self.stream.writing():
             self._finish_request()
 
+    def _can_keep_alive(self, start_line, headers):
+        if self.no_keep_alive:
+            return False
+        connection_header = headers.get("Connection")
+        if connection_header is not None:
+            connection_header = connection_header.lower()
+        if start_line.endswith("HTTP/1.1"):
+            return connection_header != "close"
+        elif ("Content-Length" in headers
+              or start_line.startswith(("HEAD ", "GET "))):
+            return connection_header == "keep-alive"
+        return False
+
     def _finish_request(self):
-        if self.no_keep_alive or self._request is None:
-            disconnect = True
-        else:
-            connection_header = self._request.headers.get("Connection")
-            if connection_header is not None:
-                connection_header = connection_header.lower()
-            if self._request.supports_http_1_1():
-                disconnect = connection_header == "close"
-            elif ("Content-Length" in self._request.headers
-                    or self._request.method in ("HEAD", "GET")):
-                disconnect = connection_header != "keep-alive"
-            else:
-                disconnect = True
         self._clear_request_state()
-        if disconnect:
+        if self._disconnect_on_finish:
             self.close()
             return
         # Turn Nagle's algorithm back on, leaving the stream in its
@@ -188,59 +180,16 @@ class HTTP1Connection(object):
             headers = httputil.HTTPHeaders.parse(data[eol:])
         except ValueError:
             # probably form split() if there was no ':' in the line
-            raise _BadRequestException("Malformed HTTP headers")
+            raise httputil.BadRequestException("Malformed HTTP headers")
         return start_line, headers
 
-    def _make_request(self, start_line, headers):
-        try:
-            method, uri, version = start_line.split(" ")
-        except ValueError:
-            raise _BadRequestException("Malformed HTTP request line")
-        if not version.startswith("HTTP/"):
-            raise _BadRequestException("Malformed HTTP version in HTTP Request-Line")
-        # HTTPRequest wants an IP, not a full socket address
-        if self.address_family in (socket.AF_INET, socket.AF_INET6):
-            remote_ip = self.address[0]
-        else:
-            # Unix (or other) socket; fake the remote address
-            remote_ip = '0.0.0.0'
-
-        protocol = self.protocol
-
-        # xheaders can override the defaults
-        if self.xheaders:
-            # Squid uses X-Forwarded-For, others use X-Real-Ip
-            ip = headers.get("X-Forwarded-For", remote_ip)
-            ip = ip.split(',')[-1].strip()
-            ip = headers.get("X-Real-Ip", ip)
-            if netutil.is_valid_ip(ip):
-                remote_ip = ip
-            # AWS uses X-Forwarded-Proto
-            proto_header = headers.get(
-                "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol))
-            if proto_header in ("http", "https"):
-                protocol = proto_header
-
-        return httputil.HTTPServerRequest(
-            connection=self, method=method, uri=uri, version=version,
-            headers=headers, remote_ip=remote_ip, protocol=protocol)
-
     def _read_body(self, headers):
         content_length = headers.get("Content-Length")
         if content_length:
             content_length = int(content_length)
             if content_length > self.stream.max_buffer_size:
-                raise _BadRequestException("Content-Length too long")
+                raise httputil.BadRequestException("Content-Length too long")
             if headers.get("Expect") == "100-continue":
                 self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
             return self.stream.read_bytes(content_length)
         return None
-
-    def _parse_body(self, request):
-        if self._request.method in ("POST", "PATCH", "PUT"):
-            httputil.parse_body_arguments(
-                self._request.headers.get("Content-Type", ""), request.body,
-                self._request.body_arguments, self._request.files)
-
-            for k, v in self._request.body_arguments.items():
-                self._request.arguments.setdefault(k, []).extend(v)
index e30bc32f5ef1fff6377e8d3a167f43316eb427a7..44a2c94bde477d26d863710fcf2c393e28f94db2 100644 (file)
@@ -28,12 +28,14 @@ class except to start a server at the beginning of the process
 
 from __future__ import absolute_import, division, print_function, with_statement
 
+import socket
 
 from tornado import http1connection, httputil
+from tornado import netutil
 from tornado.tcpserver import TCPServer
 
 
-class HTTPServer(TCPServer):
+class HTTPServer(TCPServer, httputil.HTTPConnectionDelegate):
     r"""A non-blocking, single-threaded HTTP server.
 
     A server is defined by a request callback that takes an HTTPRequest
@@ -141,8 +143,66 @@ class HTTPServer(TCPServer):
                            **kwargs)
 
     def handle_stream(self, stream, address):
-        HTTPConnection(stream, address, self.request_callback,
-                       self.no_keep_alive, self.xheaders, self.protocol)
+        HTTPConnection(stream, address, self, self.no_keep_alive, self.protocol)
+
+    def start_request(self, connection):
+        return _ServerRequestProcessor(self, connection)
+
+class _ServerRequestProcessor(httputil.HTTPStreamDelegate):
+    def __init__(self, server, connection):
+        self.server = server
+        self.connection = connection
+
+    def headers_received(self, start_line, headers):
+        pass
+        try:
+            method, uri, version = start_line.split(" ")
+        except ValueError:
+            raise httputil.BadRequestException("Malformed HTTP request line")
+        if not version.startswith("HTTP/"):
+            raise httputil.BadRequestException("Malformed HTTP version in HTTP Request-Line")
+        # HTTPRequest wants an IP, not a full socket address
+        if self.connection.address_family in (socket.AF_INET, socket.AF_INET6):
+            remote_ip = self.connection.address[0]
+        else:
+            # Unix (or other) socket; fake the remote address
+            remote_ip = '0.0.0.0'
+
+        protocol = self.connection.protocol
+
+        # xheaders can override the defaults
+        if self.server.xheaders:
+            # Squid uses X-Forwarded-For, others use X-Real-Ip
+            ip = headers.get("X-Forwarded-For", remote_ip)
+            ip = ip.split(',')[-1].strip()
+            ip = headers.get("X-Real-Ip", ip)
+            if netutil.is_valid_ip(ip):
+                remote_ip = ip
+            # AWS uses X-Forwarded-Proto
+            proto_header = headers.get(
+                "X-Scheme", headers.get("X-Forwarded-Proto", protocol))
+            if proto_header in ("http", "https"):
+                protocol = proto_header
+
+        self.request = httputil.HTTPServerRequest(
+            connection=self.connection, method=method, uri=uri, version=version,
+            headers=headers, remote_ip=remote_ip, protocol=protocol)
+
+    def data_received(self, chunk):
+        assert not self.request.body
+        self.request.body = chunk
+
+    def finish(self):
+        if self.request.method in ("POST", "PATCH", "PUT"):
+            httputil.parse_body_arguments(
+                self.request.headers.get("Content-Type", ""), self.request.body,
+                self.request.body_arguments, self.request.files)
+
+            for k, v in self.request.body_arguments.items():
+                self.request.arguments.setdefault(k, []).extend(v)
+
+        self.server.request_callback(self.request)
+
 
 
 HTTPRequest = httputil.HTTPServerRequest
index fac21ec0869dd664a60dd51f7b5b5f7909217553..8164d2eb84fae414a72a04ab8c63ed9886c1d5c9 100644 (file)
@@ -411,6 +411,27 @@ class HTTPServerRequest(object):
             self.__class__.__name__, args, dict(self.headers))
 
 
+class BadRequestException(Exception):
+    """Exception class for malformed HTTP requests."""
+    pass
+
+
+class HTTPConnectionDelegate(object):
+    def start_request(self, connection):
+        raise NotImplementedError()
+
+
+class HTTPStreamDelegate(object):
+    def headers_received(self, start_line, headers):
+        pass
+
+    def data_received(self, chunk):
+        pass
+
+    def finish(self):
+        pass
+
+
 def url_concat(url, args):
     """Concatenate url and argument dictionary regardless of whether
     url has existing query parameters.