]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
netutil: Add type annotations
authorBen Darnell <ben@bendarnell.com>
Sat, 11 Aug 2018 23:35:25 +0000 (19:35 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 11 Aug 2018 23:35:25 +0000 (19:35 -0400)
setup.cfg
tornado/concurrent.py
tornado/gen.py
tornado/netutil.py
tornado/test/netutil_test.py
tornado/testing.py
tornado/util.py

index 9a33f099c84e9c3040aa80a3c7be35fe45abe501..433b6d0d3a9a42a7ea58a76547185ed0d201e281 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -25,6 +25,9 @@ disallow_untyped_defs = True
 [mypy-tornado.log]
 disallow_untyped_defs = True
 
+[mypy-tornado.netutil]
+disallow_untyped_defs = True
+
 [mypy-tornado.options]
 disallow_untyped_defs = True
 
@@ -54,6 +57,9 @@ check_untyped_defs = True
 [mypy-tornado.test.log_test]
 check_untyped_defs = True
 
+[mypy-tornado.test.netutil_test]
+check_untyped_defs = True
+
 [mypy-tornado.test.options_test]
 check_untyped_defs = True
 
index d003ce915b2b86b480642f076adba223028b3dfb..d0dad7a34278cba8cb2b2d26e7ecddbc451c4e2a 100644 (file)
@@ -34,7 +34,7 @@ import sys
 import types
 
 import typing
-from typing import Any, Callable, Optional, Tuple
+from typing import Any, Callable, Optional, Tuple, Union
 
 _T = typing.TypeVar('_T')
 
@@ -53,9 +53,9 @@ def is_future(x: Any) -> bool:
     return isinstance(x, FUTURES)
 
 
-class DummyExecutor(object):
-    def submit(self, fn: Callable[..., _T], *args: Any, **kwargs: Any) -> 'Future[_T]':
-        future = Future()  # type: Future
+class DummyExecutor(futures.Executor):
+    def submit(self, fn: Callable[..., _T], *args: Any, **kwargs: Any) -> 'futures.Future[_T]':
+        future = futures.Future()  # type: futures.Future[_T]
         try:
             future_set_result_unless_cancelled(future, fn(*args, **kwargs))
         except Exception:
@@ -165,7 +165,8 @@ def chain_future(a: 'Future[_T]', b: 'Future[_T]') -> None:
         IOLoop.current().add_future(a, copy)
 
 
-def future_set_result_unless_cancelled(future: 'Future[_T]', value: _T) -> None:
+def future_set_result_unless_cancelled(future: Union['futures.Future[_T]', 'Future[_T]'],
+                                       value: _T) -> None:
     """Set the given ``value`` as the `Future`'s result, if not cancelled.
 
     Avoids asyncio.InvalidStateError when calling set_result() on
@@ -177,7 +178,7 @@ def future_set_result_unless_cancelled(future: 'Future[_T]', value: _T) -> None:
         future.set_result(value)
 
 
-def future_set_exc_info(future: 'Future[_T]',
+def future_set_exc_info(future: Union['futures.Future[_T]', 'Future[_T]'],
                         exc_info: Tuple[Optional[type], Optional[BaseException],
                                         Optional[types.TracebackType]]) -> None:
     """Set the given ``exc_info`` as the `Future`'s exception.
@@ -197,8 +198,20 @@ def future_set_exc_info(future: 'Future[_T]',
         future.set_exception(exc_info[1])
 
 
+@typing.overload
+def future_add_done_callback(future: 'futures.Future[_T]',
+                             callback: Callable[['futures.Future[_T]'], None]) -> None:
+    pass
+
+
+@typing.overload  # noqa: F811
 def future_add_done_callback(future: 'Future[_T]',
                              callback: Callable[['Future[_T]'], None]) -> None:
+    pass
+
+
+def future_add_done_callback(future: Union['futures.Future[_T]', 'Future[_T]'],  # noqa: F811
+                             callback: Callable[..., None]) -> None:
     """Arrange to call ``callback`` when ``future`` is complete.
 
     ``callback`` is invoked with one argument, the ``future``.
index 71c7f71bde3af48efaa232a0f70b63abffda81bb..346e897d648ae0860d82b68733dffee6a63b058b 100644 (file)
@@ -150,7 +150,7 @@ def _create_future() -> Future:
     return future
 
 
-def coroutine(func: Callable[..., _T]) -> Callable[..., 'Future[_T]']:
+def coroutine(func: Callable[..., 'Generator[Any, Any, _T]']) -> Callable[..., 'Future[_T]']:
     """Decorator for asynchronous generators.
 
     Any generator that yields objects from this module must be wrapped
index 7bb587dce70f476a7934703c9c432934be0b2e77..e844facfeb58f02b46163a690f11833478743d20 100644 (file)
@@ -15,6 +15,7 @@
 
 """Miscellaneous network utility code."""
 
+import concurrent.futures
 import errno
 import os
 import sys
@@ -28,6 +29,13 @@ from tornado.ioloop import IOLoop
 from tornado.platform.auto import set_close_exec
 from tornado.util import Configurable, errno_from_exception
 
+import typing
+from typing import List, Callable, Any, Type, Generator, Dict, Union, Tuple
+
+if typing.TYPE_CHECKING:
+    from asyncio import Future  # noqa: F401
+    from typing import Awaitable  # noqa: F401
+
 # Note that the naming of ssl.Purpose is confusing; the purpose
 # of a context is to authentiate the opposite side of the connection.
 _client_ssl_defaults = ssl.create_default_context(
@@ -61,8 +69,10 @@ if hasattr(errno, "WSAEWOULDBLOCK"):
 _DEFAULT_BACKLOG = 128
 
 
-def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
-                 backlog=_DEFAULT_BACKLOG, flags=None, reuse_port=False):
+def bind_sockets(port: int, address: str=None,
+                 family: socket.AddressFamily=socket.AF_UNSPEC,
+                 backlog: int=_DEFAULT_BACKLOG, flags: int=None,
+                 reuse_port: bool=False) -> List[socket.socket]:
     """Creates listening sockets bound to the given port and address.
 
     Returns a list of socket objects (multiple sockets are returned if
@@ -102,7 +112,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
     if flags is None:
         flags = socket.AI_PASSIVE
     bound_port = None
-    unique_addresses = set()
+    unique_addresses = set()  # type: set
     for res in sorted(socket.getaddrinfo(address, port, family, socket.SOCK_STREAM,
                                          0, flags), key=lambda x: x[0]):
         if res in unique_addresses:
@@ -154,7 +164,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
         if requested_port == 0 and bound_port is not None:
             sockaddr = tuple([host, bound_port] + list(sockaddr[2:]))
 
-        sock.setblocking(0)
+        sock.setblocking(False)
         sock.bind(sockaddr)
         bound_port = sock.getsockname()[1]
         sock.listen(backlog)
@@ -163,7 +173,8 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
 
 
 if hasattr(socket, 'AF_UNIX'):
-    def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
+    def bind_unix_socket(file: str, mode: int=0o600,
+                         backlog: int=_DEFAULT_BACKLOG) -> socket.socket:
         """Creates a listening unix socket.
 
         If a socket with the given name already exists, it will be deleted.
@@ -181,7 +192,7 @@ if hasattr(socket, 'AF_UNIX'):
             if errno_from_exception(e) != errno.ENOPROTOOPT:
                 # Hurd doesn't support SO_REUSEADDR
                 raise
-        sock.setblocking(0)
+        sock.setblocking(False)
         try:
             st = os.stat(file)
         except OSError as err:
@@ -198,7 +209,8 @@ if hasattr(socket, 'AF_UNIX'):
         return sock
 
 
-def add_accept_handler(sock, callback):
+def add_accept_handler(sock: socket.socket,
+                       callback: Callable[[socket.socket, Any], None]) -> Callable[[], None]:
     """Adds an `.IOLoop` event handler to accept new connections on ``sock``.
 
     When a connection is accepted, ``callback(connection, address)`` will
@@ -219,7 +231,7 @@ def add_accept_handler(sock, callback):
     io_loop = IOLoop.current()
     removed = [False]
 
-    def accept_handler(fd, events):
+    def accept_handler(fd: int, events: int) -> None:
         # More connections may come in while we're handling callbacks;
         # to prevent starvation of other tasks we must limit the number
         # of connections we accept at a time.  Ideally we would accept
@@ -251,7 +263,7 @@ def add_accept_handler(sock, callback):
             set_close_exec(connection.fileno())
             callback(connection, address)
 
-    def remove_handler():
+    def remove_handler() -> None:
         io_loop.remove_handler(sock)
         removed[0] = True
 
@@ -259,7 +271,7 @@ def add_accept_handler(sock, callback):
     return remove_handler
 
 
-def is_valid_ip(ip):
+def is_valid_ip(ip: str) -> bool:
     """Returns true if the given string is a well-formed IP address.
 
     Supports IPv4 and IPv6.
@@ -304,14 +316,16 @@ class Resolver(Configurable):
        `DefaultExecutorResolver`.
     """
     @classmethod
-    def configurable_base(cls):
+    def configurable_base(cls) -> Type['Resolver']:
         return Resolver
 
     @classmethod
-    def configurable_default(cls):
+    def configurable_default(cls) -> Type['Resolver']:
         return DefaultExecutorResolver
 
-    def resolve(self, host, port, family=socket.AF_UNSPEC):
+    def resolve(
+            self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC,
+    ) -> 'Future[List[Tuple[int, Any]]]':
         """Resolves an address.
 
         The ``host`` argument is a string which may be a hostname or a
@@ -335,7 +349,7 @@ class Resolver(Configurable):
         """
         raise NotImplementedError()
 
-    def close(self):
+    def close(self) -> None:
         """Closes the `Resolver`, freeing any resources used.
 
         .. versionadded:: 3.1
@@ -344,7 +358,9 @@ class Resolver(Configurable):
         pass
 
 
-def _resolve_addr(host, port, family=socket.AF_UNSPEC):
+def _resolve_addr(
+        host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC,
+) -> List[Tuple[int, Any]]:
     # On Solaris, getaddrinfo fails if the given port is not found
     # in /etc/services and no socket type is given, so we must pass
     # one here.  The socket type used here doesn't seem to actually
@@ -352,8 +368,8 @@ def _resolve_addr(host, port, family=socket.AF_UNSPEC):
     # so the addresses we return should still be usable with SOCK_DGRAM.
     addrinfo = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM)
     results = []
-    for family, socktype, proto, canonname, address in addrinfo:
-        results.append((family, address))
+    for fam, socktype, proto, canonname, address in addrinfo:
+        results.append((fam, address))
     return results
 
 
@@ -363,10 +379,12 @@ class DefaultExecutorResolver(Resolver):
     .. versionadded:: 5.0
     """
     @gen.coroutine
-    def resolve(self, host, port, family=socket.AF_UNSPEC):
+    def resolve(
+            self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC,
+    ) -> Generator[Any, Any, List[Tuple[int, Any]]]:
         result = yield IOLoop.current().run_in_executor(
             None, _resolve_addr, host, port, family)
-        raise gen.Return(result)
+        return result
 
 
 class ExecutorResolver(Resolver):
@@ -386,7 +404,8 @@ class ExecutorResolver(Resolver):
        The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
        of this class.
     """
-    def initialize(self, executor=None, close_executor=True):
+    def initialize(self, executor: concurrent.futures.Executor=None,
+                   close_executor: bool=True) -> None:
         self.io_loop = IOLoop.current()
         if executor is not None:
             self.executor = executor
@@ -395,13 +414,15 @@ class ExecutorResolver(Resolver):
             self.executor = dummy_executor
             self.close_executor = False
 
-    def close(self):
+    def close(self) -> None:
         if self.close_executor:
             self.executor.shutdown()
-        self.executor = None
+        self.executor = None  # type: ignore
 
     @run_on_executor
-    def resolve(self, host, port, family=socket.AF_UNSPEC):
+    def resolve(
+            self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC,
+    ) -> List[Tuple[int, Any]]:
         return _resolve_addr(host, port, family)
 
 
@@ -415,7 +436,7 @@ class BlockingResolver(ExecutorResolver):
        The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
        of this class.
     """
-    def initialize(self):
+    def initialize(self) -> None:  # type: ignore
         super(BlockingResolver, self).initialize()
 
 
@@ -442,21 +463,20 @@ class ThreadedResolver(ExecutorResolver):
     _threadpool = None  # type: ignore
     _threadpool_pid = None  # type: int
 
-    def initialize(self, num_threads=10):
+    def initialize(self, num_threads: int=10) -> None:  # type: ignore
         threadpool = ThreadedResolver._create_threadpool(num_threads)
         super(ThreadedResolver, self).initialize(
             executor=threadpool, close_executor=False)
 
     @classmethod
-    def _create_threadpool(cls, num_threads):
+    def _create_threadpool(cls, num_threads: int) -> concurrent.futures.ThreadPoolExecutor:
         pid = os.getpid()
         if cls._threadpool_pid != pid:
             # Threads cannot survive after a fork, so if our pid isn't what it
             # was when we created the pool then delete it.
             cls._threadpool = None
         if cls._threadpool is None:
-            from concurrent.futures import ThreadPoolExecutor
-            cls._threadpool = ThreadPoolExecutor(num_threads)
+            cls._threadpool = concurrent.futures.ThreadPoolExecutor(num_threads)
             cls._threadpool_pid = pid
         return cls._threadpool
 
@@ -483,21 +503,23 @@ class OverrideResolver(Resolver):
     .. versionchanged:: 5.0
        Added support for host-port-family triplets.
     """
-    def initialize(self, resolver, mapping):
+    def initialize(self, resolver: Resolver, mapping: dict) -> None:  # type: ignore
         self.resolver = resolver
         self.mapping = mapping
 
-    def close(self):
+    def close(self) -> None:
         self.resolver.close()
 
-    def resolve(self, host, port, family=socket.AF_UNSPEC, *args, **kwargs):
+    def resolve(
+            self, host: str, port: int, family: socket.AddressFamily=socket.AF_UNSPEC,
+    ) -> 'Future[List[Tuple[int, Any]]]':
         if (host, port, family) in self.mapping:
             host, port = self.mapping[(host, port, family)]
         elif (host, port) in self.mapping:
             host, port = self.mapping[(host, port)]
         elif host in self.mapping:
             host = self.mapping[host]
-        return self.resolver.resolve(host, port, family, *args, **kwargs)
+        return self.resolver.resolve(host, port, family)
 
 
 # These are the keyword arguments to ssl.wrap_socket that must be translated
@@ -507,7 +529,7 @@ _SSL_CONTEXT_KEYWORDS = frozenset(['ssl_version', 'certfile', 'keyfile',
                                    'cert_reqs', 'ca_certs', 'ciphers'])
 
 
-def ssl_options_to_context(ssl_options):
+def ssl_options_to_context(ssl_options: Union[Dict[str, Any], ssl.SSLContext]) -> ssl.SSLContext:
     """Try to convert an ``ssl_options`` dictionary to an
     `~ssl.SSLContext` object.
 
@@ -543,7 +565,8 @@ def ssl_options_to_context(ssl_options):
     return context
 
 
-def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
+def ssl_wrap_socket(socket: socket.socket, ssl_options: Union[Dict[str, Any], ssl.SSLContext],
+                    server_hostname: str=None, **kwargs: Any) -> ssl.SSLSocket:
     """Returns an ``ssl.SSLSocket`` wrapping the given socket.
 
     ``ssl_options`` may be either an `ssl.SSLContext` object or a
index 6686e9fbaaa5a17a5eece9f94813494b2d642e46..c17356095f37e14746f88c1b2b73c0fa67b147b6 100644 (file)
@@ -13,6 +13,10 @@ from tornado.netutil import (
 from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
 from tornado.test.util import skipIfNoNetwork
 
+import typing
+if typing.TYPE_CHECKING:
+    from typing import List  # noqa: F401
+
 try:
     import pycares  # type: ignore
 except ImportError:
@@ -193,7 +197,7 @@ class TestPortAllocation(unittest.TestCase):
     def test_same_port_allocation(self):
         if 'TRAVIS' in os.environ:
             self.skipTest("dual-stack servers often have port conflicts on travis")
-        sockets = bind_sockets(None, 'localhost')
+        sockets = bind_sockets(0, 'localhost')
         try:
             port = sockets[0].getsockname()[1]
             self.assertTrue(all(s.getsockname()[1] == port
@@ -204,7 +208,7 @@ class TestPortAllocation(unittest.TestCase):
 
     @unittest.skipIf(not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported")
     def test_reuse_port(self):
-        sockets = []
+        sockets = []  # type: List[socket.socket]
         socket, port = bind_unused_port(reuse_port=True)
         try:
             sockets = bind_sockets(port, '127.0.0.1', reuse_port=True)
index 40a6e759373a4522b53046da17869ba7e18ccf8b..8c6e79ffeb0c7cb263b281cc763dca450e59ecbd 100644 (file)
@@ -52,7 +52,7 @@ def bind_unused_port(reuse_port: bool=False) -> Tuple[socket.socket, int]:
        Always binds to ``127.0.0.1`` without resolving the name
        ``localhost``.
     """
-    sock = netutil.bind_sockets(None, '127.0.0.1', family=socket.AF_INET,
+    sock = netutil.bind_sockets(0, '127.0.0.1', family=socket.AF_INET,
                                 reuse_port=reuse_port)[0]
     port = sock.getsockname()[1]
     return sock, port
index c10133116b1c0e81b1e9cc736044f47f151641c0..fad8d47e7c11879c2c35983629bb60200e1d8496 100644 (file)
@@ -296,6 +296,10 @@ class Configurable(object):
 
         Configurable classes should use `initialize` instead of ``__init__``.
 
+        When used with ``mypy``, subclasses will often need ``# type: ignore``
+        annotations on this method because ``mypy`` does not recognize that
+        its arguments may change in subclasses (as it does for ``__init__``).
+
         .. versionchanged:: 4.2
            Now accepts positional arguments in addition to keyword arguments.
         """