From: Ben Darnell Date: Mon, 24 Feb 2014 05:53:53 +0000 (-0500) Subject: Move HTTPServerRequest-specific logic from http1connection to httpserver. X-Git-Tag: v4.0.0b1~91^2~55 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e564ddbe32bb2b250beef44eca3bdeb03fe4838d;p=thirdparty%2Ftornado.git Move HTTPServerRequest-specific logic from http1connection to httpserver. --- diff --git a/tornado/http1connection.py b/tornado/http1connection.py index 8bf2cc122..b491de6c4 100644 --- a/tornado/http1connection.py +++ b/tornado/http1connection.py @@ -16,46 +16,38 @@ from __future__ import absolute_import, division, print_function, with_statement -import socket - from tornado.concurrent import Future from tornado.escape import native_str from tornado import gen from tornado import httputil from tornado import iostream from tornado.log import gen_log -from tornado import netutil from tornado import stack_context -class _BadRequestException(Exception): - """Exception class for malformed HTTP requests.""" - pass - - class HTTP1Connection(object): """Handles a connection to an HTTP client, executing HTTP requests. We parse HTTP headers and bodies, and execute the request callback until the HTTP conection is closed. """ - def __init__(self, stream, address, request_callback, no_keep_alive=False, - xheaders=False, protocol=None): + def __init__(self, stream, address, delegate, no_keep_alive=False, + protocol=None): self.stream = stream self.address = address # Save the socket's address family now so we know how to # interpret self.address even after the stream is closed # and its socket attribute replaced with None. self.address_family = stream.socket.family - self.request_callback = request_callback + self.delegate = delegate self.no_keep_alive = no_keep_alive - self.xheaders = xheaders if protocol: self.protocol = protocol elif isinstance(stream, iostream.SSLIOStream): self.protocol = "https" else: self.protocol = "http" + self._disconnect_on_finish = False self._clear_request_state() self.stream.set_close_callback(self._on_connection_close) self._finish_future = None @@ -68,17 +60,18 @@ class HTTP1Connection(object): while True: try: header_data = yield self.stream.read_until(b"\r\n\r\n") + request_delegate = self.delegate.start_request(self) self._finish_future = Future() start_line, headers = self._parse_headers(header_data) - request = self._make_request(start_line, headers) - self._request = request + self._disconnect_on_finish = not self._can_keep_alive( + start_line, headers) + request_delegate.headers_received(start_line, headers) body_future = self._read_body(headers) if body_future is not None: - request.body = yield body_future - self._parse_body(request) - self.request_callback(request) + request_delegate.data_received((yield body_future)) + request_delegate.finish() yield self._finish_future - except _BadRequestException as e: + except httputil.BadRequestException as e: gen_log.info("Malformed HTTP request from %r: %s", self.address, e) self.close() @@ -96,7 +89,6 @@ class HTTP1Connection(object): and when the connection is closed (to break up cycles and facilitate garbage collection in cpython). """ - self._request = None self._request_finished = False self._write_callback = None self._close_callback = None @@ -157,22 +149,22 @@ class HTTP1Connection(object): if self._request_finished and not self.stream.writing(): self._finish_request() + def _can_keep_alive(self, start_line, headers): + if self.no_keep_alive: + return False + connection_header = headers.get("Connection") + if connection_header is not None: + connection_header = connection_header.lower() + if start_line.endswith("HTTP/1.1"): + return connection_header != "close" + elif ("Content-Length" in headers + or start_line.startswith(("HEAD ", "GET "))): + return connection_header == "keep-alive" + return False + def _finish_request(self): - if self.no_keep_alive or self._request is None: - disconnect = True - else: - connection_header = self._request.headers.get("Connection") - if connection_header is not None: - connection_header = connection_header.lower() - if self._request.supports_http_1_1(): - disconnect = connection_header == "close" - elif ("Content-Length" in self._request.headers - or self._request.method in ("HEAD", "GET")): - disconnect = connection_header != "keep-alive" - else: - disconnect = True self._clear_request_state() - if disconnect: + if self._disconnect_on_finish: self.close() return # Turn Nagle's algorithm back on, leaving the stream in its @@ -188,59 +180,16 @@ class HTTP1Connection(object): headers = httputil.HTTPHeaders.parse(data[eol:]) except ValueError: # probably form split() if there was no ':' in the line - raise _BadRequestException("Malformed HTTP headers") + raise httputil.BadRequestException("Malformed HTTP headers") return start_line, headers - def _make_request(self, start_line, headers): - try: - method, uri, version = start_line.split(" ") - except ValueError: - raise _BadRequestException("Malformed HTTP request line") - if not version.startswith("HTTP/"): - raise _BadRequestException("Malformed HTTP version in HTTP Request-Line") - # HTTPRequest wants an IP, not a full socket address - if self.address_family in (socket.AF_INET, socket.AF_INET6): - remote_ip = self.address[0] - else: - # Unix (or other) socket; fake the remote address - remote_ip = '0.0.0.0' - - protocol = self.protocol - - # xheaders can override the defaults - if self.xheaders: - # Squid uses X-Forwarded-For, others use X-Real-Ip - ip = headers.get("X-Forwarded-For", remote_ip) - ip = ip.split(',')[-1].strip() - ip = headers.get("X-Real-Ip", ip) - if netutil.is_valid_ip(ip): - remote_ip = ip - # AWS uses X-Forwarded-Proto - proto_header = headers.get( - "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol)) - if proto_header in ("http", "https"): - protocol = proto_header - - return httputil.HTTPServerRequest( - connection=self, method=method, uri=uri, version=version, - headers=headers, remote_ip=remote_ip, protocol=protocol) - def _read_body(self, headers): content_length = headers.get("Content-Length") if content_length: content_length = int(content_length) if content_length > self.stream.max_buffer_size: - raise _BadRequestException("Content-Length too long") + raise httputil.BadRequestException("Content-Length too long") if headers.get("Expect") == "100-continue": self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n") return self.stream.read_bytes(content_length) return None - - def _parse_body(self, request): - if self._request.method in ("POST", "PATCH", "PUT"): - httputil.parse_body_arguments( - self._request.headers.get("Content-Type", ""), request.body, - self._request.body_arguments, self._request.files) - - for k, v in self._request.body_arguments.items(): - self._request.arguments.setdefault(k, []).extend(v) diff --git a/tornado/httpserver.py b/tornado/httpserver.py index e30bc32f5..44a2c94bd 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -28,12 +28,14 @@ class except to start a server at the beginning of the process from __future__ import absolute_import, division, print_function, with_statement +import socket from tornado import http1connection, httputil +from tornado import netutil from tornado.tcpserver import TCPServer -class HTTPServer(TCPServer): +class HTTPServer(TCPServer, httputil.HTTPConnectionDelegate): r"""A non-blocking, single-threaded HTTP server. A server is defined by a request callback that takes an HTTPRequest @@ -141,8 +143,66 @@ class HTTPServer(TCPServer): **kwargs) def handle_stream(self, stream, address): - HTTPConnection(stream, address, self.request_callback, - self.no_keep_alive, self.xheaders, self.protocol) + HTTPConnection(stream, address, self, self.no_keep_alive, self.protocol) + + def start_request(self, connection): + return _ServerRequestProcessor(self, connection) + +class _ServerRequestProcessor(httputil.HTTPStreamDelegate): + def __init__(self, server, connection): + self.server = server + self.connection = connection + + def headers_received(self, start_line, headers): + pass + try: + method, uri, version = start_line.split(" ") + except ValueError: + raise httputil.BadRequestException("Malformed HTTP request line") + if not version.startswith("HTTP/"): + raise httputil.BadRequestException("Malformed HTTP version in HTTP Request-Line") + # HTTPRequest wants an IP, not a full socket address + if self.connection.address_family in (socket.AF_INET, socket.AF_INET6): + remote_ip = self.connection.address[0] + else: + # Unix (or other) socket; fake the remote address + remote_ip = '0.0.0.0' + + protocol = self.connection.protocol + + # xheaders can override the defaults + if self.server.xheaders: + # Squid uses X-Forwarded-For, others use X-Real-Ip + ip = headers.get("X-Forwarded-For", remote_ip) + ip = ip.split(',')[-1].strip() + ip = headers.get("X-Real-Ip", ip) + if netutil.is_valid_ip(ip): + remote_ip = ip + # AWS uses X-Forwarded-Proto + proto_header = headers.get( + "X-Scheme", headers.get("X-Forwarded-Proto", protocol)) + if proto_header in ("http", "https"): + protocol = proto_header + + self.request = httputil.HTTPServerRequest( + connection=self.connection, method=method, uri=uri, version=version, + headers=headers, remote_ip=remote_ip, protocol=protocol) + + def data_received(self, chunk): + assert not self.request.body + self.request.body = chunk + + def finish(self): + if self.request.method in ("POST", "PATCH", "PUT"): + httputil.parse_body_arguments( + self.request.headers.get("Content-Type", ""), self.request.body, + self.request.body_arguments, self.request.files) + + for k, v in self.request.body_arguments.items(): + self.request.arguments.setdefault(k, []).extend(v) + + self.server.request_callback(self.request) + HTTPRequest = httputil.HTTPServerRequest diff --git a/tornado/httputil.py b/tornado/httputil.py index fac21ec08..8164d2eb8 100644 --- a/tornado/httputil.py +++ b/tornado/httputil.py @@ -411,6 +411,27 @@ class HTTPServerRequest(object): self.__class__.__name__, args, dict(self.headers)) +class BadRequestException(Exception): + """Exception class for malformed HTTP requests.""" + pass + + +class HTTPConnectionDelegate(object): + def start_request(self, connection): + raise NotImplementedError() + + +class HTTPStreamDelegate(object): + def headers_received(self, start_line, headers): + pass + + def data_received(self, chunk): + pass + + def finish(self): + pass + + def url_concat(url, args): """Concatenate url and argument dictionary regardless of whether url has existing query parameters.