]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
*: Convert most non-test coroutines to native 2520/head
authorBen Darnell <ben@bendarnell.com>
Sun, 14 Oct 2018 19:28:49 +0000 (15:28 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 21 Oct 2018 01:46:38 +0000 (21:46 -0400)
The websocket module is an exception; it needs a little more work to
change things without breaking tests.

tornado/httpserver.py
tornado/netutil.py
tornado/simple_httpclient.py
tornado/tcpclient.py
tornado/test/simple_httpclient_test.py
tornado/web.py
tornado/websocket.py

index 78945c10a32649b3e976060a408ff4a4d5a54671..9376c568594ebfbc6763043c23e372342dd14a92 100644 (file)
@@ -30,7 +30,6 @@ import ssl
 
 from tornado.escape import native_str
 from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters
-from tornado import gen
 from tornado import httputil
 from tornado import iostream
 from tornado import netutil
@@ -38,18 +37,7 @@ from tornado.tcpserver import TCPServer
 from tornado.util import Configurable
 
 import typing
-from typing import (
-    Union,
-    Any,
-    Dict,
-    Callable,
-    List,
-    Type,
-    Generator,
-    Tuple,
-    Optional,
-    Awaitable,
-)
+from typing import Union, Any, Dict, Callable, List, Type, Tuple, Optional, Awaitable
 
 if typing.TYPE_CHECKING:
     from typing import Set  # noqa: F401
@@ -209,12 +197,11 @@ class HTTPServer(TCPServer, Configurable, httputil.HTTPServerConnectionDelegate)
     def configurable_default(cls) -> Type[Configurable]:
         return HTTPServer
 
-    @gen.coroutine
-    def close_all_connections(self) -> Generator[Any, Any, None]:
+    async def close_all_connections(self) -> None:
         while self._connections:
             # Peek at an arbitrary element of the set
             conn = next(iter(self._connections))
-            yield conn.close()
+            await conn.close()
 
     def handle_stream(self, stream: iostream.IOStream, address: Tuple) -> None:
         context = _HTTPRequestContext(
index 7c07c48ff61e89c3d3ad199c9992c5f4f1263c71..d570ff96d03a07c065e852b76b861db3dd1b2c74 100644 (file)
@@ -24,17 +24,15 @@ import ssl
 import stat
 
 from tornado.concurrent import dummy_executor, run_on_executor
-from tornado import gen
 from tornado.ioloop import IOLoop
 from tornado.platform.auto import set_close_exec
 from tornado.util import Configurable, errno_from_exception
 
 import typing
-from typing import List, Callable, Any, Type, Generator, Dict, Union, Tuple
+from typing import List, Callable, Any, Type, Dict, Union, Tuple, Awaitable
 
 if typing.TYPE_CHECKING:
     from asyncio import Future  # noqa: F401
-    from typing import Awaitable  # noqa: F401
 
 # Note that the naming of ssl.Purpose is confusing; the purpose
 # of a context is to authentiate the opposite side of the connection.
@@ -337,7 +335,7 @@ class Resolver(Configurable):
 
     def resolve(
         self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
-    ) -> "Future[List[Tuple[int, Any]]]":
+    ) -> Awaitable[List[Tuple[int, Any]]]:
         """Resolves an address.
 
         The ``host`` argument is a string which may be a hostname or a
@@ -391,11 +389,10 @@ class DefaultExecutorResolver(Resolver):
     .. versionadded:: 5.0
     """
 
-    @gen.coroutine
-    def resolve(
+    async def resolve(
         self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
-    ) -> Generator[Any, Any, List[Tuple[int, Any]]]:
-        result = yield IOLoop.current().run_in_executor(
+    ) -> List[Tuple[int, Any]]:
+        result = await IOLoop.current().run_in_executor(
             None, _resolve_addr, host, port, family
         )
         return result
@@ -534,7 +531,7 @@ class OverrideResolver(Resolver):
 
     def resolve(
         self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
-    ) -> "Future[List[Tuple[int, Any]]]":
+    ) -> Awaitable[List[Tuple[int, Any]]]:
         if (host, port, family) in self.mapping:
             host, port = self.mapping[(host, port, family)]
         elif (host, port) in self.mapping:
index e8474cb9d674560350da4aac63ed7add42628440..d6ac20f91237e13589efcbe0f448b98879fdc063 100644 (file)
@@ -28,7 +28,7 @@ import time
 from io import BytesIO
 import urllib.parse
 
-from typing import Dict, Any, Generator, Callable, Optional, Type, Union
+from typing import Dict, Any, Callable, Optional, Type, Union
 from types import TracebackType
 import typing
 
@@ -277,10 +277,11 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
         # Timeout handle returned by IOLoop.add_timeout
         self._timeout = None  # type: object
         self._sockaddr = None
-        IOLoop.current().add_callback(self.run)
+        IOLoop.current().add_future(
+            gen.convert_yielded(self.run()), lambda f: f.result()
+        )
 
-    @gen.coroutine
-    def run(self) -> Generator[Any, Any, None]:
+    async def run(self) -> None:
         try:
             self.parsed = urllib.parse.urlsplit(_unicode(self.request.url))
             if self.parsed.scheme not in ("http", "https"):
@@ -311,7 +312,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                     self.start_time + timeout,
                     functools.partial(self._on_timeout, "while connecting"),
                 )
