From: Ben Darnell Date: Sun, 14 Oct 2018 19:28:49 +0000 (-0400) Subject: *: Convert most non-test coroutines to native X-Git-Tag: v6.0.0b1~23^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bc74d7b08aa6cdc1b592638b07742d98b9a3019a;p=thirdparty%2Ftornado.git *: Convert most non-test coroutines to native The websocket module is an exception; it needs a little more work to change things without breaking tests. --- diff --git a/tornado/httpserver.py b/tornado/httpserver.py index 78945c10a..9376c5685 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -30,7 +30,6 @@ import ssl from tornado.escape import native_str from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters -from tornado import gen from tornado import httputil from tornado import iostream from tornado import netutil @@ -38,18 +37,7 @@ from tornado.tcpserver import TCPServer from tornado.util import Configurable import typing -from typing import ( - Union, - Any, - Dict, - Callable, - List, - Type, - Generator, - Tuple, - Optional, - Awaitable, -) +from typing import Union, Any, Dict, Callable, List, Type, Tuple, Optional, Awaitable if typing.TYPE_CHECKING: from typing import Set # noqa: F401 @@ -209,12 +197,11 @@ class HTTPServer(TCPServer, Configurable, httputil.HTTPServerConnectionDelegate) def configurable_default(cls) -> Type[Configurable]: return HTTPServer - @gen.coroutine - def close_all_connections(self) -> Generator[Any, Any, None]: + async def close_all_connections(self) -> None: while self._connections: # Peek at an arbitrary element of the set conn = next(iter(self._connections)) - yield conn.close() + await conn.close() def handle_stream(self, stream: iostream.IOStream, address: Tuple) -> None: context = _HTTPRequestContext( diff --git a/tornado/netutil.py b/tornado/netutil.py index 7c07c48ff..d570ff96d 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -24,17 +24,15 @@ import ssl import stat from tornado.concurrent import dummy_executor, run_on_executor -from tornado import gen from tornado.ioloop import IOLoop from tornado.platform.auto import set_close_exec from tornado.util import Configurable, errno_from_exception import typing -from typing import List, Callable, Any, Type, Generator, Dict, Union, Tuple +from typing import List, Callable, Any, Type, Dict, Union, Tuple, Awaitable if typing.TYPE_CHECKING: from asyncio import Future # noqa: F401 - from typing import Awaitable # noqa: F401 # Note that the naming of ssl.Purpose is confusing; the purpose # of a context is to authentiate the opposite side of the connection. @@ -337,7 +335,7 @@ class Resolver(Configurable): def resolve( self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC - ) -> "Future[List[Tuple[int, Any]]]": + ) -> Awaitable[List[Tuple[int, Any]]]: """Resolves an address. The ``host`` argument is a string which may be a hostname or a @@ -391,11 +389,10 @@ class DefaultExecutorResolver(Resolver): .. versionadded:: 5.0 """ - @gen.coroutine - def resolve( + async def resolve( self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC - ) -> Generator[Any, Any, List[Tuple[int, Any]]]: - result = yield IOLoop.current().run_in_executor( + ) -> List[Tuple[int, Any]]: + result = await IOLoop.current().run_in_executor( None, _resolve_addr, host, port, family ) return result @@ -534,7 +531,7 @@ class OverrideResolver(Resolver): def resolve( self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC - ) -> "Future[List[Tuple[int, Any]]]": + ) -> Awaitable[List[Tuple[int, Any]]]: if (host, port, family) in self.mapping: host, port = self.mapping[(host, port, family)] elif (host, port) in self.mapping: diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index e8474cb9d..d6ac20f91 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -28,7 +28,7 @@ import time from io import BytesIO import urllib.parse -from typing import Dict, Any, Generator, Callable, Optional, Type, Union +from typing import Dict, Any, Callable, Optional, Type, Union from types import TracebackType import typing @@ -277,10 +277,11 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): # Timeout handle returned by IOLoop.add_timeout self._timeout = None # type: object self._sockaddr = None - IOLoop.current().add_callback(self.run) + IOLoop.current().add_future( + gen.convert_yielded(self.run()), lambda f: f.result() + ) - @gen.coroutine - def run(self) -> Generator[Any, Any, None]: + async def run(self) -> None: try: self.parsed = urllib.parse.urlsplit(_unicode(self.request.url)) if self.parsed.scheme not in ("http", "https"): @@ -311,7 +312,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): self.start_time + timeout, functools.partial(self._on_timeout, "while connecting"), ) - stream = yield self.tcp_client.connect( + stream = await self.tcp_client.connect( host, port, af=af, @@ -417,9 +418,9 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): ) self.connection.write_headers(start_line, self.request.headers) if self.request.expect_100_continue: - yield self.connection.read_response(self) + await self.connection.read_response(self) else: - yield self._write_body(True) + await self._write_body(True) except Exception: if not self._handle_exception(*sys.exc_info()): raise @@ -489,18 +490,17 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): ) return connection - @gen.coroutine - def _write_body(self, start_read: bool) -> Generator[Any, Any, None]: + async def _write_body(self, start_read: bool) -> None: if self.request.body is not None: self.connection.write(self.request.body) elif self.request.body_producer is not None: fut = self.request.body_producer(self.connection.write) if fut is not None: - yield fut + await fut self.connection.finish() if start_read: try: - yield self.connection.read_response(self) + await self.connection.read_response(self) except StreamClosedError: if not self._handle_exception(*sys.exc_info()): raise @@ -564,14 +564,14 @@ class _HTTPConnection(httputil.HTTPMessageDelegate): except HTTPStreamClosedError: self._handle_exception(*sys.exc_info()) - def headers_received( + async def headers_received( self, first_line: Union[httputil.ResponseStartLine, httputil.RequestStartLine], headers: httputil.HTTPHeaders, ) -> None: assert isinstance(first_line, httputil.ResponseStartLine) if self.request.expect_100_continue and first_line.code == 100: - self._write_body(False) + await self._write_body(False) return self.code = first_line.code self.reason = first_line.reason diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py index 5c2560d2c..9fe69716f 100644 --- a/tornado/tcpclient.py +++ b/tornado/tcpclient.py @@ -31,7 +31,7 @@ from tornado.platform.auto import set_close_exec from tornado.gen import TimeoutError import typing -from typing import Generator, Any, Union, Dict, Tuple, List, Callable, Iterator +from typing import Any, Union, Dict, Tuple, List, Callable, Iterator if typing.TYPE_CHECKING: from typing import Optional, Set # noqa: F401 @@ -219,8 +219,7 @@ class TCPClient(object): if self._own_resolver: self.resolver.close() - @gen.coroutine - def connect( + async def connect( self, host: str, port: int, @@ -230,7 +229,7 @@ class TCPClient(object): source_ip: str = None, source_port: int = None, timeout: Union[float, datetime.timedelta] = None, - ) -> Generator[Any, Any, IOStream]: + ) -> IOStream: """Connect to the given host and port. Asynchronously returns an `.IOStream` (or `.SSLIOStream` if @@ -264,11 +263,11 @@ class TCPClient(object): else: raise TypeError("Unsupported timeout %r" % timeout) if timeout is not None: - addrinfo = yield gen.with_timeout( + addrinfo = await gen.with_timeout( timeout, self.resolver.resolve(host, port, af) ) else: - addrinfo = yield self.resolver.resolve(host, port, af) + addrinfo = await self.resolver.resolve(host, port, af) connector = _Connector( addrinfo, functools.partial( @@ -278,20 +277,20 @@ class TCPClient(object): source_port=source_port, ), ) - af, addr, stream = yield connector.start(connect_timeout=timeout) + af, addr, stream = await connector.start(connect_timeout=timeout) # TODO: For better performance we could cache the (af, addr) # information here and re-use it on subsequent connections to # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2) if ssl_options is not None: if timeout is not None: - stream = yield gen.with_timeout( + stream = await gen.with_timeout( timeout, stream.start_tls( False, ssl_options=ssl_options, server_hostname=host ), ) else: - stream = yield stream.start_tls( + stream = await stream.start_tls( False, ssl_options=ssl_options, server_hostname=host ) return stream diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index 11539a2f7..956f7af3b 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -18,7 +18,6 @@ from tornado.ioloop import IOLoop from tornado.iostream import UnsatisfiableReadError from tornado.locks import Event from tornado.log import gen_log -from tornado.concurrent import Future from tornado.netutil import Resolver, bind_sockets from tornado.simple_httpclient import ( SimpleAsyncHTTPClient, @@ -273,9 +272,14 @@ class SimpleHTTPClientTestMixin(object): def test_connect_timeout(self): timeout = 0.1 + cleanup_event = Event() + test = self + class TimeoutResolver(Resolver): - def resolve(self, *args, **kwargs): - return Future() # never completes + async def resolve(self, *args, **kwargs): + await cleanup_event.wait() + # Return something valid so the test doesn't raise during shutdown. + return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))] with closing(self.create_client(resolver=TimeoutResolver())) as client: with self.assertRaises(HTTPTimeoutError): @@ -286,6 +290,12 @@ class SimpleHTTPClientTestMixin(object): raise_error=True, ) + # Let the hanging coroutine clean up after itself. We need to + # wait more than a single IOLoop iteration for the SSL case, + # which logs errors on unexpected EOF. + cleanup_event.set() + yield gen.sleep(0.2) + @skipOnTravis def test_request_timeout(self): timeout = 0.1 @@ -694,11 +704,16 @@ class HostnameMappingTestCase(AsyncHTTPTestCase): class ResolveTimeoutTestCase(AsyncHTTPTestCase): def setUp(self): + self.cleanup_event = Event() + test = self + # Dummy Resolver subclass that never finishes. class BadResolver(Resolver): @gen.coroutine def resolve(self, *args, **kwargs): - yield Event().wait() + yield test.cleanup_event.wait() + # Return something valid so the test doesn't raise during cleanup. + return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))] super(ResolveTimeoutTestCase, self).setUp() self.http_client = SimpleAsyncHTTPClient(resolver=BadResolver()) @@ -710,6 +725,10 @@ class ResolveTimeoutTestCase(AsyncHTTPTestCase): with self.assertRaises(HTTPTimeoutError): self.fetch("/hello", connect_timeout=0.1, raise_error=True) + # Let the hanging coroutine clean up after itself + self.cleanup_event.set() + self.io_loop.run_sync(lambda: gen.sleep(0)) + class MaxHeaderSizeTest(AsyncHTTPTestCase): def get_app(self): diff --git a/tornado/web.py b/tornado/web.py index c6dca17ff..46dba1fdb 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -2571,11 +2571,10 @@ class StaticFileHandler(RequestHandler): with cls._lock: cls._static_hashes = {} - def head(self, path: str) -> "Future[None]": + def head(self, path: str) -> Awaitable[None]: return self.get(path, include_body=False) - @gen.coroutine - def get(self, path: str, include_body: bool = True) -> Generator[Any, Any, None]: + async def get(self, path: str, include_body: bool = True) -> None: # Set up our path instance variables. self.path = self.parse_url_path(path) del path # make sure we don't refer to path instead of self.path again @@ -2644,7 +2643,7 @@ class StaticFileHandler(RequestHandler): for chunk in content: try: self.write(chunk) - yield self.flush() + await self.flush() except iostream.StreamClosedError: return else: diff --git a/tornado/websocket.py b/tornado/websocket.py index d100b9805..33629bd4c 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -1444,16 +1444,17 @@ class WebSocketClientConnection(simple_httpclient._HTTPConnection): WebSocketError("Non-websocket response") ) - def headers_received( + async def headers_received( self, start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], headers: httputil.HTTPHeaders, ) -> None: assert isinstance(start_line, httputil.ResponseStartLine) if start_line.code != 101: - return super(WebSocketClientConnection, self).headers_received( + await super(WebSocketClientConnection, self).headers_received( start_line, headers ) + return self.headers = headers self.protocol = self.get_websocket_protocol()