From: Ben Darnell Date: Sun, 12 Aug 2018 17:01:33 +0000 (-0400) Subject: tcpclient,tcpserver: Add type annotations X-Git-Tag: v6.0.0b1~33^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1368ee6a0935ec1ae2b186744d2c0e9a1e7b5838;p=thirdparty%2Ftornado.git tcpclient,tcpserver: Add type annotations --- diff --git a/setup.cfg b/setup.cfg index c843ce7b3..9a9b1dcc1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,6 +40,12 @@ disallow_untyped_defs = True [mypy-tornado.platform.*] disallow_untyped_defs = True +[mypy-tornado.tcpclient] +disallow_untyped_defs = True + +[mypy-tornado.tcpserver] +disallow_untyped_defs = True + [mypy-tornado.testing] disallow_untyped_defs = True @@ -81,5 +87,11 @@ check_untyped_defs = True [mypy-tornado.test.options_test] check_untyped_defs = True +[mypy-tornado.test.tcpclient_test] +check_untyped_defs = True + +[mypy-tornado.test.tcpserver_test] +check_untyped_defs = True + [mypy-tornado.test.testing_test] check_untyped_defs = True diff --git a/tornado/tcpclient.py b/tornado/tcpclient.py index 94dcdf150..902280ddf 100644 --- a/tornado/tcpclient.py +++ b/tornado/tcpclient.py @@ -20,6 +20,7 @@ import functools import socket import numbers import datetime +import ssl from tornado.concurrent import Future, future_add_done_callback from tornado.ioloop import IOLoop @@ -29,6 +30,11 @@ from tornado.netutil import Resolver from tornado.platform.auto import set_close_exec from tornado.gen import TimeoutError +import typing +from typing import Generator, Any, Union, Dict, Tuple, List, Callable, Iterator +if typing.TYPE_CHECKING: + from typing import Optional, Set # noqa: F401 + _INITIAL_CONNECT_TIMEOUT = 0.3 @@ -49,20 +55,23 @@ class _Connector(object): http://tools.ietf.org/html/rfc6555 """ - def __init__(self, addrinfo, connect): + def __init__(self, addrinfo: List[Tuple], + connect: Callable[[socket.AddressFamily, Tuple], + Tuple[IOStream, 'Future[IOStream]']]) -> None: self.io_loop = IOLoop.current() self.connect = connect - self.future = Future() - self.timeout = None - self.connect_timeout = None - self.last_error = None + self.future = Future() # type: Future[Tuple[socket.AddressFamily, Any, IOStream]] + self.timeout = None # type: Optional[object] + self.connect_timeout = None # type: Optional[object] + self.last_error = None # type: Optional[Exception] self.remaining = len(addrinfo) self.primary_addrs, self.secondary_addrs = self.split(addrinfo) - self.streams = set() + self.streams = set() # type: Set[IOStream] @staticmethod - def split(addrinfo): + def split(addrinfo: List[Tuple]) -> Tuple[List[Tuple[socket.AddressFamily, Tuple]], + List[Tuple[socket.AddressFamily, Tuple]]]: """Partition the ``addrinfo`` list by address family. Returns two lists. The first list contains the first entry from @@ -81,14 +90,17 @@ class _Connector(object): secondary.append((af, addr)) return primary, secondary - def start(self, timeout=_INITIAL_CONNECT_TIMEOUT, connect_timeout=None): + def start( + self, timeout: float=_INITIAL_CONNECT_TIMEOUT, + connect_timeout: Union[float, datetime.timedelta]=None, + ) -> 'Future[Tuple[socket.AddressFamily, Any, IOStream]]': self.try_connect(iter(self.primary_addrs)) self.set_timeout(timeout) if connect_timeout is not None: self.set_connect_timeout(connect_timeout) return self.future - def try_connect(self, addrs): + def try_connect(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]]) -> None: try: af, addr = next(addrs) except StopIteration: @@ -104,7 +116,9 @@ class _Connector(object): future_add_done_callback( future, functools.partial(self.on_connect_done, addrs, af, addr)) - def on_connect_done(self, addrs, af, addr, future): + def on_connect_done(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]], + af: socket.AddressFamily, addr: Tuple, + future: 'Future[IOStream]') -> None: self.remaining -= 1 try: stream = future.result() @@ -130,35 +144,35 @@ class _Connector(object): self.future.set_result((af, addr, stream)) self.close_streams() - def set_timeout(self, timeout): + def set_timeout(self, timeout: float) -> None: self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout, self.on_timeout) - def on_timeout(self): + def on_timeout(self) -> None: self.timeout = None if not self.future.done(): self.try_connect(iter(self.secondary_addrs)) - def clear_timeout(self): + def clear_timeout(self) -> None: if self.timeout is not None: self.io_loop.remove_timeout(self.timeout) - def set_connect_timeout(self, connect_timeout): + def set_connect_timeout(self, connect_timeout: Union[float, datetime.timedelta]) -> None: self.connect_timeout = self.io_loop.add_timeout( connect_timeout, self.on_connect_timeout) - def on_connect_timeout(self): + def on_connect_timeout(self) -> None: if not self.future.done(): self.future.set_exception(TimeoutError()) self.close_streams() - def clear_timeouts(self): + def clear_timeouts(self) -> None: if self.timeout is not None: self.io_loop.remove_timeout(self.timeout) if self.connect_timeout is not None: self.io_loop.remove_timeout(self.connect_timeout) - def close_streams(self): + def close_streams(self) -> None: for stream in self.streams: stream.close() @@ -169,7 +183,7 @@ class TCPClient(object): .. versionchanged:: 5.0 The ``io_loop`` argument (deprecated since version 4.1) has been removed. """ - def __init__(self, resolver=None): + def __init__(self, resolver: Resolver=None) -> None: if resolver is not None: self.resolver = resolver self._own_resolver = False @@ -177,14 +191,15 @@ class TCPClient(object): self.resolver = Resolver() self._own_resolver = True - def close(self): + def close(self) -> None: if self._own_resolver: self.resolver.close() @gen.coroutine - def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None, - max_buffer_size=None, source_ip=None, source_port=None, - timeout=None): + def connect(self, host: str, port: int, af: socket.AddressFamily=socket.AF_UNSPEC, + ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None, + max_buffer_size: int=None, source_ip: str=None, source_port: int=None, + timeout: Union[float, datetime.timedelta]=None) -> Generator[Any, Any, IOStream]: """Connect to the given host and port. Asynchronously returns an `.IOStream` (or `.SSLIOStream` if @@ -238,10 +253,11 @@ class TCPClient(object): else: stream = yield stream.start_tls(False, ssl_options=ssl_options, server_hostname=host) - raise gen.Return(stream) + return stream - def _create_stream(self, max_buffer_size, af, addr, source_ip=None, - source_port=None): + def _create_stream(self, max_buffer_size: int, af: socket.AddressFamily, + addr: Tuple, source_ip: str=None, + source_port: int=None) -> Tuple[IOStream, 'Future[IOStream]']: # Always connect in plaintext; we'll convert to ssl if necessary # after one connection has completed. source_port_bind = source_port if isinstance(source_port, int) else 0 @@ -267,8 +283,8 @@ class TCPClient(object): stream = IOStream(socket_obj, max_buffer_size=max_buffer_size) except socket.error as e: - fu = Future() + fu = Future() # type: Future[IOStream] fu.set_exception(e) - return fu + return stream, fu else: return stream, stream.connect(addr) diff --git a/tornado/tcpserver.py b/tornado/tcpserver.py index 038d8223c..4d59f1ebe 100644 --- a/tornado/tcpserver.py +++ b/tornado/tcpserver.py @@ -28,6 +28,11 @@ from tornado.netutil import bind_sockets, add_accept_handler, ssl_wrap_socket from tornado import process from tornado.util import errno_from_exception +import typing +from typing import Union, Dict, Any, Iterable, Optional, Awaitable +if typing.TYPE_CHECKING: + from typing import Callable, List # noqa: F401 + class TCPServer(object): r"""A non-blocking, single-threaded TCP server. @@ -98,12 +103,12 @@ class TCPServer(object): .. versionchanged:: 5.0 The ``io_loop`` argument has been removed. """ - def __init__(self, ssl_options=None, max_buffer_size=None, - read_chunk_size=None): + def __init__(self, ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None, + max_buffer_size: int=None, read_chunk_size: int=None) -> None: self.ssl_options = ssl_options - self._sockets = {} # fd -> socket object - self._handlers = {} # fd -> remove_handler callable - self._pending_sockets = [] + self._sockets = {} # type: Dict[int, socket.socket] + self._handlers = {} # type: Dict[int, Callable[[], None]] + self._pending_sockets = [] # type: List[socket.socket] self._started = False self._stopped = False self.max_buffer_size = max_buffer_size @@ -126,7 +131,7 @@ class TCPServer(object): raise ValueError('keyfile "%s" does not exist' % self.ssl_options['keyfile']) - def listen(self, port, address=""): + def listen(self, port: int, address: str="") -> None: """Starts accepting connections on the given port. This method may be called more than once to listen on multiple ports. @@ -137,7 +142,7 @@ class TCPServer(object): sockets = bind_sockets(port, address=address) self.add_sockets(sockets) - def add_sockets(self, sockets): + def add_sockets(self, sockets: Iterable[socket.socket]) -> None: """Makes this server start accepting connections on the given sockets. The ``sockets`` parameter is a list of socket objects such as @@ -151,12 +156,13 @@ class TCPServer(object): self._handlers[sock.fileno()] = add_accept_handler( sock, self._handle_connection) - def add_socket(self, socket): + def add_socket(self, socket: socket.socket) -> None: """Singular version of `add_sockets`. Takes a single socket object.""" self.add_sockets([socket]) - def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128, - reuse_port=False): + def bind(self, port: int, address: str=None, + family: socket.AddressFamily=socket.AF_UNSPEC, + backlog: int=128, reuse_port: bool=False) -> None: """Binds this server to the given port on the given address. To start the server, call `start`. If you want to run this server @@ -187,7 +193,7 @@ class TCPServer(object): else: self._pending_sockets.extend(sockets) - def start(self, num_processes=1): + def start(self, num_processes: Optional[int]=1) -> None: """Starts this server in the `.IOLoop`. By default, we run the server in this process and do not fork any @@ -215,7 +221,7 @@ class TCPServer(object): self._pending_sockets = [] self.add_sockets(sockets) - def stop(self): + def stop(self) -> None: """Stops listening for new connections. Requests currently in progress may still continue after the @@ -230,7 +236,7 @@ class TCPServer(object): self._handlers.pop(fd)() sock.close() - def handle_stream(self, stream, address): + def handle_stream(self, stream: IOStream, address: tuple) -> Optional[Awaitable[None]]: """Override to handle a new `.IOStream` from an incoming connection. This method may be a coroutine; if so any exceptions it raises @@ -247,7 +253,7 @@ class TCPServer(object): """ raise NotImplementedError() - def _handle_connection(self, connection, address): + def _handle_connection(self, connection: socket.socket, address: Any) -> None: if self.ssl_options is not None: assert ssl, "Python 2.6+ and OpenSSL required for SSL" try: @@ -277,9 +283,10 @@ class TCPServer(object): raise try: if self.ssl_options is not None: - stream = SSLIOStream(connection, - max_buffer_size=self.max_buffer_size, - read_chunk_size=self.read_chunk_size) + stream = SSLIOStream( + connection, + max_buffer_size=self.max_buffer_size, + read_chunk_size=self.read_chunk_size) # type: IOStream else: stream = IOStream(connection, max_buffer_size=self.max_buffer_size, diff --git a/tornado/test/tcpclient_test.py b/tornado/test/tcpclient_test.py index df4eee9df..63cf23cf5 100644 --- a/tornado/test/tcpclient_test.py +++ b/tornado/test/tcpclient_test.py @@ -26,6 +26,11 @@ from tornado.testing import AsyncTestCase, gen_test from tornado.test.util import skipIfNoIPv6, refusing_port, skipIfNonUnix from tornado.gen import TimeoutError +import typing +if typing.TYPE_CHECKING: + from tornado.iostream import IOStream # noqa: F401 + from typing import List, Dict, Tuple # noqa: F401 + # Fake address families for testing. Used in place of AF_INET # and AF_INET6 because some installations do not have AF_INET6. AF1, AF2 = 1, 2 @@ -34,9 +39,9 @@ AF1, AF2 = 1, 2 class TestTCPServer(TCPServer): def __init__(self, family): super(TestTCPServer, self).__init__() - self.streams = [] + self.streams = [] # type: List[IOStream] self.queue = Queue() - sockets = bind_sockets(None, 'localhost', family) + sockets = bind_sockets(0, 'localhost', family) self.add_sockets(sockets) self.port = sockets[0].getsockname()[1] @@ -197,8 +202,9 @@ class ConnectorTest(AsyncTestCase): def setUp(self): super(ConnectorTest, self).setUp() - self.connect_futures = {} - self.streams = {} + self.connect_futures = {} \ + # type: Dict[Tuple[int, Tuple], Future[ConnectorTest.FakeStream]] + self.streams = {} # type: Dict[Tuple, ConnectorTest.FakeStream] self.addrinfo = [(AF1, 'a'), (AF1, 'b'), (AF2, 'c'), (AF2, 'd')] @@ -212,7 +218,7 @@ class ConnectorTest(AsyncTestCase): def create_stream(self, af, addr): stream = ConnectorTest.FakeStream() self.streams[addr] = stream - future = Future() + future = Future() # type: Future[ConnectorTest.FakeStream] self.connect_futures[(af, addr)] = future return stream, future diff --git a/tornado/test/tcpserver_test.py b/tornado/test/tcpserver_test.py index 7408c780d..d33f66b21 100644 --- a/tornado/test/tcpserver_test.py +++ b/tornado/test/tcpserver_test.py @@ -77,7 +77,7 @@ class TCPServerTest(AsyncTestCase): class TestServer(TCPServer): @gen.coroutine def handle_stream(self, stream, address): - server.stop() + server.stop() # type: ignore yield stream.read_until_close() sock, port = bind_unused_port()