]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
testing: Type-annotate the module
authorBen Darnell <ben@bendarnell.com>
Sat, 11 Aug 2018 20:40:07 +0000 (16:40 -0400)
committerBen Darnell <ben@bendarnell.com>
Sat, 11 Aug 2018 20:40:07 +0000 (16:40 -0400)
setup.cfg
tornado/testing.py

index 26d19f2fe6cc3e23877235e13ad882d736ebb9cd..13a4ce3f8b2510967d182dc60d71c92d002ff130 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -19,6 +19,9 @@ disallow_untyped_defs = True
 [mypy-tornado.gen]
 disallow_untyped_defs = True
 
+[mypy-tornado.testing]
+disallow_untyped_defs = True
+
 # It's generally too tedious to require type annotations in tests, but
 # we do want to type check them as much as type inference allows.
 [mypy-tornado.test.util_test]
@@ -35,3 +38,6 @@ check_untyped_defs = True
 
 [mypy-tornado.test.gen_test]
 check_untyped_defs = True
+
+[mypy-tornado.test.testing_test]
+check_untyped_defs = True
index 4521ea9cc902f6fb31b18aa72f11e584217c39fc..40a6e759373a4522b53046da17869ba7e18ccf8b 100644 (file)
@@ -21,7 +21,7 @@ import sys
 import unittest
 
 from tornado import gen
-from tornado.httpclient import AsyncHTTPClient
+from tornado.httpclient import AsyncHTTPClient, HTTPResponse
 from tornado.httpserver import HTTPServer
 from tornado.ioloop import IOLoop, TimeoutError
 from tornado import netutil
@@ -29,12 +29,21 @@ from tornado.platform.asyncio import AsyncIOMainLoop
 from tornado.process import Subprocess
 from tornado.log import app_log
 from tornado.util import raise_exc_info, basestring_type
+from tornado.web import Application
+
+import typing
+from typing import Tuple, Any, Callable, Type, Dict, Union, Coroutine, Optional
+from types import TracebackType
+
+if typing.TYPE_CHECKING:
+    _ExcInfoTuple = Tuple[Optional[Type[BaseException]], Optional[BaseException],
+                          Optional[TracebackType]]
 
 
 _NON_OWNED_IOLOOPS = AsyncIOMainLoop
 
 
-def bind_unused_port(reuse_port=False):
+def bind_unused_port(reuse_port: bool=False) -> Tuple[socket.socket, int]:
     """Binds a server socket to an available port on localhost.
 
     Returns a tuple (socket, port).
@@ -49,17 +58,20 @@ def bind_unused_port(reuse_port=False):
     return sock, port
 
 
-def get_async_test_timeout():
+def get_async_test_timeout() -> float:
     """Get the global timeout setting for async tests.
 
     Returns a float, the timeout in seconds.
 
     .. versionadded:: 3.1
     """
-    try:
-        return float(os.environ.get('ASYNC_TEST_TIMEOUT'))
-    except (ValueError, TypeError):
-        return 5
+    env = os.environ.get('ASYNC_TEST_TIMEOUT')
+    if env is not None:
+        try:
+            return float(env)
+        except ValueError:
+            pass
+    return 5
 
 
 class _TestMethodWrapper(object):
@@ -71,10 +83,10 @@ class _TestMethodWrapper(object):
     necessarily errors, but we alert anyway since there is no good
     reason to return a value from a test).
     """
-    def __init__(self, orig_method):
+    def __init__(self, orig_method: Callable) -> None:
         self.orig_method = orig_method
 
