From 5c283f9c6bf6c91d65be65a09dbada5a0116e2ad Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sun, 23 Feb 2014 16:39:29 -0500 Subject: [PATCH] Refactor HTTP1Connection to use coroutines. --- tornado/http1connection.py | 173 ++++++++++++++++++++----------------- 1 file changed, 93 insertions(+), 80 deletions(-) diff --git a/tornado/http1connection.py b/tornado/http1connection.py index 5c931dd48..8bf2cc122 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -18,7 +18,9 @@ from __future__ import absolute_import, division, print_function, with_statement import socket +from tornado.concurrent import Future from tornado.escape import native_str +from tornado import gen from tornado import httputil from tornado import iostream from tornado.log import gen_log @@ -55,11 +57,36 @@ class HTTP1Connection(object): else: self.protocol = "http" self._clear_request_state() - # Save stack context here, outside of any request. This keeps - # contexts from one request from leaking into the next. - self._header_callback = stack_context.wrap(self._on_headers) self.stream.set_close_callback(self._on_connection_close) - self.stream.read_until(b"\r\n\r\n", self._header_callback) + self._finish_future = None + # Register the future on the IOLoop so its errors get logged. + stream.io_loop.add_future(self._process_requests(), + lambda f: f.result()) + + @gen.coroutine + def _process_requests(self): + while True: + try: + header_data = yield self.stream.read_until(b"\r\n\r\n") + self._finish_future = Future() + start_line, headers = self._parse_headers(header_data) + request = self._make_request(start_line, headers) + self._request = request + body_future = self._read_body(headers) + if body_future is not None: + request.body = yield body_future + self._parse_body(request) + self.request_callback(request) + yield self._finish_future + except _BadRequestException as e: + gen_log.info("Malformed HTTP request from %r: %s", + self.address, e) + self.close() + return + except iostream.StreamClosedError: + self.close() + return + def _clear_request_state(self): """Clears the per-request state. @@ -89,15 +116,15 @@ class HTTP1Connection(object): callback = self._close_callback self._close_callback = None callback() + if self._finish_future is not None and not self._finish_future.done(): + self._finish_future.set_result(None) # Delete any unfinished callbacks to break up reference cycles. - self._header_callback = None self._clear_request_state() def close(self): self.stream.close() # Remove this reference to self, which would otherwise cause a # cycle and delay garbage collection of this connection. - self._header_callback = None self._clear_request_state() def write(self, chunk, callback=None): @@ -148,86 +175,72 @@ class HTTP1Connection(object): if disconnect: self.close() return + # Turn Nagle's algorithm back on, leaving the stream in its + # default state for the next request. + self.stream.set_nodelay(False) + self._finish_future.set_result(None) + + def _parse_headers(self, data): + data = native_str(data.decode('latin1')) + eol = data.find("\r\n") + start_line = data[:eol] try: - # Use a try/except instead of checking stream.closed() - # directly, because in some cases the stream doesn't discover - # that it's closed until you try to read from it. - self.stream.read_until(b"\r\n\r\n", self._header_callback) - - # Turn Nagle's algorithm back on, leaving the stream in its - # default state for the next request. - self.stream.set_nodelay(False) - except iostream.StreamClosedError: - self.close() + headers = httputil.HTTPHeaders.parse(data[eol:]) + except ValueError: + # probably form split() if there was no ':' in the line + raise _BadRequestException("Malformed HTTP headers") + return start_line, headers - def _on_headers(self, data): + def _make_request(self, start_line, headers): try: - data = native_str(data.decode('latin1')) - eol = data.find("\r\n") - start_line = data[:eol] - try: - method, uri, version = start_line.split(" ") - except ValueError: - raise _BadRequestException("Malformed HTTP request line") - if not version.startswith("HTTP/"): - raise _BadRequestException("Malformed HTTP version in HTTP Request-Line") - try: - headers = httputil.HTTPHeaders.parse(data[eol:]) - except ValueError: - # Probably from split() if there was no ':' in the line - raise _BadRequestException("Malformed HTTP headers") - - # HTTPRequest wants an IP, not a full socket address - if self.address_family in (socket.AF_INET, socket.AF_INET6): - remote_ip = self.address[0] - else: - # Unix (or other) socket; fake the remote address - remote_ip = '0.0.0.0' - - protocol = self.protocol - - # xheaders can override the defaults - if self.xheaders: - # Squid uses X-Forwarded-For, others use X-Real-Ip - ip = headers.get("X-Forwarded-For", remote_ip) - ip = ip.split(',')[-1].strip() - ip = headers.get("X-Real-Ip", ip) - if netutil.is_valid_ip(ip): - remote_ip = ip - # AWS uses X-Forwarded-Proto - proto_header = headers.get( - "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol)) - if proto_header in ("http", "https"): - protocol = proto_header - - self._request = httputil.HTTPServerRequest( - connection=self, method=method, uri=uri, version=version, - headers=headers, remote_ip=remote_ip, protocol=protocol) - - content_length = headers.get("Content-Length") - if content_length: - content_length = int(content_length) - if content_length > self.stream.max_buffer_size: - raise _BadRequestException("Content-Length too long") - if headers.get("Expect") == "100-continue": - self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n") - self.stream.read_bytes(content_length, self._on_request_body) - return - - self.request_callback(self._request) - except _BadRequestException as e: - gen_log.info("Malformed HTTP request from %r: %s", - self.address, e) - self.close() - return - - def _on_request_body(self, data): - self._request.body = data + method, uri, version = start_line.split(" ") + except ValueError: + raise _BadRequestException("Malformed HTTP request line") + if not version.startswith("HTTP/"): + raise _BadRequestException("Malformed HTTP version in HTTP Request-Line") + # HTTPRequest wants an IP, not a full socket address + if self.address_family in (socket.AF_INET, socket.AF_INET6): + remote_ip = self.address[0] + else: + # Unix (or other) socket; fake the remote address + remote_ip = '0.0.0.0' + + protocol = self.protocol + + # xheaders can override the defaults + if self.xheaders: + # Squid uses X-Forwarded-For, others use X-Real-Ip + ip = headers.get("X-Forwarded-For", remote_ip) + ip = ip.split(',')[-1].strip() + ip = headers.get("X-Real-Ip", ip) + if netutil.is_valid_ip(ip): + remote_ip = ip + # AWS uses X-Forwarded-Proto + proto_header = headers.get( + "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol)) + if proto_header in ("http", "https"): + protocol = proto_header + + return httputil.HTTPServerRequest( + connection=self, method=method, uri=uri, version=version, + headers=headers, remote_ip=remote_ip, protocol=protocol) + + def _read_body(self, headers): + content_length = headers.get("Content-Length") + if content_length: + content_length = int(content_length) + if content_length > self.stream.max_buffer_size: + raise _BadRequestException("Content-Length too long") + if headers.get("Expect") == "100-continue": + self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n") + return self.stream.read_bytes(content_length) + return None + + def _parse_body(self, request): if self._request.method in ("POST", "PATCH", "PUT"): httputil.parse_body_arguments( - self._request.headers.get("Content-Type", ""), data, + self._request.headers.get("Content-Type", ""), request.body, self._request.body_arguments, self._request.files) for k, v in self._request.body_arguments.items(): self._request.arguments.setdefault(k, []).extend(v) - self.request_callback(self._request) -- 2.47.2