]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
tcpclient,tcpserver: Add type annotations
authorBen Darnell <ben@bendarnell.com>
Sun, 12 Aug 2018 17:01:33 +0000 (13:01 -0400)
committerBen Darnell <ben@bendarnell.com>
Mon, 10 Sep 2018 04:23:16 +0000 (00:23 -0400)
setup.cfg
tornado/tcpclient.py
tornado/tcpserver.py
tornado/test/tcpclient_test.py
tornado/test/tcpserver_test.py

index c843ce7b3f16b06749b6b60fb0537d292b1b93f2..9a9b1dcc14ca88ab8ba10e20bbbb08f4a56cfa35 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -40,6 +40,12 @@ disallow_untyped_defs = True
 [mypy-tornado.platform.*]
 disallow_untyped_defs = True
 
+[mypy-tornado.tcpclient]
+disallow_untyped_defs = True
+
+[mypy-tornado.tcpserver]
+disallow_untyped_defs = True
+
 [mypy-tornado.testing]
 disallow_untyped_defs = True
 
@@ -81,5 +87,11 @@ check_untyped_defs = True
 [mypy-tornado.test.options_test]
 check_untyped_defs = True
 
+[mypy-tornado.test.tcpclient_test]
+check_untyped_defs = True
+
+[mypy-tornado.test.tcpserver_test]
+check_untyped_defs = True
+
 [mypy-tornado.test.testing_test]
 check_untyped_defs = True
index 94dcdf1506fd8ece979c71317633ceaa3ceb57f7..902280ddfb8b103af4207918381bf1a253e4faa1 100644 (file)
@@ -20,6 +20,7 @@ import functools
 import socket
 import numbers
 import datetime
+import ssl
 
 from tornado.concurrent import Future, future_add_done_callback
 from tornado.ioloop import IOLoop
@@ -29,6 +30,11 @@ from tornado.netutil import Resolver
 from tornado.platform.auto import set_close_exec
 from tornado.gen import TimeoutError
 
+import typing
+from typing import Generator, Any, Union, Dict, Tuple, List, Callable, Iterator
+if typing.TYPE_CHECKING:
+    from typing import Optional, Set  # noqa: F401
+
 _INITIAL_CONNECT_TIMEOUT = 0.3
 
 
@@ -49,20 +55,23 @@ class _Connector(object):
     http://tools.ietf.org/html/rfc6555
 
     """
-    def __init__(self, addrinfo, connect):
+    def __init__(self, addrinfo: List[Tuple],
+                 connect: Callable[[socket.AddressFamily, Tuple],
+                                   Tuple[IOStream, 'Future[IOStream]']]) -> None:
         self.io_loop = IOLoop.current()
         self.connect = connect
 
-        self.future = Future()
-        self.timeout = None
-        self.connect_timeout = None
-        self.last_error = None
+        self.future = Future()  # type: Future[Tuple[socket.AddressFamily, Any, IOStream]]
+        self.timeout = None  # type: Optional[object]
+        self.connect_timeout = None  # type: Optional[object]
+        self.last_error = None  # type: Optional[Exception]
         self.remaining = len(addrinfo)
         self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
-        self.streams = set()
+        self.streams = set()  # type: Set[IOStream]
 
     @staticmethod
-    def split(addrinfo):
+    def split(addrinfo: List[Tuple]) -> Tuple[List[Tuple[socket.AddressFamily, Tuple]],
+                                              List[Tuple[socket.AddressFamily, Tuple]]]:
         """Partition the ``addrinfo`` list by address family.
 
         Returns two lists.  The first list contains the first entry from
@@ -81,14 +90,17 @@ class _Connector(object):
                 secondary.append((af, addr))
         return primary, secondary
 
-    def start(self, timeout=_INITIAL_CONNECT_TIMEOUT, connect_timeout=None):
+    def start(
+            self, timeout: float=_INITIAL_CONNECT_TIMEOUT,
+            connect_timeout: Union[float, datetime.timedelta]=None,
+    ) -> 'Future[Tuple[socket.AddressFamily, Any, IOStream]]':
         self.try_connect(iter(self.primary_addrs))
         self.set_timeout(timeout)
         if connect_timeout is not None:
             self.set_connect_timeout(connect_timeout)
         return self.future
 
