import unittest
from tornado import gen
-from tornado.httpclient import AsyncHTTPClient
+from tornado.httpclient import AsyncHTTPClient, HTTPResponse
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop, TimeoutError
from tornado import netutil
from tornado.process import Subprocess
from tornado.log import app_log
from tornado.util import raise_exc_info, basestring_type
+from tornado.web import Application
+
+import typing
+from typing import Tuple, Any, Callable, Type, Dict, Union, Coroutine, Optional
+from types import TracebackType
+
+if typing.TYPE_CHECKING:
+ _ExcInfoTuple = Tuple[Optional[Type[BaseException]], Optional[BaseException],
+ Optional[TracebackType]]
_NON_OWNED_IOLOOPS = AsyncIOMainLoop
-def bind_unused_port(reuse_port=False):
+def bind_unused_port(reuse_port: bool=False) -> Tuple[socket.socket, int]:
"""Binds a server socket to an available port on localhost.
Returns a tuple (socket, port).
return sock, port
-def get_async_test_timeout():
+def get_async_test_timeout() -> float:
"""Get the global timeout setting for async tests.
Returns a float, the timeout in seconds.
.. versionadded:: 3.1
"""
- try:
- return float(os.environ.get('ASYNC_TEST_TIMEOUT'))
- except (ValueError, TypeError):
- return 5
+ env = os.environ.get('ASYNC_TEST_TIMEOUT')
+ if env is not None:
+ try:
+ return float(env)
+ except ValueError:
+ pass
+ return 5
class _TestMethodWrapper(object):
necessarily errors, but we alert anyway since there is no good
reason to return a value from a test).
"""
- def __init__(self, orig_method):
+ def __init__(self, orig_method: Callable) -> None:
self.orig_method = orig_method
- def __call__(self, *args, **kwargs):
+ def __call__(self, *args: Any, **kwargs: Any) -> None:
result = self.orig_method(*args, **kwargs)
if isinstance(result, Generator) or inspect.iscoroutine(result):
raise TypeError("Generator and coroutine test methods should be"
raise ValueError("Return value from test method ignored: %r" %
result)
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
"""Proxy all unknown attributes to the original method.
This is important for some of the decorators in the `unittest`
# Test contents of response
self.assertIn("FriendFeed", response.body)
"""
- def __init__(self, methodName='runTest'):
+ def __init__(self, methodName: str='runTest') -> None:
super(AsyncTestCase, self).__init__(methodName)
self.__stopped = False
self.__running = False
- self.__failure = None
- self.__stop_args = None
+ self.__failure = None # type: Optional[_ExcInfoTuple]
+ self.__stop_args = None # type: Any
self.__timeout = None
# It's easy to forget the @gen_test decorator, but if you do
# make sure it's not an undecorated generator.
setattr(self, methodName, _TestMethodWrapper(getattr(self, methodName)))
- def setUp(self):
+ # Not used in this class itself, but used by @gen_test
+ self._test_generator = None # type: Optional[Union[Generator, Coroutine]]
+
+ def setUp(self) -> None:
super(AsyncTestCase, self).setUp()
self.io_loop = self.get_new_ioloop()
self.io_loop.make_current()
- def tearDown(self):
+ def tearDown(self) -> None:
# Clean up Subprocess, so it can be used again with a new ioloop.
Subprocess.uninitialize()
self.io_loop.clear_current()
# unittest machinery understands.
self.__rethrow()
- def get_new_ioloop(self):
+ def get_new_ioloop(self) -> IOLoop:
"""Returns the `.IOLoop` to use for this test.
By default, a new `.IOLoop` is created for each test.
"""
return IOLoop()
- def _handle_exception(self, typ, value, tb):
+ def _handle_exception(self, typ: Type[Exception], value: Exception, tb: TracebackType) -> bool:
if self.__failure is None:
self.__failure = (typ, value, tb)
else:
self.stop()
return True
- def __rethrow(self):
+ def __rethrow(self) -> None:
if self.__failure is not None:
failure = self.__failure
self.__failure = None
raise_exc_info(failure)
- def run(self, result=None):
- super(AsyncTestCase, self).run(result)
+ def run(self, result: unittest.TestResult=None) -> unittest.TestCase:
+ ret = super(AsyncTestCase, self).run(result)
# As a last resort, if an exception escaped super.run() and wasn't
# re-raised in tearDown, raise it here. This will cause the
# unittest run to fail messily, but that's better than silently
# ignoring an error.
self.__rethrow()
+ return ret
- def stop(self, _arg=None, **kwargs):
+ def stop(self, _arg: Any=None, **kwargs: Any) -> None:
"""Stops the `.IOLoop`, causing one pending (or future) call to `wait()`
to return.
self.__running = False
self.__stopped = True
- def wait(self, condition=None, timeout=None):
+ def wait(self, condition: Callable[..., bool]=None, timeout: float=None) -> None:
"""Runs the `.IOLoop` until stop is called or timeout has passed.
In the event of a timeout, an exception will be thrown. The
if not self.__stopped:
if timeout:
- def timeout_func():
+ def timeout_func() -> None:
try:
raise self.failureException(
'Async operation timed out after %s seconds' %
to do other asynchronous operations in tests, you'll probably need to use
``stop()`` and ``wait()`` yourself.
"""
- def setUp(self):
+ def setUp(self) -> None:
super(AsyncHTTPTestCase, self).setUp()
sock, port = bind_unused_port()
self.__port = port
self.http_server = self.get_http_server()
self.http_server.add_sockets([sock])
- def get_http_client(self):
+ def get_http_client(self) -> AsyncHTTPClient:
return AsyncHTTPClient()
- def get_http_server(self):
+ def get_http_server(self) -> HTTPServer:
return HTTPServer(self._app, **self.get_httpserver_options())
- def get_app(self):
+ def get_app(self) -> Application:
"""Should be overridden by subclasses to return a
`tornado.web.Application` or other `.HTTPServer` callback.
"""
raise NotImplementedError()
- def fetch(self, path, raise_error=False, **kwargs):
+ def fetch(self, path: str, raise_error: bool=False, **kwargs: Any) -> HTTPResponse:
"""Convenience method to synchronously fetch a URL.
The given path will be appended to the local server's host and
lambda: self.http_client.fetch(url, raise_error=raise_error, **kwargs),
timeout=get_async_test_timeout())
- def get_httpserver_options(self):
+ def get_httpserver_options(self) -> Dict[str, Any]:
"""May be overridden by subclasses to return additional
keyword arguments for the server.
"""
return {}
- def get_http_port(self):
+ def get_http_port(self) -> int:
"""Returns the port used by the server.
A new port is chosen for each test.
"""
return self.__port
- def get_protocol(self):
+ def get_protocol(self) -> str:
return 'http'
- def get_url(self, path):
+ def get_url(self, path: str) -> str:
"""Returns an absolute url for the given path on the test server."""
return '%s://127.0.0.1:%s%s' % (self.get_protocol(),
self.get_http_port(), path)
- def tearDown(self):
+ def tearDown(self) -> None:
self.http_server.stop()
self.io_loop.run_sync(self.http_server.close_all_connections,
timeout=get_async_test_timeout())
Interface is generally the same as `AsyncHTTPTestCase`.
"""
- def get_http_client(self):
+ def get_http_client(self) -> AsyncHTTPClient:
return AsyncHTTPClient(force_instance=True,
defaults=dict(validate_cert=False))
- def get_httpserver_options(self):
+ def get_httpserver_options(self) -> Dict[str, Any]:
return dict(ssl_options=self.get_ssl_options())
- def get_ssl_options(self):
+ def get_ssl_options(self) -> Dict[str, Any]:
"""May be overridden by subclasses to select SSL options.
By default includes a self-signed testing certificate.
certfile=os.path.join(module_dir, 'test', 'test.crt'),
keyfile=os.path.join(module_dir, 'test', 'test.key'))
- def get_protocol(self):
+ def get_protocol(self) -> str:
return 'https'
-def gen_test(func=None, timeout=None):
+@typing.overload
+def gen_test(*, timeout: float=None) -> Callable[[Callable[..., Union[Generator, Coroutine]]],
+ Callable[..., None]]:
+ pass
+
+
+@typing.overload # noqa: F811
+def gen_test(func: Callable[..., Union[Generator, Coroutine]]) -> Callable[..., None]:
+ pass
+
+
+def gen_test( # noqa: F811
+ func: Callable[..., Union[Generator, Coroutine]]=None, timeout: float=None,
+) -> Union[Callable[..., None],
+ Callable[[Callable[..., Union[Generator, Coroutine]]], Callable[..., None]]]:
"""Testing equivalent of ``@gen.coroutine``, to be applied to test methods.
``@gen.coroutine`` cannot be used on tests because the `.IOLoop` is not
if timeout is None:
timeout = get_async_test_timeout()
- def wrap(f):
+ def wrap(f: Callable[..., Union[Generator, Coroutine]]) -> Callable[..., None]:
# Stack up several decorators to allow us to access the generator
# object itself. In the innermost wrapper, we capture the generator
# and save it in an attribute of self. Next, we run the wrapped
# extensibility in the gen decorators or cancellation support.
@functools.wraps(f)
def pre_coroutine(self, *args, **kwargs):
+ # type: (AsyncTestCase, *Any, **Any) -> Union[Generator, Coroutine]
+ # Type comments used to avoid pypy3 bug.
result = f(self, *args, **kwargs)
if isinstance(result, Generator) or inspect.iscoroutine(result):
self._test_generator = result
@functools.wraps(coro)
def post_coroutine(self, *args, **kwargs):
+ # type: (AsyncTestCase, *Any, **Any) -> None
try:
return self.io_loop.run_sync(
functools.partial(coro, self, *args, **kwargs),
# point where the test is stopped. The only reason the generator
# would not be running would be if it were cancelled, which means
# a native coroutine, so we can rely on the cr_running attribute.
- if getattr(self._test_generator, 'cr_running', True):
- self._test_generator.throw(e)
+ if (self._test_generator is not None and
+ getattr(self._test_generator, 'cr_running', True)):
+ self._test_generator.throw(type(e), e)
# In case the test contains an overly broad except
# clause, we may get back here.
# Coroutine was stopped or didn't raise a useful stack trace,
.. versionchanged:: 4.3
Added the ``logged_stack`` attribute.
"""
- def __init__(self, logger, regex, required=True):
+ def __init__(self, logger: Union[logging.Logger, basestring_type], regex: str,
+ required: bool=True) -> None:
"""Constructs an ExpectLog context manager.
:param logger: Logger object (or name of logger) to watch. Pass
self.matched = False
self.logged_stack = False
- def filter(self, record):
+ def filter(self, record: logging.LogRecord) -> bool:
if record.exc_info:
self.logged_stack = True
message = record.getMessage()
return False
return True
- def __enter__(self):
+ def __enter__(self) -> logging.Filter:
self.logger.addFilter(self)
return self
- def __exit__(self, typ, value, tb):
+ def __exit__(self, typ: Optional[Type[BaseException]], value: Optional[BaseException],
+ tb: Optional[TracebackType]) -> None:
self.logger.removeFilter(self)
if not typ and self.required and not self.matched:
raise Exception("did not get expected log message")
-def main(**kwargs):
+def main(**kwargs: Any) -> None:
"""A simple test runner.
This test runner is essentially equivalent to `unittest.main` from
# test discovery, which is incompatible with auto2to3), so don't
# set module if we're not asking for a specific test.
if len(argv) > 1:
- unittest.main(module=None, argv=argv, **kwargs)
+ unittest.main(module=None, argv=argv, **kwargs) # type: ignore
else:
unittest.main(defaultTest="all", argv=argv, **kwargs)