]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Use HTTP1Connection in SimpleAsyncHTTPClient.
authorBen Darnell <ben@bendarnell.com>
Sun, 2 Mar 2014 20:42:19 +0000 (15:42 -0500)
committerBen Darnell <ben@bendarnell.com>
Sun, 2 Mar 2014 20:48:26 +0000 (15:48 -0500)
This incidentally adds support for chunked request bodies on the server side.

tornado/http1connection.py
tornado/httpserver.py
tornado/httputil.py
tornado/simple_httpclient.py
tornado/test/httpserver_test.py
tornado/test/simple_httpclient_test.py
tornado/websocket.py

index b491de6c476da7a3fe8b232709591cf263f4b26c..a7cd9cbcc23f7e3985e27095904f35d0174129d3 100644 (file)
@@ -31,15 +31,13 @@ class HTTP1Connection(object):
     We parse HTTP headers and bodies, and execute the request callback
     until the HTTP conection is closed.
     """
-    def __init__(self, stream, address, delegate, no_keep_alive=False,
-                 protocol=None):
+    def __init__(self, stream, address, no_keep_alive=False, protocol=None):
         self.stream = stream
         self.address = address
         # Save the socket's address family now so we know how to
         # interpret self.address even after the stream is closed
         # and its socket attribute replaced with None.
         self.address_family = stream.socket.family
-        self.delegate = delegate
         self.no_keep_alive = no_keep_alive
         if protocol:
             self.protocol = protocol
@@ -51,34 +49,65 @@ class HTTP1Connection(object):
         self._clear_request_state()
         self.stream.set_close_callback(self._on_connection_close)
         self._finish_future = None
+
+    def start_serving(self, delegate):
+        assert isinstance(delegate, httputil.HTTPConnectionDelegate)
         # Register the future on the IOLoop so its errors get logged.
-        stream.io_loop.add_future(self._process_requests(),
-                                  lambda f: f.result())
+        self.stream.io_loop.add_future(self._process_requests(delegate),
+                                       lambda f: f.result())
 
     @gen.coroutine
-    def _process_requests(self):
+    def _process_requests(self, delegate):
         while True:
+            request_delegate = delegate.start_request(self)
             try:
-                header_data = yield self.stream.read_until(b"\r\n\r\n")
-                request_delegate = self.delegate.start_request(self)
-                self._finish_future = Future()
-                start_line, headers = self._parse_headers(header_data)
-                self._disconnect_on_finish = not self._can_keep_alive(
-                    start_line, headers)
-                request_delegate.headers_received(start_line, headers)
-                body_future = self._read_body(headers)
-                if body_future is not None:
-                    request_delegate.data_received((yield body_future))
-                request_delegate.finish()
-                yield self._finish_future
-            except httputil.BadRequestException as e:
-                gen_log.info("Malformed HTTP request from %r: %s",
-                             self.address, e)
-                self.close()
-                return
+                ret = yield self._process_message(request_delegate, False)
             except iostream.StreamClosedError:
                 self.close()
                 return
+            if not ret:
+                return
+
+    def process_response(self, delegate, method):
+        return self._process_message(delegate, True, method=method)
+
+    @gen.coroutine
+    def _process_message(self, delegate, is_client, method=None):
+        assert isinstance(delegate, httputil.HTTPStreamDelegate)
+        try:
+            header_data = yield self.stream.read_until_regex(b"\r?\n\r?\n")
+            self._finish_future = Future()
+            start_line, headers = self._parse_headers(header_data)
+            self._disconnect_on_finish = not self._can_keep_alive(
+                start_line, headers)
+            ret = delegate.headers_received(start_line, headers)
+            # TODO: finalize the 'detach' interface.
+            if ret == 'detach':
+                return
+            skip_body = False
+            if is_client:
+                if method == 'HEAD':
+                    skip_body = True
+                code = httputil.parse_response_start_line(start_line).code
+                if code == 304:
+                    skip_body = True
+                if code >= 100 and code < 200:
+                    yield self._process_message(delegate, is_client, method=method)
+            else:
+                if headers.get("Expect") == "100-continue":
+                    self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
+            if not skip_body:
+                body_future = self._read_body(is_client, headers, delegate)
+                if body_future is not None:
+                    yield body_future
+            delegate.finish()
+            yield self._finish_future
+        except httputil.BadRequestException as e:
+            gen_log.info("Malformed HTTP request from %r: %s",
+                         self.address, e)
+            self.close()
+            raise gen.Return(False)
+        raise gen.Return(True)
 
 
     def _clear_request_state(self):
@@ -183,13 +212,38 @@ class HTTP1Connection(object):
             raise httputil.BadRequestException("Malformed HTTP headers")
         return start_line, headers
 
-    def _read_body(self, headers):
+    def _read_body(self, is_client, headers, delegate):
         content_length = headers.get("Content-Length")
         if content_length:
             content_length = int(content_length)
             if content_length > self.stream.max_buffer_size:
                 raise httputil.BadRequestException("Content-Length too long")
-            if headers.get("Expect") == "100-continue":
-                self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
-            return self.stream.read_bytes(content_length)
+            return self._read_fixed_body(content_length, delegate)
+        if headers.get("Transfer-Encoding") == "chunked":
+            return self._read_chunked_body(delegate)
+        if is_client:
+            return self._read_body_until_close(delegate)
         return None
+
+    @gen.coroutine
+    def _read_fixed_body(self, content_length, delegate):
+        body = yield self.stream.read_bytes(content_length)
+        delegate.data_received(body)
+
+    @gen.coroutine
+    def _read_chunked_body(self, delegate):
+        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
+        while True:
+            chunk_len = yield self.stream.read_until(b"\r\n")
+            chunk_len = int(chunk_len.strip(), 16)
+            if chunk_len == 0:
+                return
+            # chunk ends with \r\n
+            chunk = yield self.stream.read_bytes(chunk_len + 2)
+            assert chunk[-2:] == b"\r\n"
+            delegate.data_received(chunk[:-2])
+
+    @gen.coroutine
+    def _read_body_until_close(self, delegate):
+        body = yield self.stream.read_until_close()
+        delegate.data_received(body)
index 44a2c94bde477d26d863710fcf2c393e28f94db2..8454c3d7cdf4f20deded0d74c3667d799061a444 100644 (file)
@@ -143,7 +143,9 @@ class HTTPServer(TCPServer, httputil.HTTPConnectionDelegate):
                            **kwargs)
 
     def handle_stream(self, stream, address):
-        HTTPConnection(stream, address, self, self.no_keep_alive, self.protocol)
+        conn = HTTPConnection(stream, address, self.no_keep_alive,
+                              self.protocol)
+        conn.start_serving(self)
 
     def start_request(self, connection):
         return _ServerRequestProcessor(self, connection)
index 8164d2eb84fae414a72a04ab8c63ed9886c1d5c9..d7739d401d7f93ea06bde270bfcd5c63fc4d0221 100644 (file)
@@ -28,6 +28,7 @@ import copy
 import datetime
 import email.utils
 import numbers
+import re
 import time
 
 from tornado.escape import native_str, parse_qs_bytes, utf8
@@ -625,6 +626,24 @@ def format_timestamp(ts):
         raise TypeError("unknown timestamp type: %r" % ts)
     return email.utils.formatdate(ts, usegmt=True)
 
+
+ResponseStartLine = collections.namedtuple(
+    'ResponseStartLine', ['version', 'code', 'reason'])
+
+def parse_response_start_line(line):
+    """Returns a (version, code, reason) tuple for an HTTP 1.x response line.
+
+    The response is a `collections.namedtuple`.
+
+    >>> parse_response_start_line("HTTP/1.1 200 OK")
+    ResponseStartLine(version='HTTP/1.1', code=200, reason='OK')
+    """
+    line = native_str(line)
+    match = re.match("(HTTP/1.[01]) ([0-9]+) ([^\r]*)", line)
+    assert match
+    return ResponseStartLine(match.group(1), int(match.group(2)),
+                             match.group(3))
+
 # _parseparam and _parse_header are copied and modified from python2.7's cgi.py
 # The original 2.7 version of this code did not correctly support some
 # combinations of semicolons and double quotes.
index 73bfee89e4c26b12cefeb12e90af6db2e6c30eba..0bd92e299959486d98e315badd3b1af3a7f32aac 100644 (file)
@@ -1,10 +1,11 @@
 #!/usr/bin/env python
 from __future__ import absolute_import, division, print_function, with_statement
 
-from tornado.escape import utf8, _unicode, native_str
+from tornado.escape import utf8, _unicode
 from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
-from tornado.httputil import HTTPHeaders
-from tornado.iostream import IOStream, SSLIOStream
+from tornado import httputil
+from tornado.http1connection import HTTP1Connection
+from tornado.iostream import IOStream, SSLIOStream, StreamClosedError
 from tornado.netutil import Resolver, OverrideResolver
 from tornado.log import gen_log
 from tornado import stack_context
@@ -142,7 +143,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
         del self.waiting[key]
 
 
-class _HTTPConnection(object):
+class _HTTPConnection(httputil.HTTPStreamDelegate):
     _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
 
     def __init__(self, io_loop, client, request, release_callback,
@@ -157,10 +158,11 @@ class _HTTPConnection(object):
         self.resolver = resolver
         self.code = None
         self.headers = None
-        self.chunks = None
+        self.chunks = []
         self._decompressor = None
         # Timeout handle returned by IOLoop.add_timeout
         self._timeout = None
+        self._sockaddr = None
         with stack_context.ExceptionStackContext(self._handle_exception):
             self.parsed = urlparse.urlsplit(_unicode(self.request.url))
             if self.parsed.scheme not in ("http", "https"):
@@ -205,8 +207,8 @@ class _HTTPConnection(object):
         self.stream.set_close_callback(self._on_close)
         # ipv6 addresses are broken (in self.parsed.hostname) until
         # 2.7, here is correctly parsed value calculated in __init__
-        sockaddr = addrinfo[0][1]
-        self.stream.connect(sockaddr, self._on_connect,
+        self._sockaddr = addrinfo[0][1]
+        self.stream.connect(self._sockaddr, self._on_connect,
                             server_hostname=self.parsed_hostname)
 
     def _create_stream(self, addrinfo):
@@ -333,7 +335,14 @@ class _HTTPConnection(object):
             request_str += self.request.body
         self.stream.set_nodelay(True)
         self.stream.write(request_str)
-        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
+        self.connection = HTTP1Connection(
+            self.stream, self._sockaddr,
+            no_keep_alive=True, protocol=self.parsed.scheme)
+        # Ensure that any exception raised in process_response ends up in our
+        # stack context.
+        self.io_loop.add_future(
+            self.connection.process_response(self, method=self.request.method),
+            lambda f: f.result())
 
     def _release(self):
         if self.release_callback is not None:
@@ -351,19 +360,24 @@ class _HTTPConnection(object):
     def _handle_exception(self, typ, value, tb):
         if self.final_callback:
             self._remove_timeout()
+            if isinstance(value, StreamClosedError):
+                value = HTTPError(599, "Stream closed")
             self._run_callback(HTTPResponse(self.request, 599, error=value,
                                             request_time=self.io_loop.time() - self.start_time,
                                             ))
 
             if hasattr(self, "stream"):
+                # TODO: this may cause a StreamClosedError to be raised
+                # by the connection's Future.  Should we cancel the
+                # connection more gracefully?
                 self.stream.close()
             return True
         else:
             # If our callback has already been called, we are probably
             # catching an exception that is not caused by us but rather
             # some child of our callback. Rather than drop it on the floor,
-            # pass it along.
-            return False
+            # pass it along, unless it's just the stream being closed.
+            return isinstance(value, StreamClosedError)
 
     def _on_close(self):
         if self.final_callback is not None:
@@ -372,22 +386,11 @@ class _HTTPConnection(object):
                 message = str(self.stream.error)
             raise HTTPError(599, message)
 
-    def _handle_1xx(self, code):
-        self.stream.read_until_regex(b"\r?\n\r?\n", self._on_headers)
-
-    def _on_headers(self, data):
-        data = native_str(data.decode("latin1"))
-        first_line, _, header_data = data.partition("\n")
-        match = re.match("HTTP/1.[01] ([0-9]+) ([^\r]*)", first_line)
-        assert match
-        code = int(match.group(1))
-        self.headers = HTTPHeaders.parse(header_data)
-        if 100 <= code < 200:
-            self._handle_1xx(code)
-            return
-        else:
-            self.code = code
-            self.reason = match.group(2)
+    def headers_received(self, first_line, headers):
+        self.headers = headers
+        version, code, reason = httputil.parse_response_start_line(first_line)
+        self.code = code
+        self.reason = reason
 
         if "Content-Length" in self.headers:
             if "," in self.headers["Content-Length"]:
@@ -405,16 +408,11 @@ class _HTTPConnection(object):
 
         if self.request.header_callback is not None:
             # re-attach the newline we split on earlier
-            self.request.header_callback(first_line + _)
+            self.request.header_callback(first_line + '\r\n')
             for k, v in self.headers.get_all():
                 self.request.header_callback("%s: %s\r\n" % (k, v))
             self.request.header_callback('\r\n')
 
-        if self.request.method == "HEAD" or self.code == 304:
-            # HEAD requests and 304 responses never have content, even
-            # though they may have content-length headers
-            self._on_body(b"")
-            return
         if 100 <= self.code < 200 or self.code == 204:
             # These response codes never have bodies
             # http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
@@ -422,21 +420,25 @@ class _HTTPConnection(object):
                     content_length not in (None, 0)):
                 raise ValueError("Response with code %d should not have body" %
                                  self.code)
-            self._on_body(b"")
-            return
 
         if (self.request.use_gzip and
                 self.headers.get("Content-Encoding") == "gzip"):
             self._decompressor = GzipDecompressor()
-        if self.headers.get("Transfer-Encoding") == "chunked":
-            self.chunks = []
-            self.stream.read_until(b"\r\n", self._on_chunk_length)
-        elif content_length is not None:
-            self.stream.read_bytes(content_length, self._on_body)
-        else:
-            self.stream.read_until_close(self._on_body)
 
-    def _on_body(self, data):
+    def finish(self):
+        if self._decompressor is not None:
+            tail = self._decompressor.flush()
+            if tail:
+                # I believe the tail will always be empty (i.e.
+                # decompress will return all it can).  The purpose
+                # of the flush call is to detect errors such
+                # as truncated input.  But in case it ever returns
+                # anything, treat it as an extra chunk
+                if self.request.streaming_callback is not None:
+                    self.request.streaming_callback(tail)
+                else:
+                    self.chunks.append(tail)
+        data = b''.join(self.chunks)
         self._remove_timeout()
         original_request = getattr(self.request, "original_request",
                                    self.request)
@@ -472,19 +474,12 @@ class _HTTPConnection(object):
             self.client.fetch(new_request, final_callback)
             self._on_end_request()
             return
-        if self._decompressor:
-            data = (self._decompressor.decompress(data) +
-                    self._decompressor.flush())
         if self.request.streaming_callback:
-            if self.chunks is None:
-                # if chunks is not None, we already called streaming_callback
-                # in _on_chunk_data
-                self.request.streaming_callback(data)
             buffer = BytesIO()
         else:
             buffer = BytesIO(data)  # TODO: don't require one big string?
         response = HTTPResponse(original_request,
-                                self.code, reason=self.reason,
+                                self.code, reason=getattr(self, 'reason', None),
                                 headers=self.headers,
                                 request_time=self.io_loop.time() - self.start_time,
                                 buffer=buffer,
@@ -495,40 +490,13 @@ class _HTTPConnection(object):
     def _on_end_request(self):
         self.stream.close()
 
-    def _on_chunk_length(self, data):
-        # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
-        length = int(data.strip(), 16)
-        if length == 0:
-            if self._decompressor is not None:
-                tail = self._decompressor.flush()
-                if tail:
-                    # I believe the tail will always be empty (i.e.
-                    # decompress will return all it can).  The purpose
-                    # of the flush call is to detect errors such
-                    # as truncated input.  But in case it ever returns
-                    # anything, treat it as an extra chunk
-                    if self.request.streaming_callback is not None:
-                        self.request.streaming_callback(tail)
-                    else:
-                        self.chunks.append(tail)
-                # all the data has been decompressed, so we don't need to
-                # decompress again in _on_body
-                self._decompressor = None
-            self._on_body(b''.join(self.chunks))
-        else:
-            self.stream.read_bytes(length + 2,  # chunk ends with \r\n
-                                   self._on_chunk_data)
-
-    def _on_chunk_data(self, data):
-        assert data[-2:] == b"\r\n"
-        chunk = data[:-2]
+    def data_received(self, chunk):
         if self._decompressor:
             chunk = self._decompressor.decompress(chunk)
         if self.request.streaming_callback is not None:
             self.request.streaming_callback(chunk)
         else:
             self.chunks.append(chunk)
-        self.stream.read_until(b"\r\n", self._on_chunk_length)
 
 
 if __name__ == "__main__":
index 5ca2993501a01ce709d5b03ed274630e5a14d3ab..19049991ad72936cf2bcf0d7e690c6337c8f7264 100644 (file)
@@ -4,8 +4,9 @@
 from __future__ import absolute_import, division, print_function, with_statement
 from tornado import httpclient, simple_httpclient, netutil
 from tornado.escape import json_decode, utf8, _unicode, recursive_unicode, native_str
+from tornado.http1connection import HTTP1Connection
 from tornado.httpserver import HTTPServer
-from tornado.httputil import HTTPHeaders
+from tornado.httputil import HTTPHeaders, HTTPStreamDelegate
 from tornado.iostream import IOStream
 from tornado.log import gen_log
 from tornado.netutil import ssl_options_to_context, Resolver
@@ -163,18 +164,7 @@ class MultipartTestHandler(RequestHandler):
                      })
 
 
-class RawRequestHTTPConnection(simple_httpclient._HTTPConnection):
-    def set_request(self, request):
-        self.__next_request = request
-
-    def _on_connect(self):
-        self.stream.write(self.__next_request)
-        self.__next_request = None
-        self.stream.read_until(b"\r\n\r\n", self._on_headers)
-
 # This test is also called from wsgi_test
-
-
 class HTTPConnectionTest(AsyncHTTPTestCase):
     def get_handlers(self):
         return [("/multipart", MultipartTestHandler),
@@ -184,23 +174,25 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
         return Application(self.get_handlers())
 
     def raw_fetch(self, headers, body):
-        with closing(Resolver(io_loop=self.io_loop)) as resolver:
-            with closing(SimpleAsyncHTTPClient(self.io_loop,
-                                               resolver=resolver)) as client:
-                conn = RawRequestHTTPConnection(
-                    self.io_loop, client,
-                    httpclient._RequestProxy(
-                        httpclient.HTTPRequest(self.get_url("/")),
-                        dict(httpclient.HTTPRequest._DEFAULTS)),
-                    None, self.stop,
-                    1024 * 1024, resolver)
-                conn.set_request(
-                    b"\r\n".join(headers +
-                                 [utf8("Content-Length: %d\r\n" % len(body))]) +
-                    b"\r\n" + body)
-                response = self.wait()
-                response.rethrow()
-                return response
+        with closing(IOStream(socket.socket())) as stream:
+            stream.connect(('127.0.0.1', self.get_http_port()), self.stop)
+            self.wait()
+            stream.write(
+                b"\r\n".join(headers +
+                             [utf8("Content-Length: %d\r\n" % len(body))]) +
+                b"\r\n" + body)
+            chunks = []
+            test = self
+            class Delegate(HTTPStreamDelegate):
+                def data_received(self, chunk):
+                    chunks.append(chunk)
+
+                def finish(self):
+                    test.stop()
+            conn = HTTP1Connection(stream, None)
+            conn.process_response(Delegate(), method='GET')
+            self.wait()
+            return b''.join(chunks)
 
     def test_multipart_form(self):
         # Encodings here are tricky:  Headers are latin1, bodies can be
@@ -221,7 +213,7 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
                 b"--1234567890--",
                 b"",
             ]))
-        data = json_decode(response.body)
+        data = json_decode(response)
         self.assertEqual(u("\u00e9"), data["header"])
         self.assertEqual(u("\u00e1"), data["argument"])
         self.assertEqual(u("\u00f3"), data["filename"])
index ac98aaae4846c4edadc280a698dc868b7a65d60b..c10e7777ccbce42da4c908ccf53a95939166a6ec 100644 (file)
@@ -10,6 +10,7 @@ import re
 import socket
 import sys
 
+from tornado import gen
 from tornado.httpclient import AsyncHTTPClient
 from tornado.httputil import HTTPHeaders
 from tornado.ioloop import IOLoop
@@ -94,6 +95,18 @@ class HostEchoHandler(RequestHandler):
         self.write(self.request.headers["Host"])
 
 
+class NoContentLengthHandler(RequestHandler):
+    @gen.coroutine
+    def get(self):
+        # Emulate the old HTTP/1.0 behavior of returning a body with no
+        # content-length.  Tornado handles content-length at the framework
+        # level so we have to go around it.
+        stream = self.request.connection.stream
+        yield stream.write(b"HTTP/1.0 200 OK\r\n\r\n"
+                           b"hello")
+        stream.close()
+
+
 class SimpleHTTPClientTestMixin(object):
     def get_app(self):
         # callable objects to finish pending /trigger requests
@@ -112,6 +125,7 @@ class SimpleHTTPClientTestMixin(object):
             url("/see_other_post", SeeOtherPostHandler),
             url("/see_other_get", SeeOtherGetHandler),
             url("/host_echo", HostEchoHandler),
+            url("/no_content_length", NoContentLengthHandler),
         ], gzip=True)
 
     def test_singleton(self):
@@ -313,6 +327,10 @@ class SimpleHTTPClientTestMixin(object):
             self.triggers.popleft()()
             self.wait()
 
+    def test_no_content_length(self):
+        response = self.fetch("/no_content_length")
+        self.assertEquals(b"hello", response.body)
+
 
 class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
     def setUp(self):
index 1992b1869d8b8486187e1a9137b589c89aaea9d7..5e3bcd7d5bb2e0f22c77cdccbd53452115809a04 100644 (file)
@@ -860,8 +860,14 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
                 self.connect_future.set_exception(WebSocketError(
                     "Non-websocket response"))
 
-    def _handle_1xx(self, code):
-        assert code == 101
+    def headers_received(self, start_line, headers):
+        code = httputil.parse_response_start_line(start_line).code
+
+        if code != 101:
+            return super(WebSocketClientConnection, self).headers_received(
+                start_line, headers)
+
+        self.headers = headers
         assert self.headers['Upgrade'].lower() == 'websocket'
         assert self.headers['Connection'].lower() == 'upgrade'
         accept = WebSocketProtocol13.compute_accept_value(self.key)
@@ -874,7 +880,10 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
             self.io_loop.remove_timeout(self._timeout)
             self._timeout = None
 
+        self.stream.set_close_callback(self._on_close)
+
         self.connect_future.set_result(self)
+        return 'detach'
 
     def write_message(self, message, binary=False):
         """Sends a message to the WebSocket server."""