-    def try_connect(self, addrs):
+    def try_connect(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]]) -> None:
         try:
             af, addr = next(addrs)
         except StopIteration:
@@ -104,7 +116,9 @@ class _Connector(object):
         future_add_done_callback(
             future, functools.partial(self.on_connect_done, addrs, af, addr))
 
-    def on_connect_done(self, addrs, af, addr, future):
+    def on_connect_done(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]],
+                        af: socket.AddressFamily, addr: Tuple,
+                        future: 'Future[IOStream]') -> None:
         self.remaining -= 1
         try:
             stream = future.result()
@@ -130,35 +144,35 @@ class _Connector(object):
             self.future.set_result((af, addr, stream))
             self.close_streams()
 
-    def set_timeout(self, timeout):
+    def set_timeout(self, timeout: float) -> None:
         self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
                                                 self.on_timeout)
 
-    def on_timeout(self):
+    def on_timeout(self) -> None:
         self.timeout = None
         if not self.future.done():
             self.try_connect(iter(self.secondary_addrs))
 
-    def clear_timeout(self):
+    def clear_timeout(self) -> None:
         if self.timeout is not None:
             self.io_loop.remove_timeout(self.timeout)
 
-    def set_connect_timeout(self, connect_timeout):
+    def set_connect_timeout(self, connect_timeout: Union[float, datetime.timedelta]) -> None:
         self.connect_timeout = self.io_loop.add_timeout(
             connect_timeout, self.on_connect_timeout)
 
-    def on_connect_timeout(self):
+    def on_connect_timeout(self) -> None:
         if not self.future.done():
             self.future.set_exception(TimeoutError())
         self.close_streams()
 
-    def clear_timeouts(self):
+    def clear_timeouts(self) -> None:
         if self.timeout is not None:
             self.io_loop.remove_timeout(self.timeout)
         if self.connect_timeout is not None:
             self.io_loop.remove_timeout(self.connect_timeout)
 
-    def close_streams(self):
+    def close_streams(self) -> None:
         for stream in self.streams:
             stream.close()
 
@@ -169,7 +183,7 @@ class TCPClient(object):
     .. versionchanged:: 5.0
        The ``io_loop`` argument (deprecated since version 4.1) has been removed.
     """
-    def __init__(self, resolver=None):
+    def __init__(self, resolver: Resolver=None) -> None:
         if resolver is not None:
             self.resolver = resolver
             self._own_resolver = False
@@ -177,14 +191,15 @@ class TCPClient(object):
             self.resolver = Resolver()
             self._own_resolver = True
 
-    def close(self):
+    def close(self) -> None:
         if self._own_resolver:
             self.resolver.close()
 
     @gen.coroutine
-    def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
-                max_buffer_size=None, source_ip=None, source_port=None,
-                timeout=None):
+    def connect(self, host: str, port: int, af: socket.AddressFamily=socket.AF_UNSPEC,
+                ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None,
+                max_buffer_size: int=None, source_ip: str=None, source_port: int=None,
+                timeout: Union[float, datetime.timedelta]=None) -> Generator[Any, Any, IOStream]:
         """Connect to the given host and port.
 
         Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
@@ -238,10 +253,11 @@ class TCPClient(object):
             else:
                 stream = yield stream.start_tls(False, ssl_options=ssl_options,
                                                 server_hostname=host)
-        raise gen.Return(stream)
+        return stream
 
-    def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
-                       source_port=None):
+    def _create_stream(self, max_buffer_size: int, af: socket.AddressFamily,
+                       addr: Tuple, source_ip: str=None,
+                       source_port: int=None) -> Tuple[IOStream, 'Future[IOStream]']:
         # Always connect in plaintext; we'll convert to ssl if necessary
         # after one connection has completed.
         source_port_bind = source_port if isinstance(source_port, int) else 0