-                stream = yield self.tcp_client.connect(
+                stream = await self.tcp_client.connect(
                     host,
                     port,
                     af=af,
@@ -417,9 +418,9 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
                 )
                 self.connection.write_headers(start_line, self.request.headers)
                 if self.request.expect_100_continue:
-                    yield self.connection.read_response(self)
+                    await self.connection.read_response(self)
                 else:
-                    yield self._write_body(True)
+                    await self._write_body(True)
         except Exception:
             if not self._handle_exception(*sys.exc_info()):
                 raise
@@ -489,18 +490,17 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
         )
         return connection
 
-    @gen.coroutine
-    def _write_body(self, start_read: bool) -> Generator[Any, Any, None]:
+    async def _write_body(self, start_read: bool) -> None:
         if self.request.body is not None:
             self.connection.write(self.request.body)
         elif self.request.body_producer is not None:
             fut = self.request.body_producer(self.connection.write)
             if fut is not None:
-                yield fut
+                await fut
         self.connection.finish()
         if start_read:
             try:
-                yield self.connection.read_response(self)
+                await self.connection.read_response(self)
             except StreamClosedError:
                 if not self._handle_exception(*sys.exc_info()):
                     raise
@@ -564,14 +564,14 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
             except HTTPStreamClosedError:
                 self._handle_exception(*sys.exc_info())
 
