from __future__ import absolute_import, division, print_function, with_statement
-import socket
-
from tornado.concurrent import Future
from tornado.escape import native_str
from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado.log import gen_log
-from tornado import netutil
from tornado import stack_context
-class _BadRequestException(Exception):
- """Exception class for malformed HTTP requests."""
- pass
-
-
class HTTP1Connection(object):
"""Handles a connection to an HTTP client, executing HTTP requests.
We parse HTTP headers and bodies, and execute the request callback
until the HTTP conection is closed.
"""
- def __init__(self, stream, address, request_callback, no_keep_alive=False,
- xheaders=False, protocol=None):
+ def __init__(self, stream, address, delegate, no_keep_alive=False,
+ protocol=None):
self.stream = stream
self.address = address
# Save the socket's address family now so we know how to
# interpret self.address even after the stream is closed
# and its socket attribute replaced with None.
self.address_family = stream.socket.family
- self.request_callback = request_callback
+ self.delegate = delegate
self.no_keep_alive = no_keep_alive
- self.xheaders = xheaders
if protocol:
self.protocol = protocol
elif isinstance(stream, iostream.SSLIOStream):
self.protocol = "https"
else:
self.protocol = "http"
+ self._disconnect_on_finish = False
self._clear_request_state()
self.stream.set_close_callback(self._on_connection_close)
self._finish_future = None
while True:
try:
header_data = yield self.stream.read_until(b"\r\n\r\n")
+ request_delegate = self.delegate.start_request(self)
self._finish_future = Future()
start_line, headers = self._parse_headers(header_data)
- request = self._make_request(start_line, headers)
- self._request = request
+ self._disconnect_on_finish = not self._can_keep_alive(
+ start_line, headers)
+ request_delegate.headers_received(start_line, headers)
body_future = self._read_body(headers)
if body_future is not None:
- request.body = yield body_future
- self._parse_body(request)
- self.request_callback(request)
+ request_delegate.data_received((yield body_future))
+ request_delegate.finish()
yield self._finish_future
- except _BadRequestException as e:
+ except httputil.BadRequestException as e:
gen_log.info("Malformed HTTP request from %r: %s",
self.address, e)
self.close()
and when the connection is closed (to break up cycles and
facilitate garbage collection in cpython).
"""
- self._request = None
self._request_finished = False
self._write_callback = None
self._close_callback = None
if self._request_finished and not self.stream.writing():
self._finish_request()
+ def _can_keep_alive(self, start_line, headers):
+ if self.no_keep_alive:
+ return False
+ connection_header = headers.get("Connection")
+ if connection_header is not None:
+ connection_header = connection_header.lower()
+ if start_line.endswith("HTTP/1.1"):
+ return connection_header != "close"
+ elif ("Content-Length" in headers
+ or start_line.startswith(("HEAD ", "GET "))):
+ return connection_header == "keep-alive"
+ return False
+
def _finish_request(self):
- if self.no_keep_alive or self._request is None:
- disconnect = True
- else:
- connection_header = self._request.headers.get("Connection")
- if connection_header is not None:
- connection_header = connection_header.lower()
- if self._request.supports_http_1_1():
- disconnect = connection_header == "close"
- elif ("Content-Length" in self._request.headers
- or self._request.method in ("HEAD", "GET")):
- disconnect = connection_header != "keep-alive"
- else:
- disconnect = True
self._clear_request_state()
- if disconnect:
+ if self._disconnect_on_finish:
self.close()
return
# Turn Nagle's algorithm back on, leaving the stream in its
headers = httputil.HTTPHeaders.parse(data[eol:])
except ValueError:
# probably form split() if there was no ':' in the line
- raise _BadRequestException("Malformed HTTP headers")
+ raise httputil.BadRequestException("Malformed HTTP headers")
return start_line, headers
- def _make_request(self, start_line, headers):
- try:
- method, uri, version = start_line.split(" ")
- except ValueError:
- raise _BadRequestException("Malformed HTTP request line")
- if not version.startswith("HTTP/"):
- raise _BadRequestException("Malformed HTTP version in HTTP Request-Line")
- # HTTPRequest wants an IP, not a full socket address
- if self.address_family in (socket.AF_INET, socket.AF_INET6):
- remote_ip = self.address[0]
- else:
- # Unix (or other) socket; fake the remote address
- remote_ip = '0.0.0.0'
-
- protocol = self.protocol
-
- # xheaders can override the defaults
- if self.xheaders:
- # Squid uses X-Forwarded-For, others use X-Real-Ip
- ip = headers.get("X-Forwarded-For", remote_ip)
- ip = ip.split(',')[-1].strip()
- ip = headers.get("X-Real-Ip", ip)
- if netutil.is_valid_ip(ip):
- remote_ip = ip
- # AWS uses X-Forwarded-Proto
- proto_header = headers.get(
- "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol))
- if proto_header in ("http", "https"):
- protocol = proto_header
-
- return httputil.HTTPServerRequest(
- connection=self, method=method, uri=uri, version=version,
- headers=headers, remote_ip=remote_ip, protocol=protocol)
-
def _read_body(self, headers):
content_length = headers.get("Content-Length")
if content_length:
content_length = int(content_length)
if content_length > self.stream.max_buffer_size:
- raise _BadRequestException("Content-Length too long")
+ raise httputil.BadRequestException("Content-Length too long")
if headers.get("Expect") == "100-continue":
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
return self.stream.read_bytes(content_length)
return None
-
- def _parse_body(self, request):
- if self._request.method in ("POST", "PATCH", "PUT"):
- httputil.parse_body_arguments(
- self._request.headers.get("Content-Type", ""), request.body,
- self._request.body_arguments, self._request.files)
-
- for k, v in self._request.body_arguments.items():
- self._request.arguments.setdefault(k, []).extend(v)
from __future__ import absolute_import, division, print_function, with_statement
+import socket
from tornado import http1connection, httputil
+from tornado import netutil
from tornado.tcpserver import TCPServer
-class HTTPServer(TCPServer):
+class HTTPServer(TCPServer, httputil.HTTPConnectionDelegate):
r"""A non-blocking, single-threaded HTTP server.
A server is defined by a request callback that takes an HTTPRequest
**kwargs)
def handle_stream(self, stream, address):
- HTTPConnection(stream, address, self.request_callback,
- self.no_keep_alive, self.xheaders, self.protocol)
+ HTTPConnection(stream, address, self, self.no_keep_alive, self.protocol)
+
+ def start_request(self, connection):
+ return _ServerRequestProcessor(self, connection)
+
+class _ServerRequestProcessor(httputil.HTTPStreamDelegate):
+ def __init__(self, server, connection):
+ self.server = server
+ self.connection = connection
+
+ def headers_received(self, start_line, headers):
+ pass
+ try:
+ method, uri, version = start_line.split(" ")
+ except ValueError:
+ raise httputil.BadRequestException("Malformed HTTP request line")
+ if not version.startswith("HTTP/"):
+ raise httputil.BadRequestException("Malformed HTTP version in HTTP Request-Line")
+ # HTTPRequest wants an IP, not a full socket address
+ if self.connection.address_family in (socket.AF_INET, socket.AF_INET6):
+ remote_ip = self.connection.address[0]
+ else:
+ # Unix (or other) socket; fake the remote address
+ remote_ip = '0.0.0.0'
+
+ protocol = self.connection.protocol
+
+ # xheaders can override the defaults
+ if self.server.xheaders:
+ # Squid uses X-Forwarded-For, others use X-Real-Ip
+ ip = headers.get("X-Forwarded-For", remote_ip)
+ ip = ip.split(',')[-1].strip()
+ ip = headers.get("X-Real-Ip", ip)
+ if netutil.is_valid_ip(ip):
+ remote_ip = ip
+ # AWS uses X-Forwarded-Proto
+ proto_header = headers.get(
+ "X-Scheme", headers.get("X-Forwarded-Proto", protocol))
+ if proto_header in ("http", "https"):
+ protocol = proto_header
+
+ self.request = httputil.HTTPServerRequest(
+ connection=self.connection, method=method, uri=uri, version=version,
+ headers=headers, remote_ip=remote_ip, protocol=protocol)
+
+ def data_received(self, chunk):
+ assert not self.request.body
+ self.request.body = chunk
+
+ def finish(self):
+ if self.request.method in ("POST", "PATCH", "PUT"):
+ httputil.parse_body_arguments(
+ self.request.headers.get("Content-Type", ""), self.request.body,
+ self.request.body_arguments, self.request.files)
+
+ for k, v in self.request.body_arguments.items():
+ self.request.arguments.setdefault(k, []).extend(v)
+
+ self.server.request_callback(self.request)
+
HTTPRequest = httputil.HTTPServerRequest