-    def __call__(self, *args, **kwargs):
+    def __call__(self, *args: Any, **kwargs: Any) -> None:
         result = self.orig_method(*args, **kwargs)
         if isinstance(result, Generator) or inspect.iscoroutine(result):
             raise TypeError("Generator and coroutine test methods should be"
@@ -83,7 +95,7 @@ class _TestMethodWrapper(object):
             raise ValueError("Return value from test method ignored: %r" %
                              result)
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:
         """Proxy all unknown attributes to the original method.
 
         This is important for some of the decorators in the `unittest`
@@ -138,12 +150,12 @@ class AsyncTestCase(unittest.TestCase):
                 # Test contents of response
                 self.assertIn("FriendFeed", response.body)
     """
-    def __init__(self, methodName='runTest'):
+    def __init__(self, methodName: str='runTest') -> None:
         super(AsyncTestCase, self).__init__(methodName)
         self.__stopped = False
         self.__running = False
-        self.__failure = None
-        self.__stop_args = None
+        self.__failure = None  # type: Optional[_ExcInfoTuple]
+        self.__stop_args = None  # type: Any
         self.__timeout = None
 
         # It's easy to forget the @gen_test decorator, but if you do
@@ -152,12 +164,15 @@ class AsyncTestCase(unittest.TestCase):
         # make sure it's not an undecorated generator.
         setattr(self, methodName, _TestMethodWrapper(getattr(self, methodName)))
 
-    def setUp(self):
+        # Not used in this class itself, but used by @gen_test
+        self._test_generator = None  # type: Optional[Union[Generator, Coroutine]]
+
+    def setUp(self) -> None:
         super(AsyncTestCase, self).setUp()
         self.io_loop = self.get_new_ioloop()
         self.io_loop.make_current()
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         # Clean up Subprocess, so it can be used again with a new ioloop.
         Subprocess.uninitialize()
         self.io_loop.clear_current()
@@ -174,7 +189,7 @@ class AsyncTestCase(unittest.TestCase):
         # unittest machinery understands.
         self.__rethrow()
 
-    def get_new_ioloop(self):
+    def get_new_ioloop(self) -> IOLoop:
         """Returns the `.IOLoop` to use for this test.
 
         By default, a new `.IOLoop` is created for each test.
@@ -187,7 +202,7 @@ class AsyncTestCase(unittest.TestCase):
         """
         return IOLoop()
 
-    def _handle_exception(self, typ, value, tb):
+    def _handle_exception(self, typ: Type[Exception], value: Exception, tb: TracebackType) -> bool:
         if self.__failure is None:
             self.__failure = (typ, value, tb)
         else:
@@ -196,21 +211,22 @@ class AsyncTestCase(unittest.TestCase):
         self.stop()
         return True
 
-    def __rethrow(self):
+    def __rethrow(self) -> None:
         if self.__failure is not None:
             failure = self.__failure
             self.__failure = None
             raise_exc_info(failure)
 
-    def run(self, result=None):
-        super(AsyncTestCase, self).run(result)
+    def run(self, result: unittest.TestResult=None) -> unittest.TestCase:
+        ret = super(AsyncTestCase, self).run(result)
         # As a last resort, if an exception escaped super.run() and wasn't
         # re-raised in tearDown, raise it here.  This will cause the
         # unittest run to fail messily, but that's better than silently
         # ignoring an error.
         self.__rethrow()
+        return ret
 
-    def stop(self, _arg=None, **kwargs):
+    def stop(self, _arg: Any=None, **kwargs: Any) -> None:
         """Stops the `.IOLoop`, causing one pending (or future) call to `wait()`
         to return.
 
@@ -228,7 +244,7 @@ class AsyncTestCase(unittest.TestCase):
             self.__running = False
         self.__stopped = True
 
-    def wait(self, condition=None, timeout=None):
+    def wait(self, condition: Callable[..., bool]=None, timeout: float=None) -> None:
         """Runs the `.IOLoop` until stop is called or timeout has passed.
 
         In the event of a timeout, an exception will be thrown. The
@@ -251,7 +267,7 @@ class AsyncTestCase(unittest.TestCase):
 
         if not self.__stopped:
             if timeout:
-                def timeout_func():
+                def timeout_func() -> None:
                     try:
                         raise self.failureException(
                             'Async operation timed out after %s seconds' %
@@ -310,7 +326,7 @@ class AsyncHTTPTestCase(AsyncTestCase):
     to do other asynchronous operations in tests, you'll probably need to use
     ``stop()`` and ``wait()`` yourself.
     """
-    def setUp(self):
+    def setUp(self) -> None:
         super(AsyncHTTPTestCase, self).setUp()
         sock, port = bind_unused_port()
         self.__port = port
@@ -320,19 +336,19 @@ class AsyncHTTPTestCase(AsyncTestCase):
         self.http_server = self.get_http_server()
         self.http_server.add_sockets([sock])
 
-    def get_http_client(self):
+    def get_http_client(self) -> AsyncHTTPClient:
         return AsyncHTTPClient()
 
-    def get_http_server(self):
+    def get_http_server(self) -> HTTPServer:
         return HTTPServer(self._app, **self.get_httpserver_options())
 
-    def get_app(self):
+    def get_app(self) -> Application:
         """Should be overridden by subclasses to return a
         `tornado.web.Application` or other `.HTTPServer` callback.
         """
         raise NotImplementedError()
 
-    def fetch(self, path, raise_error=False, **kwargs):
+    def fetch(self, path: str, raise_error: bool=False, **kwargs: Any) -> HTTPResponse:
         """Convenience method to synchronously fetch a URL.
 
         The given path will be appended to the local server's host and
@@ -374,28 +390,28 @@ class AsyncHTTPTestCase(AsyncTestCase):
             lambda: self.http_client.fetch(url, raise_error=raise_error, **kwargs),
             timeout=get_async_test_timeout())
 
-    def get_httpserver_options(self):
+    def get_httpserver_options(self) -> Dict[str, Any]:
         """May be overridden by subclasses to return additional
         keyword arguments for the server.
         """
         return {}
 
-    def get_http_port(self):
+    def get_http_port(self) -> int:
         """Returns the port used by the server.
 
         A new port is chosen for each test.
         """
         return self.__port
 
-    def get_protocol(self):
+    def get_protocol(self) -> str:
         return 'http'
 
-    def get_url(self, path):
+    def get_url(self, path: str) -> str:
         """Returns an absolute url for the given path on the test server."""
         return '%s://127.0.0.1:%s%s' % (self.get_protocol(),
                                         self.get_http_port(), path)
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         self.http_server.stop()
         self.io_loop.run_sync(self.http_server.close_all_connections,
                               timeout=get_async_test_timeout())
@@ -408,14 +424,14 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase):
 
     Interface is generally the same as `AsyncHTTPTestCase`.
     """
-    def get_http_client(self):
+    def get_http_client(self) -> AsyncHTTPClient:
         return AsyncHTTPClient(force_instance=True,
                                defaults=dict(validate_cert=False))
 
-    def get_httpserver_options(self):
+    def get_httpserver_options(self) -> Dict[str, Any]:
         return dict(ssl_options=self.get_ssl_options())
 
-    def get_ssl_options(self):
+    def get_ssl_options(self) -> Dict[str, Any]:
         """May be overridden by subclasses to select SSL options.
 
         By default includes a self-signed testing certificate.
