]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
http1connection: Convert to native coroutines
authorBen Darnell <ben@bendarnell.com>
Sun, 14 Oct 2018 15:33:05 +0000 (11:33 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 14 Oct 2018 15:40:50 +0000 (11:40 -0400)
tornado/http1connection.py
tornado/test/httpserver_test.py

index 15c3c59b07a70047603b26428b50385152dba897..41df9f2c19a6149f2cfa3467c94a7f778031fff9 100644 (file)
@@ -18,6 +18,7 @@
 .. versionadded:: 4.0
 """
 
+import asyncio
 import logging
 import re
 import types
@@ -35,17 +36,7 @@ from tornado.log import gen_log, app_log
 from tornado.util import GzipDecompressor
 
 
-from typing import (
-    cast,
-    Optional,
-    Type,
-    Awaitable,
-    Generator,
-    Any,
-    Callable,
-    Union,
-    Tuple,
-)
+from typing import cast, Optional, Type, Awaitable, Callable, Union, Tuple
 
 
 class _QuietException(Exception):
@@ -186,20 +177,17 @@ class HTTP1Connection(httputil.HTTPConnection):
             delegate = _GzipMessageDelegate(delegate, self.params.chunk_size)
         return self._read_message(delegate)
 
-    @gen.coroutine
-    def _read_message(
-        self, delegate: httputil.HTTPMessageDelegate
-    ) -> Generator[Any, Any, bool]:
+    async def _read_message(self, delegate: httputil.HTTPMessageDelegate) -> bool:
         need_delegate_close = False
         try:
             header_future = self.stream.read_until_regex(
                 b"\r?\n\r?\n", max_bytes=self.params.max_header_size
             )
             if self.params.header_timeout is None:
-                header_data = yield header_future
+                header_data = await header_future
             else:
                 try:
-                    header_data = yield gen.with_timeout(
+                    header_data = await gen.with_timeout(
                         self.stream.io_loop.time() + self.params.header_timeout,
                         header_future,
                         quiet_exceptions=iostream.StreamClosedError,
@@ -228,7 +216,7 @@ class HTTP1Connection(httputil.HTTPConnection):
             with _ExceptionLoggingContext(app_log):
                 header_recv_future = delegate.headers_received(start_line, headers)
                 if header_recv_future is not None:
-                    yield header_recv_future
+                    await header_recv_future
             if self.stream is None:
                 # We've been detached.
                 need_delegate_close = False
@@ -256,7 +244,7 @@ class HTTP1Connection(httputil.HTTPConnection):
                         )
                     # TODO: client delegates will get headers_received twice
                     # in the case of a 100-continue.  Document or change?
-                    yield self._read_message(delegate)
+                    await self._read_message(delegate)
             else:
                 if headers.get("Expect") == "100-continue" and not self._write_finished:
                     self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
@@ -266,10 +254,10 @@ class HTTP1Connection(httputil.HTTPConnection):
                 )
                 if body_future is not None:
                     if self._body_timeout is None:
-                        yield body_future
+                        await body_future
                     else:
                         try:
-                            yield gen.with_timeout(
+                            await gen.with_timeout(
                                 self.stream.io_loop.time() + self._body_timeout,
                                 body_future,
                                 quiet_exceptions=iostream.StreamClosedError,
@@ -292,7 +280,7 @@ class HTTP1Connection(httputil.HTTPConnection):
                 and not self.stream.closed()
             ):
                 self.stream.set_close_callback(self._on_connection_close)
-                yield self._finish_future
+                await self._finish_future
             if self.is_client and self._disconnect_on_finish:
                 self.close()
             if self.stream is None:
@@ -300,7 +288,7 @@ class HTTP1Connection(httputil.HTTPConnection):
         except httputil.HTTPInputError as e:
             gen_log.info("Malformed HTTP message from %s: %s", self.context, e)
             if not self.is_client:
-                yield self.stream.write(b"HTTP/1.1 400 Bad Request\r\n\r\n")
+                await self.stream.write(b"HTTP/1.1 400 Bad Request\r\n\r\n")
             self.close()
             return False
         finally:
@@ -656,12 +644,11 @@ class HTTP1Connection(httputil.HTTPConnection):
             return self._read_body_until_close(delegate)
         return None
 
-    @gen.coroutine
-    def _read_fixed_body(
+    async def _read_fixed_body(
         self, content_length: int, delegate: httputil.HTTPMessageDelegate
-    ) -> Generator[Any, Any, None]:
+    ) -> None:
         while content_length > 0:
-            body = yield self.stream.read_bytes(
+            body = await self.stream.read_bytes(
                 min(self.params.chunk_size, content_length), partial=True
             )
             content_length -= len(body)
@@ -669,19 +656,16 @@ class HTTP1Connection(httputil.HTTPConnection):
                 with _ExceptionLoggingContext(app_log):
                     ret = delegate.data_received(body)
                     if ret is not None:
-                        yield ret
+                        await ret
 
-    @gen.coroutine
-    def _read_chunked_body(
-        self, delegate: httputil.HTTPMessageDelegate
-    ) -> Generator[Any, Any, None]:
+    async def _read_chunked_body(self, delegate: httputil.HTTPMessageDelegate) -> None:
         # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
         total_size = 0
         while True:
-            chunk_len = yield self.stream.read_until(b"\r\n", max_bytes=64)
-            chunk_len = int(chunk_len.strip(), 16)
+            chunk_len_str = await self.stream.read_until(b"\r\n", max_bytes=64)
+            chunk_len = int(chunk_len_str.strip(), 16)
             if chunk_len == 0:
-                crlf = yield self.stream.read_bytes(2)
+                crlf = await self.stream.read_bytes(2)
                 if crlf != b"\r\n":
                     raise httputil.HTTPInputError(
                         "improperly terminated chunked request"
@@ -692,7 +676,7 @@ class HTTP1Connection(httputil.HTTPConnection):
                 raise httputil.HTTPInputError("chunked body too large")
             bytes_to_read = chunk_len
             while bytes_to_read:
-                chunk = yield self.stream.read_bytes(
+                chunk = await self.stream.read_bytes(
                     min(bytes_to_read, self.params.chunk_size), partial=True
                 )
                 bytes_to_read -= len(chunk)
@@ -700,19 +684,20 @@ class HTTP1Connection(httputil.HTTPConnection):
                     with _ExceptionLoggingContext(app_log):
                         ret = delegate.data_received(chunk)
                         if ret is not None:
-                            yield ret
+                            await ret
             # chunk ends with \r\n
-            crlf = yield self.stream.read_bytes(2)
+            crlf = await self.stream.read_bytes(2)
             assert crlf == b"\r\n"
 
-    @gen.coroutine
-    def _read_body_until_close(
+    async def _read_body_until_close(
         self, delegate: httputil.HTTPMessageDelegate
-    ) -> Generator[Any, Any, None]:
-        body = yield self.stream.read_until_close()
+    ) -> None:
+        body = await self.stream.read_until_close()
         if not self._write_finished or self.is_client:
             with _ExceptionLoggingContext(app_log):
-                delegate.data_received(body)
+                ret = delegate.data_received(body)
+                if ret is not None:
+                    await ret
 
 
 class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
@@ -738,8 +723,7 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
             del headers["Content-Encoding"]
         return self._delegate.headers_received(start_line, headers)
 
-    @gen.coroutine
-    def data_received(self, chunk: bytes) -> Generator[Any, Any, None]:
+    async def data_received(self, chunk: bytes) -> None:
         if self._decompressor:
             compressed_data = chunk
             while compressed_data:
@@ -749,23 +733,26 @@ class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
                 if decompressed:
                     ret = self._delegate.data_received(decompressed)
                     if ret is not None:
-                        yield ret
+                        await ret
                 compressed_data = self._decompressor.unconsumed_tail
         else:
             ret = self._delegate.data_received(chunk)
             if ret is not None:
-                yield ret
+                await ret
 
     def finish(self) -> None:
         if self._decompressor is not None:
             tail = self._decompressor.flush()
             if tail:
-                # I believe the tail will always be empty (i.e.
-                # decompress will return all it can).  The purpose
-                # of the flush call is to detect errors such
-                # as truncated input.  But in case it ever returns
-                # anything, treat it as an extra chunk
-                self._delegate.data_received(tail)
+                # The tail should always be empty: decompress returned
+                # all that it can in data_received and the only
+                # purpose of the flush call is to detect errors such
+                # as truncated input. If we did legitimately get a new
+                # chunk at this point we'd need to change the
+                # interface to make finish() a coroutine.
+                raise ValueError(
+                    "decompressor.flush returned data; possile truncated input"
+                )
         return self._delegate.finish()
 
     def on_connection_close(self) -> None:
@@ -794,8 +781,7 @@ class HTTP1ServerConnection(object):
         self.context = context
         self._serving_future = None  # type: Optional[Future[None]]
 
-    @gen.coroutine
-    def close(self) -> Generator[Any, Any, None]:
+    async def close(self) -> None:
         """Closes the connection.
 
         Returns a `.Future` that resolves after the serving loop has exited.
@@ -803,8 +789,9 @@ class HTTP1ServerConnection(object):
         self.stream.close()
         # Block until the serving loop is done, but ignore any exceptions
         # (start_serving is already responsible for logging them).
+        assert self._serving_future is not None
         try:
-            yield self._serving_future
+            await self._serving_future
         except Exception:
             pass
 
@@ -814,21 +801,20 @@ class HTTP1ServerConnection(object):
         :arg delegate: a `.HTTPServerConnectionDelegate`
         """
         assert isinstance(delegate, httputil.HTTPServerConnectionDelegate)
-        fut = self._server_request_loop(delegate)
+        fut = gen.convert_yielded(self._server_request_loop(delegate))
         self._serving_future = fut
         # Register the future on the IOLoop so its errors get logged.
         self.stream.io_loop.add_future(fut, lambda f: f.result())
 
-    @gen.coroutine
-    def _server_request_loop(
+    async def _server_request_loop(
         self, delegate: httputil.HTTPServerConnectionDelegate
-    ) -> Generator[Any, Any, None]:
+    ) -> None:
         try:
             while True:
                 conn = HTTP1Connection(self.stream, False, self.params, self.context)
                 request_delegate = delegate.start_request(self, conn)
                 try:
-                    ret = yield conn.read_response(request_delegate)
+                    ret = await conn.read_response(request_delegate)
                 except (iostream.StreamClosedError, iostream.UnsatisfiableReadError):
                     return
                 except _QuietException:
@@ -841,6 +827,6 @@ class HTTP1ServerConnection(object):
                     return
                 if not ret:
                     return
-                yield gen.moment
+                await asyncio.sleep(0)
         finally:
             delegate.on_close(self)
index 8a2e03a7f0cffb9bf8276000955e6b1e36f61b1e..b46a0a396c4f34912919b0a33609a58e0f45570e 100644 (file)
@@ -1,5 +1,4 @@
 from tornado import gen, netutil
-from tornado.concurrent import Future
 from tornado.escape import (
     json_decode,
     json_encode,
@@ -50,8 +49,8 @@ if typing.TYPE_CHECKING:
     from typing import Dict, List  # noqa: F401
 
 
-def read_stream_body(stream, callback):
-    """Reads an HTTP response from `stream` and runs callback with its
+async def read_stream_body(stream):
+    """Reads an HTTP response from `stream` and returns a tuple of its
     start_line, headers and body."""
     chunks = []
 
@@ -65,10 +64,11 @@ def read_stream_body(stream, callback):
 
         def finish(self):
             conn.detach()  # type: ignore
-            callback((self.start_line, self.headers, b"".join(chunks)))
 
     conn = HTTP1Connection(stream, True)
-    conn.read_response(Delegate())
+    delegate = Delegate()
+    await conn.read_response(delegate)
+    return delegate.start_line, delegate.headers, b"".join(chunks)
 
 
 class HandlerBaseTestCase(AsyncHTTPTestCase):
@@ -257,8 +257,9 @@ class HTTPConnectionTest(AsyncHTTPTestCase):
                 + newline
                 + body
             )