@@ -267,8 +283,8 @@ class TCPClient(object):
             stream = IOStream(socket_obj,
                               max_buffer_size=max_buffer_size)
         except socket.error as e:
-            fu = Future()
+            fu = Future()  # type: Future[IOStream]
             fu.set_exception(e)
-            return fu
+            return stream, fu
         else:
             return stream, stream.connect(addr)
index 038d8223cdbf6da79c5e74c7a33c2276c5657e9d..4d59f1ebef22136c054cd2b8f8bfa766b14e58fa 100644 (file)
@@ -28,6 +28,11 @@ from tornado.netutil import bind_sockets, add_accept_handler, ssl_wrap_socket
 from tornado import process
 from tornado.util import errno_from_exception
 
+import typing
+from typing import Union, Dict, Any, Iterable, Optional, Awaitable
+if typing.TYPE_CHECKING:
+    from typing import Callable, List  # noqa: F401
+
 
 class TCPServer(object):
     r"""A non-blocking, single-threaded TCP server.
@@ -98,12 +103,12 @@ class TCPServer(object):
     .. versionchanged:: 5.0
        The ``io_loop`` argument has been removed.
     """
-    def __init__(self, ssl_options=None, max_buffer_size=None,
-                 read_chunk_size=None):
+    def __init__(self, ssl_options: Union[Dict[str, Any], ssl.SSLContext]=None,
+                 max_buffer_size: int=None, read_chunk_size: int=None) -> None:
         self.ssl_options = ssl_options
-        self._sockets = {}   # fd -> socket object
-        self._handlers = {}  # fd -> remove_handler callable
-        self._pending_sockets = []
+        self._sockets = {}   # type: Dict[int, socket.socket]
+        self._handlers = {}  # type: Dict[int, Callable[[], None]]
+        self._pending_sockets = []  # type: List[socket.socket]
         self._started = False
         self._stopped = False
         self.max_buffer_size = max_buffer_size
@@ -126,7 +131,7 @@ class TCPServer(object):
                 raise ValueError('keyfile "%s" does not exist' %
                                  self.ssl_options['keyfile'])
 
-    def listen(self, port, address=""):
+    def listen(self, port: int, address: str="") -> None:
         """Starts accepting connections on the given port.
 
         This method may be called more than once to listen on multiple ports.
@@ -137,7 +142,7 @@ class TCPServer(object):
         sockets = bind_sockets(port, address=address)
         self.add_sockets(sockets)
 
-    def add_sockets(self, sockets):
+    def add_sockets(self, sockets: Iterable[socket.socket]) -> None:
         """Makes this server start accepting connections on the given sockets.
 
         The ``sockets`` parameter is a list of socket objects such as
@@ -151,12 +156,13 @@ class TCPServer(object):
             self._handlers[sock.fileno()] = add_accept_handler(
                 sock, self._handle_connection)
 
-    def add_socket(self, socket):
+    def add_socket(self, socket: socket.socket) -> None:
         """Singular version of `add_sockets`.  Takes a single socket object."""
         self.add_sockets([socket])
 
-    def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128,
-             reuse_port=False):
+    def bind(self, port: int, address: str=None,
+             family: socket.AddressFamily=socket.AF_UNSPEC,
+             backlog: int=128, reuse_port: bool=False) -> None:
         """Binds this server to the given port on the given address.
 
         To start the server, call `start`. If you want to run this server
@@ -187,7 +193,7 @@ class TCPServer(object):
         else:
             self._pending_sockets.extend(sockets)
 
-    def start(self, num_processes=1):
+    def start(self, num_processes: Optional[int]=1) -> None:
         """Starts this server in the `.IOLoop`.
 
         By default, we run the server in this process and do not fork any
@@ -215,7 +221,7 @@ class TCPServer(object):
         self._pending_sockets = []
         self.add_sockets(sockets)
 