@@ -428,11 +444,25 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase):
             certfile=os.path.join(module_dir, 'test', 'test.crt'),
             keyfile=os.path.join(module_dir, 'test', 'test.key'))
 
-    def get_protocol(self):
+    def get_protocol(self) -> str:
         return 'https'
 
 
-def gen_test(func=None, timeout=None):
+@typing.overload
+def gen_test(*, timeout: float=None) -> Callable[[Callable[..., Union[Generator, Coroutine]]],
+                                                 Callable[..., None]]:
+    pass
+
+
+@typing.overload  # noqa: F811
+def gen_test(func: Callable[..., Union[Generator, Coroutine]]) -> Callable[..., None]:
+    pass
+
+
+def gen_test(  # noqa: F811
+        func: Callable[..., Union[Generator, Coroutine]]=None, timeout: float=None,
+) -> Union[Callable[..., None],
+           Callable[[Callable[..., Union[Generator, Coroutine]]], Callable[..., None]]]:
     """Testing equivalent of ``@gen.coroutine``, to be applied to test methods.
 
     ``@gen.coroutine`` cannot be used on tests because the `.IOLoop` is not
@@ -471,7 +501,7 @@ def gen_test(func=None, timeout=None):
     if timeout is None:
         timeout = get_async_test_timeout()
 
-    def wrap(f):
+    def wrap(f: Callable[..., Union[Generator, Coroutine]]) -> Callable[..., None]:
         # Stack up several decorators to allow us to access the generator
         # object itself.  In the innermost wrapper, we capture the generator
         # and save it in an attribute of self.  Next, we run the wrapped
@@ -482,6 +512,8 @@ def gen_test(func=None, timeout=None):
         # extensibility in the gen decorators or cancellation support.
         @functools.wraps(f)
         def pre_coroutine(self, *args, **kwargs):
+            # type: (AsyncTestCase, *Any, **Any) -> Union[Generator, Coroutine]
+            # Type comments used to avoid pypy3 bug.
             result = f(self, *args, **kwargs)
             if isinstance(result, Generator) or inspect.iscoroutine(result):
                 self._test_generator = result
@@ -496,6 +528,7 @@ def gen_test(func=None, timeout=None):
 
         @functools.wraps(coro)
         def post_coroutine(self, *args, **kwargs):
+            # type: (AsyncTestCase, *Any, **Any) -> None
             try:
                 return self.io_loop.run_sync(
                     functools.partial(coro, self, *args, **kwargs),
@@ -507,8 +540,9 @@ def gen_test(func=None, timeout=None):
                 # point where the test is stopped. The only reason the generator
                 # would not be running would be if it were cancelled, which means
                 # a native coroutine, so we can rely on the cr_running attribute.
-                if getattr(self._test_generator, 'cr_running', True):
-                    self._test_generator.throw(e)
+                if (self._test_generator is not None and
+                        getattr(self._test_generator, 'cr_running', True)):
+                    self._test_generator.throw(type(e), e)
                     # In case the test contains an overly broad except
                     # clause, we may get back here.
                 # Coroutine was stopped or didn't raise a useful stack trace,
@@ -549,7 +583,8 @@ class ExpectLog(logging.Filter):
     .. versionchanged:: 4.3
        Added the ``logged_stack`` attribute.
     """
