]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Refactor HTTP1Connection to use coroutines.
authorBen Darnell <ben@bendarnell.com>
Sun, 23 Feb 2014 21:39:29 +0000 (16:39 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 23 Feb 2014 21:39:29 +0000 (16:39 -0500)
tornado/http1connection.py

index 5c931dd48a1852ca01e2b67cce8d36980f0b66d9..8bf2cc1225b231c367b222a6d25e7deeb126f819 100644 (file)
@@ -18,7 +18,9 @@ from __future__ import absolute_import, division, print_function, with_statement
 
 import socket
 
+from tornado.concurrent import Future
 from tornado.escape import native_str
+from tornado import gen
 from tornado import httputil
 from tornado import iostream
 from tornado.log import gen_log
@@ -55,11 +57,36 @@ class HTTP1Connection(object):
         else:
             self.protocol = "http"
         self._clear_request_state()
-        # Save stack context here, outside of any request.  This keeps
-        # contexts from one request from leaking into the next.
-        self._header_callback = stack_context.wrap(self._on_headers)
         self.stream.set_close_callback(self._on_connection_close)
-        self.stream.read_until(b"\r\n\r\n", self._header_callback)
+        self._finish_future = None
+        # Register the future on the IOLoop so its errors get logged.
+        stream.io_loop.add_future(self._process_requests(),
+                                  lambda f: f.result())
+
+    @gen.coroutine
+    def _process_requests(self):
+        while True:
+            try:
+                header_data = yield self.stream.read_until(b"\r\n\r\n")
+                self._finish_future = Future()
+                start_line, headers = self._parse_headers(header_data)
+                request = self._make_request(start_line, headers)
+                self._request = request
+                body_future = self._read_body(headers)
+                if body_future is not None:
+                    request.body = yield body_future
+                self._parse_body(request)
+                self.request_callback(request)
+                yield self._finish_future
+            except _BadRequestException as e:
+                gen_log.info("Malformed HTTP request from %r: %s",
+                             self.address, e)
+                self.close()
+                return
+            except iostream.StreamClosedError:
+                self.close()
+                return
+
 
     def _clear_request_state(self):
         """Clears the per-request state.
@@ -89,15 +116,15 @@ class HTTP1Connection(object):
             callback = self._close_callback
             self._close_callback = None
             callback()
+        if self._finish_future is not None and not self._finish_future.done():
+            self._finish_future.set_result(None)
         # Delete any unfinished callbacks to break up reference cycles.
-        self._header_callback = None
         self._clear_request_state()
 
     def close(self):
         self.stream.close()
         # Remove this reference to self, which would otherwise cause a
         # cycle and delay garbage collection of this connection.
-        self._header_callback = None
         self._clear_request_state()
 
     def write(self, chunk, callback=None):
@@ -148,86 +175,72 @@ class HTTP1Connection(object):
         if disconnect:
             self.close()
             return
+        # Turn Nagle's algorithm back on, leaving the stream in its
+        # default state for the next request.
+        self.stream.set_nodelay(False)
+        self._finish_future.set_result(None)
+
+    def _parse_headers(self, data):
+        data = native_str(data.decode('latin1'))
+        eol = data.find("\r\n")
+        start_line = data[:eol]
         try:
-            # Use a try/except instead of checking stream.closed()
-            # directly, because in some cases the stream doesn't discover
-            # that it's closed until you try to read from it.
-            self.stream.read_until(b"\r\n\r\n", self._header_callback)
-
-            # Turn Nagle's algorithm back on, leaving the stream in its
-            # default state for the next request.
-            self.stream.set_nodelay(False)
-        except iostream.StreamClosedError:
-            self.close()
+            headers = httputil.HTTPHeaders.parse(data[eol:])
+        except ValueError:
+            # probably form split() if there was no ':' in the line
+            raise _BadRequestException("Malformed HTTP headers")
+        return start_line, headers
 
-    def _on_headers(self, data):
+    def _make_request(self, start_line, headers):
         try:
-            data = native_str(data.decode('latin1'))
-            eol = data.find("\r\n")
-            start_line = data[:eol]
-            try:
-                method, uri, version = start_line.split(" ")
-            except ValueError:
-                raise _BadRequestException("Malformed HTTP request line")
-            if not version.startswith("HTTP/"):
-                raise _BadRequestException("Malformed HTTP version in HTTP Request-Line")
-            try:
-                headers = httputil.HTTPHeaders.parse(data[eol:])
-            except ValueError:
-                # Probably from split() if there was no ':' in the line
-                raise _BadRequestException("Malformed HTTP headers")
-
-            # HTTPRequest wants an IP, not a full socket address
-            if self.address_family in (socket.AF_INET, socket.AF_INET6):
-                remote_ip = self.address[0]
-            else:
-                # Unix (or other) socket; fake the remote address
-                remote_ip = '0.0.0.0'
-
-            protocol = self.protocol
-
-            # xheaders can override the defaults
-            if self.xheaders:
-                # Squid uses X-Forwarded-For, others use X-Real-Ip
-                ip = headers.get("X-Forwarded-For", remote_ip)
-                ip = ip.split(',')[-1].strip()
-                ip = headers.get("X-Real-Ip", ip)
-                if netutil.is_valid_ip(ip):
-                    remote_ip = ip
-                # AWS uses X-Forwarded-Proto
-                proto_header = headers.get(
-                    "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol))
-                if proto_header in ("http", "https"):
-                    protocol = proto_header
-
-            self._request = httputil.HTTPServerRequest(
-                connection=self, method=method, uri=uri, version=version,
-                headers=headers, remote_ip=remote_ip, protocol=protocol)
-
-            content_length = headers.get("Content-Length")
-            if content_length:
-                content_length = int(content_length)
-                if content_length > self.stream.max_buffer_size:
-                    raise _BadRequestException("Content-Length too long")
-                if headers.get("Expect") == "100-continue":
-                    self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
-                self.stream.read_bytes(content_length, self._on_request_body)
-                return
-
-            self.request_callback(self._request)
-        except _BadRequestException as e:
-            gen_log.info("Malformed HTTP request from %r: %s",
-                         self.address, e)
-            self.close()
-            return
-
-    def _on_request_body(self, data):
-        self._request.body = data
+            method, uri, version = start_line.split(" ")
+        except ValueError:
+            raise _BadRequestException("Malformed HTTP request line")
+        if not version.startswith("HTTP/"):
+            raise _BadRequestException("Malformed HTTP version in HTTP Request-Line")
+        # HTTPRequest wants an IP, not a full socket address
+        if self.address_family in (socket.AF_INET, socket.AF_INET6):
+            remote_ip = self.address[0]
+        else:
+            # Unix (or other) socket; fake the remote address
+            remote_ip = '0.0.0.0'
+
+        protocol = self.protocol
+
+        # xheaders can override the defaults
+        if self.xheaders:
+            # Squid uses X-Forwarded-For, others use X-Real-Ip
+            ip = headers.get("X-Forwarded-For", remote_ip)
+            ip = ip.split(',')[-1].strip()
+            ip = headers.get("X-Real-Ip", ip)
+            if netutil.is_valid_ip(ip):
+                remote_ip = ip
+            # AWS uses X-Forwarded-Proto
+            proto_header = headers.get(
+                "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol))
+            if proto_header in ("http", "https"):
+                protocol = proto_header
+
+        return httputil.HTTPServerRequest(
+            connection=self, method=method, uri=uri, version=version,
+            headers=headers, remote_ip=remote_ip, protocol=protocol)
+
+    def _read_body(self, headers):
+        content_length = headers.get("Content-Length")
+        if content_length:
+            content_length = int(content_length)
+            if content_length > self.stream.max_buffer_size:
+                raise _BadRequestException("Content-Length too long")
+            if headers.get("Expect") == "100-continue":
+                self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
+            return self.stream.read_bytes(content_length)
+        return None
+
+    def _parse_body(self, request):
         if self._request.method in ("POST", "PATCH", "PUT"):
             httputil.parse_body_arguments(
-                self._request.headers.get("Content-Type", ""), data,
+                self._request.headers.get("Content-Type", ""), request.body,
                 self._request.body_arguments, self._request.files)
 
             for k, v in self._request.body_arguments.items():
                 self._request.arguments.setdefault(k, []).extend(v)
-        self.request_callback(self._request)