-    def stop(self):
+    def stop(self) -> None:
         """Stops listening for new connections.
 
         Requests currently in progress may still continue after the
@@ -230,7 +236,7 @@ class TCPServer(object):
             self._handlers.pop(fd)()
             sock.close()
 
-    def handle_stream(self, stream, address):
+    def handle_stream(self, stream: IOStream, address: tuple) -> Optional[Awaitable[None]]:
         """Override to handle a new `.IOStream` from an incoming connection.
 
         This method may be a coroutine; if so any exceptions it raises
@@ -247,7 +253,7 @@ class TCPServer(object):
         """
         raise NotImplementedError()
 
-    def _handle_connection(self, connection, address):
+    def _handle_connection(self, connection: socket.socket, address: Any) -> None:
         if self.ssl_options is not None:
             assert ssl, "Python 2.6+ and OpenSSL required for SSL"
             try:
@@ -277,9 +283,10 @@ class TCPServer(object):
                     raise
         try:
             if self.ssl_options is not None:
-                stream = SSLIOStream(connection,
-                                     max_buffer_size=self.max_buffer_size,
-                                     read_chunk_size=self.read_chunk_size)
+                stream = SSLIOStream(
+                    connection,
+                    max_buffer_size=self.max_buffer_size,
+                    read_chunk_size=self.read_chunk_size)  # type: IOStream
             else:
                 stream = IOStream(connection,
                                   max_buffer_size=self.max_buffer_size,
index df4eee9df57303af5abf48e1b4c971986665ee3f..63cf23cf5b5856ece282036b414bf296048cc218 100644 (file)
@@ -26,6 +26,11 @@ from tornado.testing import AsyncTestCase, gen_test
 from tornado.test.util import skipIfNoIPv6, refusing_port, skipIfNonUnix
 from tornado.gen import TimeoutError
 
+import typing
+if typing.TYPE_CHECKING:
+    from tornado.iostream import IOStream  # noqa: F401
+    from typing import List, Dict, Tuple  # noqa: F401
+
 # Fake address families for testing.  Used in place of AF_INET
 # and AF_INET6 because some installations do not have AF_INET6.
 AF1, AF2 = 1, 2
@@ -34,9 +39,9 @@ AF1, AF2 = 1, 2
 class TestTCPServer(TCPServer):
     def __init__(self, family):
         super(TestTCPServer, self).__init__()
-        self.streams = []
+        self.streams = []  # type: List[IOStream]
         self.queue = Queue()
-        sockets = bind_sockets(None, 'localhost', family)
+        sockets = bind_sockets(0, 'localhost', family)
         self.add_sockets(sockets)
         self.port = sockets[0].getsockname()[1]
 
@@ -197,8 +202,9 @@ class ConnectorTest(AsyncTestCase):
 
     def setUp(self):
         super(ConnectorTest, self).setUp()
-        self.connect_futures = {}
-        self.streams = {}
+        self.connect_futures = {} \
+            # type: Dict[Tuple[int, Tuple], Future[ConnectorTest.FakeStream]]
+        self.streams = {}  # type: Dict[Tuple, ConnectorTest.FakeStream]
         self.addrinfo = [(AF1, 'a'), (AF1, 'b'),
                          (AF2, 'c'), (AF2, 'd')]
 
@@ -212,7 +218,7 @@ class ConnectorTest(AsyncTestCase):
     def create_stream(self, af, addr):
         stream = ConnectorTest.FakeStream()
         self.streams[addr] = stream
-        future = Future()
+        future = Future()  # type: Future[ConnectorTest.FakeStream]
         self.connect_futures[(af, addr)] = future
         return stream, future
 
index 7408c780de1a5f016d7eaee192cec9edb3b97edc..d33f66b21582c32a7d6d63844d726c8272d206b4 100644 (file)
@@ -77,7 +77,7 @@ class TCPServerTest(AsyncTestCase):
         class TestServer(TCPServer):
             @gen.coroutine
             def handle_stream(self, stream, address):
-                server.stop()
+                server.stop()  # type: ignore
                 yield stream.read_until_close()
 
         sock, port = bind_unused_port()