-    def __init__(self, logger, regex, required=True):
+    def __init__(self, logger: Union[logging.Logger, basestring_type], regex: str,
+                 required: bool=True) -> None:
         """Constructs an ExpectLog context manager.
 
         :param logger: Logger object (or name of logger) to watch.  Pass
@@ -567,7 +602,7 @@ class ExpectLog(logging.Filter):
         self.matched = False
         self.logged_stack = False
 
-    def filter(self, record):
+    def filter(self, record: logging.LogRecord) -> bool:
         if record.exc_info:
             self.logged_stack = True
         message = record.getMessage()
@@ -576,17 +611,18 @@ class ExpectLog(logging.Filter):
             return False
         return True
 
-    def __enter__(self):
+    def __enter__(self) -> logging.Filter:
         self.logger.addFilter(self)
         return self
 
-    def __exit__(self, typ, value, tb):
+    def __exit__(self, typ: Optional[Type[BaseException]], value: Optional[BaseException],
+                 tb: Optional[TracebackType]) -> None:
         self.logger.removeFilter(self)
         if not typ and self.required and not self.matched:
             raise Exception("did not get expected log message")
 
 
-def main(**kwargs):
+def main(**kwargs: Any) -> None:
     """A simple test runner.
 
     This test runner is essentially equivalent to `unittest.main` from
@@ -667,7 +703,7 @@ def main(**kwargs):
     # test discovery, which is incompatible with auto2to3), so don't
     # set module if we're not asking for a specific test.
     if len(argv) > 1:
-        unittest.main(module=None, argv=argv, **kwargs)
+        unittest.main(module=None, argv=argv, **kwargs)  # type: ignore
     else:
         unittest.main(defaultTest="all", argv=argv, **kwargs)