-            read_stream_body(stream, self.stop)
-            start_line, headers, body = self.wait()
+            start_line, headers, body = self.io_loop.run_sync(
+                lambda: read_stream_body(stream)
+            )
             return body
 
     def test_multipart_form(self):
@@ -459,8 +460,9 @@ class HTTPServerRawTest(AsyncHTTPTestCase):
     def test_malformed_first_line_response(self):
         with ExpectLog(gen_log, ".*Malformed HTTP request line"):
             self.stream.write(b"asdf\r\n\r\n")
-            read_stream_body(self.stream, self.stop)
-            start_line, headers, response = self.wait()
+            start_line, headers, response = self.io_loop.run_sync(
+                lambda: read_stream_body(self.stream)
+            )
             self.assertEqual("HTTP/1.1", start_line.version)
             self.assertEqual(400, start_line.code)
             self.assertEqual("Bad Request", start_line.reason)
@@ -498,8 +500,9 @@ bar
                 b"\n", b"\r\n"
             )
         )
-        read_stream_body(self.stream, self.stop)
-        start_line, headers, response = self.wait()
+        start_line, headers, response = self.io_loop.run_sync(
+            lambda: read_stream_body(self.stream)
+        )
         self.assertEqual(json_decode(response), {u"foo": [u"bar"]})
 
     def test_chunked_request_uppercase(self):
@@ -521,8 +524,9 @@ bar
                 b"\n", b"\r\n"
             )
         )
-        read_stream_body(self.stream, self.stop)
-        start_line, headers, response = self.wait()
+        start_line, headers, response = self.io_loop.run_sync(
+            lambda: read_stream_body(self.stream)
+        )
         self.assertEqual(json_decode(response), {u"foo": [u"bar"]})
 
     @gen_test
@@ -1239,9 +1243,7 @@ class BodyLimitsTest(AsyncHTTPTestCase):
                 b"Content-Length: 10240\r\n\r\n"
             )
             stream.write(b"a" * 10240)
-            fut = Future()  # type: Future[bytes]
-            read_stream_body(stream, callback=fut.set_result)
-            start_line, headers, response = yield fut
+            start_line, headers, response = yield read_stream_body(stream)
             self.assertEqual(response, b"10240")
             # Without the ?expected_size parameter, we get the old default value
             stream.write(