import socket
+from tornado.concurrent import Future
from tornado.escape import native_str
+from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado.log import gen_log
else:
self.protocol = "http"
self._clear_request_state()
- # Save stack context here, outside of any request. This keeps
- # contexts from one request from leaking into the next.
- self._header_callback = stack_context.wrap(self._on_headers)
self.stream.set_close_callback(self._on_connection_close)
- self.stream.read_until(b"\r\n\r\n", self._header_callback)
+ self._finish_future = None
+ # Register the future on the IOLoop so its errors get logged.
+ stream.io_loop.add_future(self._process_requests(),
+ lambda f: f.result())
+
+ @gen.coroutine
+ def _process_requests(self):
+ while True:
+ try:
+ header_data = yield self.stream.read_until(b"\r\n\r\n")
+ self._finish_future = Future()
+ start_line, headers = self._parse_headers(header_data)
+ request = self._make_request(start_line, headers)
+ self._request = request
+ body_future = self._read_body(headers)
+ if body_future is not None:
+ request.body = yield body_future
+ self._parse_body(request)
+ self.request_callback(request)
+ yield self._finish_future
+ except _BadRequestException as e:
+ gen_log.info("Malformed HTTP request from %r: %s",
+ self.address, e)
+ self.close()
+ return
+ except iostream.StreamClosedError:
+ self.close()
+ return
+
def _clear_request_state(self):
"""Clears the per-request state.
callback = self._close_callback
self._close_callback = None
callback()
+ if self._finish_future is not None and not self._finish_future.done():
+ self._finish_future.set_result(None)
# Delete any unfinished callbacks to break up reference cycles.
- self._header_callback = None
self._clear_request_state()
def close(self):
self.stream.close()
# Remove this reference to self, which would otherwise cause a
# cycle and delay garbage collection of this connection.
- self._header_callback = None
self._clear_request_state()
def write(self, chunk, callback=None):
if disconnect:
self.close()
return
+ # Turn Nagle's algorithm back on, leaving the stream in its
+ # default state for the next request.
+ self.stream.set_nodelay(False)
+ self._finish_future.set_result(None)
+
+ def _parse_headers(self, data):
+ data = native_str(data.decode('latin1'))
+ eol = data.find("\r\n")
+ start_line = data[:eol]
try:
- # Use a try/except instead of checking stream.closed()
- # directly, because in some cases the stream doesn't discover
- # that it's closed until you try to read from it.
- self.stream.read_until(b"\r\n\r\n", self._header_callback)
-
- # Turn Nagle's algorithm back on, leaving the stream in its
- # default state for the next request.
- self.stream.set_nodelay(False)
- except iostream.StreamClosedError:
- self.close()
+ headers = httputil.HTTPHeaders.parse(data[eol:])
+ except ValueError:
+ # probably form split() if there was no ':' in the line
+ raise _BadRequestException("Malformed HTTP headers")
+ return start_line, headers
- def _on_headers(self, data):
+ def _make_request(self, start_line, headers):
try:
- data = native_str(data.decode('latin1'))
- eol = data.find("\r\n")
- start_line = data[:eol]
- try:
- method, uri, version = start_line.split(" ")
- except ValueError:
- raise _BadRequestException("Malformed HTTP request line")
- if not version.startswith("HTTP/"):
- raise _BadRequestException("Malformed HTTP version in HTTP Request-Line")
- try:
- headers = httputil.HTTPHeaders.parse(data[eol:])
- except ValueError:
- # Probably from split() if there was no ':' in the line
- raise _BadRequestException("Malformed HTTP headers")
-
- # HTTPRequest wants an IP, not a full socket address
- if self.address_family in (socket.AF_INET, socket.AF_INET6):
- remote_ip = self.address[0]
- else:
- # Unix (or other) socket; fake the remote address
- remote_ip = '0.0.0.0'
-
- protocol = self.protocol
-
- # xheaders can override the defaults
- if self.xheaders:
- # Squid uses X-Forwarded-For, others use X-Real-Ip
- ip = headers.get("X-Forwarded-For", remote_ip)
- ip = ip.split(',')[-1].strip()
- ip = headers.get("X-Real-Ip", ip)
- if netutil.is_valid_ip(ip):
- remote_ip = ip
- # AWS uses X-Forwarded-Proto
- proto_header = headers.get(
- "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol))
- if proto_header in ("http", "https"):
- protocol = proto_header
-
- self._request = httputil.HTTPServerRequest(
- connection=self, method=method, uri=uri, version=version,
- headers=headers, remote_ip=remote_ip, protocol=protocol)
-
- content_length = headers.get("Content-Length")
- if content_length:
- content_length = int(content_length)
- if content_length > self.stream.max_buffer_size:
- raise _BadRequestException("Content-Length too long")
- if headers.get("Expect") == "100-continue":
- self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
- self.stream.read_bytes(content_length, self._on_request_body)
- return
-
- self.request_callback(self._request)
- except _BadRequestException as e:
- gen_log.info("Malformed HTTP request from %r: %s",
- self.address, e)
- self.close()
- return
-
- def _on_request_body(self, data):
- self._request.body = data
+ method, uri, version = start_line.split(" ")
+ except ValueError:
+ raise _BadRequestException("Malformed HTTP request line")
+ if not version.startswith("HTTP/"):
+ raise _BadRequestException("Malformed HTTP version in HTTP Request-Line")
+ # HTTPRequest wants an IP, not a full socket address
+ if self.address_family in (socket.AF_INET, socket.AF_INET6):
+ remote_ip = self.address[0]
+ else:
+ # Unix (or other) socket; fake the remote address
+ remote_ip = '0.0.0.0'
+
+ protocol = self.protocol
+
+ # xheaders can override the defaults
+ if self.xheaders:
+ # Squid uses X-Forwarded-For, others use X-Real-Ip
+ ip = headers.get("X-Forwarded-For", remote_ip)
+ ip = ip.split(',')[-1].strip()
+ ip = headers.get("X-Real-Ip", ip)
+ if netutil.is_valid_ip(ip):
+ remote_ip = ip
+ # AWS uses X-Forwarded-Proto
+ proto_header = headers.get(
+ "X-Scheme", headers.get("X-Forwarded-Proto", self.protocol))
+ if proto_header in ("http", "https"):
+ protocol = proto_header
+
+ return httputil.HTTPServerRequest(
+ connection=self, method=method, uri=uri, version=version,
+ headers=headers, remote_ip=remote_ip, protocol=protocol)
+
+ def _read_body(self, headers):
+ content_length = headers.get("Content-Length")
+ if content_length:
+ content_length = int(content_length)
+ if content_length > self.stream.max_buffer_size:
+ raise _BadRequestException("Content-Length too long")
+ if headers.get("Expect") == "100-continue":
+ self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
+ return self.stream.read_bytes(content_length)
+ return None
+
+ def _parse_body(self, request):
if self._request.method in ("POST", "PATCH", "PUT"):
httputil.parse_body_arguments(
- self._request.headers.get("Content-Type", ""), data,
+ self._request.headers.get("Content-Type", ""), request.body,
self._request.body_arguments, self._request.files)
for k, v in self._request.body_arguments.items():
self._request.arguments.setdefault(k, []).extend(v)
- self.request_callback(self._request)