-    def headers_received(
+    async def headers_received(
         self,
         first_line: Union[httputil.ResponseStartLine, httputil.RequestStartLine],
         headers: httputil.HTTPHeaders,
     ) -> None:
         assert isinstance(first_line, httputil.ResponseStartLine)
         if self.request.expect_100_continue and first_line.code == 100:
-            self._write_body(False)
+            await self._write_body(False)
             return
         self.code = first_line.code
         self.reason = first_line.reason
index 5c2560d2c7f7635a15ade096d075125303a30aa2..9fe69716fc05106100d68b5e5a29973550e0a618 100644 (file)
@@ -31,7 +31,7 @@ from tornado.platform.auto import set_close_exec
 from tornado.gen import TimeoutError
 
 import typing
-from typing import Generator, Any, Union, Dict, Tuple, List, Callable, Iterator
+from typing import Any, Union, Dict, Tuple, List, Callable, Iterator
 
 if typing.TYPE_CHECKING:
     from typing import Optional, Set  # noqa: F401
@@ -219,8 +219,7 @@ class TCPClient(object):
         if self._own_resolver:
             self.resolver.close()
 
-    @gen.coroutine
-    def connect(
+    async def connect(
         self,
         host: str,
         port: int,
@@ -230,7 +229,7 @@ class TCPClient(object):
         source_ip: str = None,
         source_port: int = None,
         timeout: Union[float, datetime.timedelta] = None,
-    ) -> Generator[Any, Any, IOStream]:
+    ) -> IOStream:
         """Connect to the given host and port.
 
         Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
@@ -264,11 +263,11 @@ class TCPClient(object):
             else:
                 raise TypeError("Unsupported timeout %r" % timeout)
         if timeout is not None:
-            addrinfo = yield gen.with_timeout(
+            addrinfo = await gen.with_timeout(
                 timeout, self.resolver.resolve(host, port, af)
             )
         else:
-            addrinfo = yield self.resolver.resolve(host, port, af)
+            addrinfo = await self.resolver.resolve(host, port, af)
         connector = _Connector(
             addrinfo,
             functools.partial(
@@ -278,20 +277,20 @@ class TCPClient(object):
                 source_port=source_port,
             ),
         )
-        af, addr, stream = yield connector.start(connect_timeout=timeout)
+        af, addr, stream = await connector.start(connect_timeout=timeout)
         # TODO: For better performance we could cache the (af, addr)
         # information here and re-use it on subsequent connections to
         # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
         if ssl_options is not None:
             if timeout is not None:
-                stream = yield gen.with_timeout(
+                stream = await gen.with_timeout(
                     timeout,
                     stream.start_tls(
                         False, ssl_options=ssl_options, server_hostname=host
                     ),
                 )
             else:
-                stream = yield stream.start_tls(
+                stream = await stream.start_tls(
                     False, ssl_options=ssl_options, server_hostname=host
                 )
         return stream
index 11539a2f7270868aa33350c48d88549db7a753ad..956f7af3b1b6de79a180d1fdf90968310f6978ed 100644 (file)
@@ -18,7 +18,6 @@ from tornado.ioloop import IOLoop
 from tornado.iostream import UnsatisfiableReadError
 from tornado.locks import Event
 from tornado.log import gen_log
-from tornado.concurrent import Future
 from tornado.netutil import Resolver, bind_sockets
 from tornado.simple_httpclient import (
     SimpleAsyncHTTPClient,
@@ -273,9 +272,14 @@ class SimpleHTTPClientTestMixin(object):
     def test_connect_timeout(self):
         timeout = 0.1
 
+        cleanup_event = Event()
+        test = self
+
         class TimeoutResolver(Resolver):
-            def resolve(self, *args, **kwargs):
-                return Future()  # never completes
+            async def resolve(self, *args, **kwargs):
+                await cleanup_event.wait()
+                # Return something valid so the test doesn't raise during shutdown.
+                return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))]
 
         with closing(self.create_client(resolver=TimeoutResolver())) as client:
             with self.assertRaises(HTTPTimeoutError):
@@ -286,6 +290,12 @@ class SimpleHTTPClientTestMixin(object):
                     raise_error=True,
                 )
 
+        # Let the hanging coroutine clean up after itself. We need to
+        # wait more than a single IOLoop iteration for the SSL case,
+        # which logs errors on unexpected EOF.
+        cleanup_event.set()
+        yield gen.sleep(0.2)
+
     @skipOnTravis
     def test_request_timeout(self):
         timeout = 0.1
@@ -694,11 +704,16 @@ class HostnameMappingTestCase(AsyncHTTPTestCase):
 
 class ResolveTimeoutTestCase(AsyncHTTPTestCase):
     def setUp(self):
+        self.cleanup_event = Event()
+        test = self
+
         # Dummy Resolver subclass that never finishes.
         class BadResolver(Resolver):
             @gen.coroutine
             def resolve(self, *args, **kwargs):
-                yield Event().wait()
+                yield test.cleanup_event.wait()
+                # Return something valid so the test doesn't raise during cleanup.
+                return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))]
 
         super(ResolveTimeoutTestCase, self).setUp()
         self.http_client = SimpleAsyncHTTPClient(resolver=BadResolver())
@@ -710,6 +725,10 @@ class ResolveTimeoutTestCase(AsyncHTTPTestCase):
         with self.assertRaises(HTTPTimeoutError):
             self.fetch("/hello", connect_timeout=0.1, raise_error=True)
 
+        # Let the hanging coroutine clean up after itself
+        self.cleanup_event.set()
+        self.io_loop.run_sync(lambda: gen.sleep(0))
+
 
 class MaxHeaderSizeTest(AsyncHTTPTestCase):
     def get_app(self):
index c6dca17ffb6295e01ec9151493864f00714defc1..46dba1fdb867dcf8621e94d3577f1c169b72e35a 100644 (file)
@@ -2571,11 +2571,10 @@ class StaticFileHandler(RequestHandler):
         with cls._lock:
             cls._static_hashes = {}
 
-    def head(self, path: str) -> "Future[None]":
+    def head(self, path: str) -> Awaitable[None]:
         return self.get(path, include_body=False)
 
-    @gen.coroutine
-    def get(self, path: str, include_body: bool = True) -> Generator[Any, Any, None]:
+    async def get(self, path: str, include_body: bool = True) -> None:
         # Set up our path instance variables.
         self.path = self.parse_url_path(path)
         del path  # make sure we don't refer to path instead of self.path again
@@ -2644,7 +2643,7 @@ class StaticFileHandler(RequestHandler):
             for chunk in content:
                 try:
                     self.write(chunk)
-                    yield self.flush()
+                    await self.flush()
                 except iostream.StreamClosedError:
                     return
         else:
index d100b98059360ceac787d576e2e943ff854a00f6..33629bd4c2b35bd4b397fe1babe202504056d677 100644 (file)
@@ -1444,16 +1444,17 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection):
                     WebSocketError("Non-websocket response")
                 )
 
-    def headers_received(
+    async def headers_received(
         self,
         start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
         headers: httputil.HTTPHeaders,
     ) -> None:
         assert isinstance(start_line, httputil.ResponseStartLine)
         if start_line.code != 101:
-            return super(WebSocketClientConnection, self).headers_received(
+            await super(WebSocketClientConnection, self).headers_received(
                 start_line, headers
             )
+            return
 
         self.headers = headers
         self.protocol = self.get_websocket_protocol()