This incidentally adds support for chunked request bodies on the server side.
We parse HTTP headers and bodies, and execute the request callback
until the HTTP conection is closed.
"""
- def __init__(self, stream, address, delegate, no_keep_alive=False,
- protocol=None):
+ def __init__(self, stream, address, no_keep_alive=False, protocol=None):
self.stream = stream
self.address = address
# Save the socket's address family now so we know how to
# interpret self.address even after the stream is closed
# and its socket attribute replaced with None.
self.address_family = stream.socket.family
- self.delegate = delegate
self.no_keep_alive = no_keep_alive
if protocol:
self.protocol = protocol
self._clear_request_state()
self.stream.set_close_callback(self._on_connection_close)
self._finish_future = None
+
+ def start_serving(self, delegate):
+ assert isinstance(delegate, httputil.HTTPConnectionDelegate)
# Register the future on the IOLoop so its errors get logged.
- stream.io_loop.add_future(self._process_requests(),
- lambda f: f.result())
+ self.stream.io_loop.add_future(self._process_requests(delegate),
+ lambda f: f.result())
@gen.coroutine
- def _process_requests(self):
+ def _process_requests(self, delegate):
while True:
+ request_delegate = delegate.start_request(self)
try:
- header_data = yield self.stream.read_until(b"\r\n\r\n")
- request_delegate = self.delegate.start_request(self)
- self._finish_future = Future()
- start_line, headers = self._parse_headers(header_data)
- self._disconnect_on_finish = not self._can_keep_alive(
- start_line, headers)
- request_delegate.headers_received(start_line, headers)
- body_future = self._read_body(headers)
- if body_future is not None:
- request_delegate.data_received((yield body_future))
- request_delegate.finish()
- yield self._finish_future
- except httputil.BadRequestException as e:
- gen_log.info("Malformed HTTP request from %r: %s",
- self.address, e)
- self.close()
- return
+ ret = yield self._process_message(request_delegate, False)
except iostream.StreamClosedError:
self.close()
return
+ if not ret:
+ return
+
+ def process_response(self, delegate, method):
+ return self._process_message(delegate, True, method=method)
+
+ @gen.coroutine
+ def _process_message(self, delegate, is_client, method=None):
+ assert isinstance(delegate, httputil.HTTPStreamDelegate)
+ try:
+ header_data = yield self.stream.read_until_regex(b"\r?\n\r?\n")
+ self._finish_future = Future()
+ start_line, headers = self._parse_headers(header_data)
+ self._disconnect_on_finish = not self._can_keep_alive(
+ start_line, headers)
+ ret = delegate.headers_received(start_line, headers)
+ # TODO: finalize the 'detach' interface.
+ if ret == 'detach':
+ return
+ skip_body = False
+ if is_client:
+ if method == 'HEAD':
+ skip_body = True
+ code = httputil.parse_response_start_line(start_line).code
+ if code == 304:
+ skip_body = True
+ if code >= 100 and code < 200:
+ yield self._process_message(delegate, is_client, method=method)
+ else:
+ if headers.get("Expect") == "100-continue":
+ self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
+ if not skip_body:
+ body_future = self._read_body(is_client, headers, delegate)
+ if body_future is not None:
+ yield body_future
+ delegate.finish()
+ yield self._finish_future
+ except httputil.BadRequestException as e:
+ gen_log.info("Malformed HTTP request from %r: %s",
+ self.address, e)
+ self.close()
+ raise gen.Return(False)
+ raise gen.Return(True)
def _clear_request_state(self):
raise httputil.BadRequestException("Malformed HTTP headers")
return start_line, headers
- def _read_body(self, headers):
+ def _read_body(self, is_client, headers, delegate):
content_length = headers.get("Content-Length")
if content_length:
content_length = int(content_length)
if content_length > self.stream.max_buffer_size:
raise httputil.BadRequestException("Content-Length too long")
- if headers.get("Expect") == "100-continue":
- self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
- return self.stream.read_bytes(content_length)
+ return self._read_fixed_body(content_length, delegate)
+ if headers.get("Transfer-Encoding") == "chunked":
+ return self._read_chunked_body(delegate)
+ if is_client:
+ return self._read_body_until_close(delegate)
return None
+
+ @gen.coroutine
+ def _read_fixed_body(self, content_length, delegate):
+ body = yield self.stream.read_bytes(content_length)
+ delegate.data_received(body)
+
+ @gen.coroutine
+ def _read_chunked_body(self, delegate):
+ # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
+ while True:
+ chunk_len = yield self.stream.read_until(b"\r\n")
+ chunk_len = int(chunk_len.strip(), 16)
+ if chunk_len == 0:
+ return
+ # chunk ends with \r\n
+ chunk = yield self.stream.read_bytes(chunk_len + 2)
+ assert chunk[-2:] == b"\r\n"
+ delegate.data_received(chunk[:-2])
+
+ @gen.coroutine
+ def _read_body_until_close(self, delegate):
+ body = yield self.stream.read_until_close()
+ delegate.data_received(body)
**kwargs)
def handle_stream(self, stream, address):
- HTTPConnection(stream, address, self, self.no_keep_alive, self.protocol)
+ conn = HTTPConnection(stream, address, self.no_keep_alive,
+ self.protocol)
+ conn.start_serving(self)
def start_request(self, connection):
return _ServerRequestProcessor(self, connection)
import datetime
import email.utils
import numbers
+import re
import time
from tornado.escape import native_str, parse_qs_bytes, utf8
raise TypeError("unknown timestamp type: %r" % ts)
return email.utils.formatdate(ts, usegmt=True)
+
+ResponseStartLine = collections.namedtuple(
+ 'ResponseStartLine', ['version', 'code', 'reason'])
+
+def parse_response_start_line(line):
+ """Returns a (version, code, reason) tuple for an HTTP 1.x response line.
+
+ The response is a `collections.namedtuple`.
+
+ >>> parse_response_start_line("HTTP/1.1 200 OK")
+ ResponseStartLine(version='HTTP/1.1', code=200, reason='OK')
+ """
+ line = native_str(line)
+ match = re.match("(HTTP/1.[01]) ([0-9]+) ([^\r]*)", line)
+ assert match
+ return ResponseStartLine(match.group(1), int(match.group(2)),
+ match.group(3))
+
# _parseparam and _parse_header are copied and modified from python2.7's cgi.py
# The original 2.7 version of this code did not correctly support some
# combinations of semicolons and double quotes.
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
-from tornado.escape import utf8, _unicode, native_str
+from tornado.escape import utf8, _unicode
from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
-from tornado.httputil import HTTPHeaders
-from tornado.iostream import IOStream, SSLIOStream
+from tornado import httputil
+from tornado.http1connection import HTTP1Connection
+from tornado.iostream import IOStream, SSLIOStream, StreamClosedError
from tornado.netutil import Resolver, OverrideResolver
from tornado.log import gen_log
from tornado import stack_context
del self.waiting[key]
-class _HTTPConnection(object):
+class _HTTPConnection(httputil.HTTPStreamDelegate):
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
def __init__(self, io_loop, client, request, release_callback,
self.resolver = resolver
self.code = None
self.headers = None
- self.chunks = None
+ self.chunks = []
self._decompressor = None
# Timeout handle returned by IOLoop.add_timeout
self._timeout = None
+ self._sockaddr = None
with stack_context.ExceptionStackContext(self._handle_exception):
self.parsed = urlparse.urlsplit(_unicode(self.request.url))
if self.parsed.scheme not in ("http", "https"):
self.stream.set_close_callback(self._on_close)
# ipv6 addresses are broken (in self.parsed.hostname) until
# 2.7, here is correctly parsed value calculated in __init__
- sockaddr = addrinfo[0][1]
- self.stream.connect(sockaddr, self._on_connect,
+ self._sockaddr = addrinfo[0][1]
+ self.stream.connect(self._sockaddr, self._on_connect,
server_hostname=self.parsed_hostname)
def _create_stream(self, addrinfo):
request_str += self.request.body
self.stream.set_nodelay(True)
self.stream.write(request_str)
- self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
+ self.connection = HTTP1Connection(
+ self.stream, self._sockaddr,
+ no_keep_alive=True, protocol=self.parsed.scheme)
+ # Ensure that any exception raised in process_response ends up in our
+ # stack context.
+ self.io_loop.add_future(
+ self.connection.process_response(self, method=self.request.method),
+ lambda f: f.result())
def _release(self):
if self.release_callback is not None:
def _handle_exception(self, typ, value, tb):
if self.final_callback:
self._remove_timeout()
+ if isinstance(value, StreamClosedError):
+ value = HTTPError(599, "Stream closed")
self._run_callback(HTTPResponse(self.request, 599, error=value,
request_time=self.io_loop.time() - self.start_time,
))
if hasattr(self, "stream"):
+ # TODO: this may cause a StreamClosedError to be raised
+ # by the connection's Future. Should we cancel the
+ # connection more gracefully?
self.stream.close()
return True
else:
# If our callback has already been called, we are probably
# catching an exception that is not caused by us but rather
# some child of our callback. Rather than drop it on the floor,
- # pass it along.
- return False
+ # pass it along, unless it's just the stream being closed.
+ return isinstance(value, StreamClosedError)
def _on_close(self):
if self.final_callback is not None:
message = str(self.stream.error)
raise HTTPError(599, message)
- def _handle_1xx(self, code):
- self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
-
- def _on_headers(self, data):
- data = native_str(data.decode("latin1"))
- first_line, _, header_data = data.partition("\n")
- match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line)
- assert match
- code = int(match.group(1))
- self.headers = HTTPHeaders.parse(header_data)
- if 100 <= code < 200:
- self._handle_1xx(code)
- return
- else:
- self.code = code
- self.reason = match.group(2)
+ def headers_received(self, first_line, headers):
+ self.headers = headers
+ version, code, reason = httputil.parse_response_start_line(first_line)
+ self.code = code
+ self.reason = reason
if "Content-Length" in self.headers:
if "," in self.headers["Content-Length"]:
if self.request.header_callback is not None:
# re-attach the newline we split on earlier
- self.request.header_callback(first_line + _)
+ self.request.header_callback(first_line + '\r\n')
for k, v in self.headers.get_all():
self.request.header_callback("%s: %s\r\n" % (k, v))
self.request.header_callback('\r\n')
- if self.request.method == "HEAD" or self.code == 304:
- # HEAD requests and 304 responses never have content, even
- # though they may have content-length headers
- self._on_body(b"")
- return
if 100 <= self.code < 200 or self.code == 204:
# These response codes never have bodies
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
content_length not in (None, 0)):
raise ValueError("Response with code %d should not have body" %
self.code)
- self._on_body(b"")
- return
if (self.request.use_gzip and
self.headers.get("Content-Encoding") == "gzip"):
self._decompressor = GzipDecompressor()
- if self.headers.get("Transfer-Encoding") == "chunked":
- self.chunks = []
- self.stream.read_until(b"\r\n", self._on_chunk_length)
- elif content_length is not None:
- self.stream.read_bytes(content_length, self._on_body)
- else:
- self.stream.read_until_close(self._on_body)
- def _on_body(self, data):
+ def finish(self):
+ if self._decompressor is not None:
+ tail = self._decompressor.flush()
+ if tail:
+ # I believe the tail will always be empty (i.e.
+ # decompress will return all it can). The purpose
+ # of the flush call is to detect errors such
+ # as truncated input. But in case it ever returns
+ # anything, treat it as an extra chunk
+ if self.request.streaming_callback is not None:
+ self.request.streaming_callback(tail)
+ else:
+ self.chunks.append(tail)
+ data = b''.join(self.chunks)
self._remove_timeout()
original_request = getattr(self.request, "original_request",
self.request)
self.client.fetch(new_request, final_callback)
self._on_end_request()
return
- if self._decompressor:
- data = (self._decompressor.decompress(data) +
- self._decompressor.flush())
if self.request.streaming_callback:
- if self.chunks is None:
- # if chunks is not None, we already called streaming_callback
- # in _on_chunk_data
- self.request.streaming_callback(data)
buffer = BytesIO()
else:
buffer = BytesIO(data) # TODO: don't require one big string?
response = HTTPResponse(original_request,
- self.code, reason=self.reason,
+ self.code, reason=getattr(self, 'reason', None),
headers=self.headers,
request_time=self.io_loop.time() - self.start_time,
buffer=buffer,
def _on_end_request(self):
self.stream.close()
- def _on_chunk_length(self, data):
- # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
- length = int(data.strip(), 16)
- if length == 0:
- if self._decompressor is not None:
- tail = self._decompressor.flush()
- if tail:
- # I believe the tail will always be empty (i.e.
- # decompress will return all it can). The purpose
- # of the flush call is to detect errors such
- # as truncated input. But in case it ever returns
- # anything, treat it as an extra chunk
- if self.request.streaming_callback is not None:
- self.request.streaming_callback(tail)
- else:
- self.chunks.append(tail)
- # all the data has been decompressed, so we don't need to
- # decompress again in _on_body
- self._decompressor = None
- self._on_body(b''.join(self.chunks))
- else:
- self.stream.read_bytes(length + 2, # chunk ends with \r\n
- self._on_chunk_data)
-
- def _on_chunk_data(self, data):
- assert data[-2:] == b"\r\n"
- chunk = data[:-2]
+ def data_received(self, chunk):
if self._decompressor:
chunk = self._decompressor.decompress(chunk)
if self.request.streaming_callback is not None:
self.request.streaming_callback(chunk)
else:
self.chunks.append(chunk)
- self.stream.read_until(b"\r\n", self._on_chunk_length)
if __name__ == "__main__":
from __future__ import absolute_import, division, print_function, with_statement
from tornado import httpclient, simple_httpclient, netutil
from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
+from tornado.http1connection import HTTP1Connection
from tornado.httpserver import HTTPServer
-from tornado.httputil import HTTPHeaders
+from tornado.httputil import HTTPHeaders, HTTPStreamDelegate
from tornado.iostream import IOStream
from tornado.log import gen_log
from tornado.netutil import ssl_options_to_context, Resolver
})
-class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
- def set_request(self, request):
- self.__next_request = request
-
- def _on_connect(self):
- self.stream.write(self.__next_request)
- self.__next_request = None
- self.stream.read_until(b"\r\n\r\n", self._on_headers)
-
# This test is also called from wsgi_test
-
-
class HTTPConnectionTest(AsyncHTTPTestCase):
def get_handlers(self):
return [("/multipart", MultipartTestHandler),
return Application(self.get_handlers())
def raw_fetch(self, headers, body):
- with closing(Resolver(io_loop=self.io_loop)) as resolver:
- with closing(SimpleAsyncHTTPClient(self.io_loop,
- resolver=resolver)) as client:
- conn = RawRequestHTTPConnection(
- self.io_loop, client,
- httpclient._RequestProxy(
- httpclient.HTTPRequest(self.get_url("/")),
- dict(httpclient.HTTPRequest._DEFAULTS)),
- None, self.stop,
- 1024 * 1024, resolver)
- conn.set_request(
- b"\r\n".join(headers +
- [utf8("Content-Length: %d\r\n" % len(body))]) +
- b"\r\n" + body)
- response = self.wait()
- response.rethrow()
- return response
+ with closing(IOStream(socket.socket())) as stream:
+ stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
+ self.wait()
+ stream.write(
+ b"\r\n".join(headers +
+ [utf8("Content-Length: %d\r\n" % len(body))]) +
+ b"\r\n" + body)
+ chunks = []
+ test = self
+ class Delegate(HTTPStreamDelegate):
+ def data_received(self, chunk):
+ chunks.append(chunk)
+
+ def finish(self):
+ test.stop()
+ conn = HTTP1Connection(stream, None)
+ conn.process_response(Delegate(), method='GET')
+ self.wait()
+ return b''.join(chunks)
def test_multipart_form(self):
# Encodings here are tricky: Headers are latin1, bodies can be
b"--1234567890--",
b"",
]))
- data = json_decode(response.body)
+ data = json_decode(response)
self.assertEqual(u("\u00e9"), data["header"])
self.assertEqual(u("\u00e1"), data["argument"])
self.assertEqual(u("\u00f3"), data["filename"])
import socket
import sys
+from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders
from tornado.ioloop import IOLoop
self.write(self.request.headers["Host"])
+class NoContentLengthHandler(RequestHandler):
+ @gen.coroutine
+ def get(self):
+ # Emulate the old HTTP/1.0 behavior of returning a body with no
+ # content-length. Tornado handles content-length at the framework
+ # level so we have to go around it.
+ stream = self.request.connection.stream
+ yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
+ b"hello")
+ stream.close()
+
+
class SimpleHTTPClientTestMixin(object):
def get_app(self):
# callable objects to finish pending /trigger requests
url("/see_other_post", SeeOtherPostHandler),
url("/see_other_get", SeeOtherGetHandler),
url("/host_echo", HostEchoHandler),
+ url("/no_content_length", NoContentLengthHandler),
], gzip=True)
def test_singleton(self):
self.triggers.popleft()()
self.wait()
+ def test_no_content_length(self):
+ response = self.fetch("/no_content_length")
+ self.assertEquals(b"hello", response.body)
+
class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
def setUp(self):
self.connect_future.set_exception(WebSocketError(
"Non-websocket response"))
- def _handle_1xx(self, code):
- assert code == 101
+ def headers_received(self, start_line, headers):
+ code = httputil.parse_response_start_line(start_line).code
+
+ if code != 101:
+ return super(WebSocketClientConnection, self).headers_received(
+ start_line, headers)
+
+ self.headers = headers
assert self.headers['Upgrade'].lower() == 'websocket'
assert self.headers['Connection'].lower() == 'upgrade'
accept = WebSocketProtocol13.compute_accept_value(self.key)
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
+ self.stream.set_close_callback(self._on_close)
+
self.connect_future.set_result(self)
+ return 'detach'
def write_message(self, message, binary=False):
"""Sends a message to the WebSocket server."""