From: Ben Darnell Date: Mon, 17 Sep 2018 12:54:20 +0000 (-0400) Subject: locks,queues: Add type annotations X-Git-Tag: v6.0.0b1~28^2~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8b18826c29c592abf2eb40b2ecb68b488297b0d0;p=thirdparty%2Ftornado.git locks,queues: Add type annotations --- diff --git a/setup.cfg b/setup.cfg index 91eb964ef..527646a63 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ python_version = 3.5 [mypy-tornado.*,tornado.platform.*] disallow_untyped_defs = True -[mypy-tornado.auth,tornado.curl_httpclient,tornado.httpclient,tornado.locks,tornado.queues,tornado.routing,tornado.simple_httpclient,tornado.template,tornado.web,tornado.websocket,tornado.wsgi] +[mypy-tornado.auth,tornado.curl_httpclient,tornado.httpclient,tornado.routing,tornado.simple_httpclient,tornado.template,tornado.web,tornado.websocket,tornado.wsgi] disallow_untyped_defs = False # It's generally too tedious to require type annotations in tests, but diff --git a/tornado/locks.py b/tornado/locks.py index 711c6b320..e1b518aad 100644 --- a/tornado/locks.py +++ b/tornado/locks.py @@ -14,10 +14,17 @@ import collections from concurrent.futures import CancelledError +import datetime +import types from tornado import gen, ioloop from tornado.concurrent import Future, future_set_result_unless_cancelled +from typing import Union, Optional, Type, Any, Generator +import typing +if typing.TYPE_CHECKING: + from typing import Deque, Set # noqa: F401 + __all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock'] @@ -30,11 +37,11 @@ class _TimeoutGarbageCollector(object): yield condition.wait(short_timeout) print('looping....') """ - def __init__(self): - self._waiters = collections.deque() # Futures. + def __init__(self) -> None: + self._waiters = collections.deque() # type: Deque[Future] self._timeouts = 0 - def _garbage_collect(self): + def _garbage_collect(self) -> None: # Occasionally clear timed-out waiters. self._timeouts += 1 if self._timeouts > 100: @@ -103,26 +110,26 @@ class Condition(_TimeoutGarbageCollector): next iteration of the `.IOLoop`. """ - def __init__(self): + def __init__(self) -> None: super(Condition, self).__init__() self.io_loop = ioloop.IOLoop.current() - def __repr__(self): + def __repr__(self) -> str: result = '<%s' % (self.__class__.__name__, ) if self._waiters: result += ' waiters[%s]' % len(self._waiters) return result + '>' - def wait(self, timeout=None): + def wait(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[bool]': """Wait for `.notify`. Returns a `.Future` that resolves ``True`` if the condition is notified, or ``False`` after a timeout. """ - waiter = Future() + waiter = Future() # type: Future[bool] self._waiters.append(waiter) if timeout: - def on_timeout(): + def on_timeout() -> None: if not waiter.done(): future_set_result_unless_cancelled(waiter, False) self._garbage_collect() @@ -132,7 +139,7 @@ class Condition(_TimeoutGarbageCollector): lambda _: io_loop.remove_timeout(timeout_handle)) return waiter - def notify(self, n=1): + def notify(self, n: int=1) -> None: """Wake ``n`` waiters.""" waiters = [] # Waiters we plan to run right now. while n and self._waiters: @@ -144,7 +151,7 @@ class Condition(_TimeoutGarbageCollector): for waiter in waiters: future_set_result_unless_cancelled(waiter, True) - def notify_all(self): + def notify_all(self) -> None: """Wake all waiters.""" self.notify(len(self._waiters)) @@ -188,19 +195,19 @@ class Event(object): Not waiting this time Done """ - def __init__(self): + def __init__(self) -> None: self._value = False - self._waiters = set() + self._waiters = set() # type: Set[Future[None]] - def __repr__(self): + def __repr__(self) -> str: return '<%s %s>' % ( self.__class__.__name__, 'set' if self.is_set() else 'clear') - def is_set(self): + def is_set(self) -> bool: """Return ``True`` if the internal flag is true.""" return self._value - def set(self): + def set(self) -> None: """Set the internal flag to ``True``. All waiters are awakened. Calling `.wait` once the flag is set will not block. @@ -212,20 +219,20 @@ class Event(object): if not fut.done(): fut.set_result(None) - def clear(self): + def clear(self) -> None: """Reset the internal flag to ``False``. Calls to `.wait` will block until `.set` is called. """ self._value = False - def wait(self, timeout=None): + def wait(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[None]': """Block until the internal flag is true. Returns a Future, which raises `tornado.util.TimeoutError` after a timeout. """ - fut = Future() + fut = Future() # type: Future[None] if self._value: fut.set_result(None) return fut @@ -250,13 +257,15 @@ class _ReleasingContextManager(object): # Now semaphore.release() has been called. """ - def __init__(self, obj): + def __init__(self, obj: Any) -> None: self._obj = obj - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[types.TracebackType]) -> None: self._obj.release() @@ -355,14 +364,14 @@ class Semaphore(_TimeoutGarbageCollector): Added ``async with`` support in Python 3.5. """ - def __init__(self, value=1): + def __init__(self, value: int=1) -> None: super(Semaphore, self).__init__() if value < 0: raise ValueError('semaphore initial value must be >= 0') self._value = value - def __repr__(self): + def __repr__(self) -> str: res = super(Semaphore, self).__repr__() extra = 'locked' if self._value == 0 else 'unlocked,value:{0}'.format( self._value) @@ -370,7 +379,7 @@ class Semaphore(_TimeoutGarbageCollector): extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) return '<{0} [{1}]>'.format(res[1:-1], extra) - def release(self): + def release(self) -> None: """Increment the counter and wake one waiter.""" self._value += 1 while self._waiters: @@ -387,20 +396,22 @@ class Semaphore(_TimeoutGarbageCollector): waiter.set_result(_ReleasingContextManager(self)) break - def acquire(self, timeout=None): + def acquire( + self, timeout: Union[float, datetime.timedelta]=None, + ) -> 'Future[_ReleasingContextManager]': """Decrement the counter. Returns a Future. Block if the counter is zero and wait for a `.release`. The Future raises `.TimeoutError` after the deadline. """ - waiter = Future() + waiter = Future() # type: Future[_ReleasingContextManager] if self._value > 0: self._value -= 1 waiter.set_result(_ReleasingContextManager(self)) else: self._waiters.append(waiter) if timeout: - def on_timeout(): + def on_timeout() -> None: if not waiter.done(): waiter.set_exception(gen.TimeoutError()) self._garbage_collect() @@ -410,20 +421,23 @@ class Semaphore(_TimeoutGarbageCollector): lambda _: io_loop.remove_timeout(timeout_handle)) return waiter - def __enter__(self): + def __enter__(self) -> None: raise RuntimeError( "Use Semaphore like 'with (yield semaphore.acquire())', not like" " 'with semaphore'") - def __exit__(self, typ, value, traceback): + def __exit__(self, typ: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[types.TracebackType]) -> None: self.__enter__() @gen.coroutine - def __aenter__(self): + def __aenter__(self) -> Generator[Any, Any, None]: yield self.acquire() - @gen.coroutine - def __aexit__(self, typ, value, tb): + async def __aexit__(self, typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[types.TracebackType]) -> None: self.release() @@ -435,11 +449,11 @@ class BoundedSemaphore(Semaphore): resources with limited capacity, so a semaphore released too many times is a sign of a bug. """ - def __init__(self, value=1): + def __init__(self, value: int=1) -> None: super(BoundedSemaphore, self).__init__(value=value) self._initial_value = value - def release(self): + def release(self) -> None: """Increment the counter and wake one waiter.""" if self._value >= self._initial_value: raise ValueError("Semaphore released too many times") @@ -482,15 +496,17 @@ class Lock(object): Added ``async with`` support in Python 3.5. """ - def __init__(self): + def __init__(self) -> None: self._block = BoundedSemaphore(value=1) - def __repr__(self): + def __repr__(self) -> str: return "<%s _block=%s>" % ( self.__class__.__name__, self._block) - def acquire(self, timeout=None): + def acquire( + self, timeout: Union[float, datetime.timedelta]=None, + ) -> 'Future[_ReleasingContextManager]': """Attempt to lock. Returns a Future. Returns a Future, which raises `tornado.util.TimeoutError` after a @@ -498,7 +514,7 @@ class Lock(object): """ return self._block.acquire(timeout) - def release(self): + def release(self) -> None: """Unlock. The first coroutine in line waiting for `acquire` gets the lock. @@ -510,17 +526,20 @@ class Lock(object): except ValueError: raise RuntimeError('release unlocked lock') - def __enter__(self): + def __enter__(self) -> None: raise RuntimeError( "Use Lock like 'with (yield lock)', not like 'with lock'") - def __exit__(self, typ, value, tb): + def __exit__(self, typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[types.TracebackType]) -> None: self.__enter__() @gen.coroutine - def __aenter__(self): + def __aenter__(self) -> Generator[Any, Any, None]: yield self.acquire() - @gen.coroutine - def __aexit__(self, typ, value, tb): + async def __aexit__(self, typ: Optional[Type[BaseException]], + value: Optional[BaseException], + tb: Optional[types.TracebackType]) -> None: self.release() diff --git a/tornado/queues.py b/tornado/queues.py index bc3338447..49fa9e602 100644 --- a/tornado/queues.py +++ b/tornado/queues.py @@ -26,12 +26,20 @@ to those provided in the standard library's `asyncio package """ import collections +import datetime import heapq from tornado import gen, ioloop from tornado.concurrent import Future, future_set_result_unless_cancelled from tornado.locks import Event +from typing import Union, TypeVar, Generic, Awaitable +import typing +if typing.TYPE_CHECKING: + from typing import Deque, Tuple, List, Any # noqa: F401 + +_T = TypeVar('_T') + __all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] @@ -45,9 +53,9 @@ class QueueFull(Exception): pass -def _set_timeout(future, timeout): +def _set_timeout(future: Future, timeout: Union[None, float, datetime.timedelta]) -> None: if timeout: - def on_timeout(): + def on_timeout() -> None: if not future.done(): future.set_exception(gen.TimeoutError()) io_loop = ioloop.IOLoop.current() @@ -56,15 +64,15 @@ def _set_timeout(future, timeout): lambda _: io_loop.remove_timeout(timeout_handle)) -class _QueueIterator(object): - def __init__(self, q): +class _QueueIterator(Generic[_T]): + def __init__(self, q: 'Queue[_T]') -> None: self.q = q - def __anext__(self): + def __anext__(self) -> Awaitable[_T]: return self.q.get() -class Queue(object): +class Queue(Generic[_T]): """Coordinate producer and consumer coroutines. If maxsize is 0 (the default) the queue size is unbounded. @@ -131,7 +139,11 @@ class Queue(object): Added ``async for`` support in Python 3.5. """ - def __init__(self, maxsize=0): + # Exact type depends on subclass. Could be another generic + # parameter and use protocols to be more precise here. + _queue = None # type: Any + + def __init__(self, maxsize: int=0) -> None: if maxsize is None: raise TypeError("maxsize can't be None") @@ -140,31 +152,31 @@ class Queue(object): self._maxsize = maxsize self._init() - self._getters = collections.deque([]) # Futures. - self._putters = collections.deque([]) # Pairs of (item, Future). + self._getters = collections.deque([]) # type: Deque[Future[_T]] + self._putters = collections.deque([]) # type: Deque[Tuple[_T, Future[None]]] self._unfinished_tasks = 0 self._finished = Event() self._finished.set() @property - def maxsize(self): + def maxsize(self) -> int: """Number of items allowed in the queue.""" return self._maxsize - def qsize(self): + def qsize(self) -> int: """Number of items in the queue.""" return len(self._queue) - def empty(self): + def empty(self) -> bool: return not self._queue - def full(self): + def full(self) -> bool: if self.maxsize == 0: return False else: return self.qsize() >= self.maxsize - def put(self, item, timeout=None): + def put(self, item: _T, timeout: Union[float, datetime.timedelta]=None) -> 'Future[None]': """Put an item into the queue, perhaps waiting until there is room. Returns a Future, which raises `tornado.util.TimeoutError` after a @@ -175,7 +187,7 @@ class Queue(object): `datetime.timedelta` object for a deadline relative to the current time. """ - future = Future() + future = Future() # type: Future[None] try: self.put_nowait(item) except QueueFull: @@ -185,7 +197,7 @@ class Queue(object): future.set_result(None) return future - def put_nowait(self, item): + def put_nowait(self, item: _T) -> None: """Put an item into the queue without blocking. If no free slot is immediately available, raise `QueueFull`. @@ -201,7 +213,7 @@ class Queue(object): else: self.__put_internal(item) - def get(self, timeout=None): + def get(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[_T]': """Remove and return an item from the queue. Returns a Future which resolves once an item is available, or raises @@ -212,7 +224,7 @@ class Queue(object): `datetime.timedelta` object for a deadline relative to the current time. """ - future = Future() + future = Future() # type: Future[_T] try: future.set_result(self.get_nowait()) except QueueEmpty: @@ -220,7 +232,7 @@ class Queue(object): _set_timeout(future, timeout) return future - def get_nowait(self): + def get_nowait(self) -> _T: """Remove and return an item from the queue without blocking. Return an item if one is immediately available, else raise @@ -238,7 +250,7 @@ class Queue(object): else: raise QueueEmpty - def task_done(self): + def task_done(self) -> None: """Indicate that a formerly enqueued task is complete. Used by queue consumers. For each `.get` used to fetch a task, a @@ -256,7 +268,7 @@ class Queue(object): if self._unfinished_tasks == 0: self._finished.set() - def join(self, timeout=None): + def join(self, timeout: Union[float, datetime.timedelta]=None) -> 'Future[None]': """Block until all items in the queue are processed. Returns a Future, which raises `tornado.util.TimeoutError` after a @@ -264,26 +276,26 @@ class Queue(object): """ return self._finished.wait(timeout) - def __aiter__(self): + def __aiter__(self) -> _QueueIterator[_T]: return _QueueIterator(self) # These three are overridable in subclasses. - def _init(self): + def _init(self) -> None: self._queue = collections.deque() - def _get(self): + def _get(self) -> _T: return self._queue.popleft() - def _put(self, item): + def _put(self, item: _T) -> None: self._queue.append(item) # End of the overridable methods. - def __put_internal(self, item): + def __put_internal(self, item: _T) -> None: self._unfinished_tasks += 1 self._finished.clear() self._put(item) - def _consume_expired(self): + def _consume_expired(self) -> None: # Remove timed-out waiters. while self._putters and self._putters[0][1].done(): self._putters.popleft() @@ -291,14 +303,14 @@ class Queue(object): while self._getters and self._getters[0].done(): self._getters.popleft() - def __repr__(self): + def __repr__(self) -> str: return '<%s at %s %s>' % ( type(self).__name__, hex(id(self)), self._format()) - def __str__(self): + def __str__(self) -> str: return '<%s %s>' % (type(self).__name__, self._format()) - def _format(self): + def _format(self) -> str: result = 'maxsize=%r' % (self.maxsize, ) if getattr(self, '_queue', None): result += ' queue=%r' % self._queue @@ -335,13 +347,13 @@ class PriorityQueue(Queue): (1, 'medium-priority item') (10, 'low-priority item') """ - def _init(self): + def _init(self) -> None: self._queue = [] - def _put(self, item): + def _put(self, item: _T) -> None: heapq.heappush(self._queue, item) - def _get(self): + def _get(self) -> _T: return heapq.heappop(self._queue) @@ -367,11 +379,11 @@ class LifoQueue(Queue): 2 3 """ - def _init(self): + def _init(self) -> None: self._queue = [] - def _put(self, item): + def _put(self, item: _T) -> None: self._queue.append(item) - def _get(self): + def _get(self) -> _T: return self._queue.pop() diff --git a/tornado/test/queues_test.py b/tornado/test/queues_test.py index 954021327..e83f26668 100644 --- a/tornado/test/queues_test.py +++ b/tornado/test/queues_test.py @@ -21,7 +21,7 @@ from tornado.testing import gen_test, AsyncTestCase class QueueBasicTest(AsyncTestCase): def test_repr_and_str(self): - q = queues.Queue(maxsize=1) + q = queues.Queue(maxsize=1) # type: queues.Queue[None] self.assertIn(hex(id(q)), repr(q)) self.assertNotIn(hex(id(q)), str(q)) q.get() @@ -44,7 +44,7 @@ class QueueBasicTest(AsyncTestCase): self.assertIn('tasks=2', q_str) def test_order(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] for i in [1, 3, 2]: q.put_nowait(i) @@ -56,7 +56,7 @@ class QueueBasicTest(AsyncTestCase): self.assertRaises(TypeError, queues.Queue, maxsize=None) self.assertRaises(ValueError, queues.Queue, maxsize=-1) - q = queues.Queue(maxsize=2) + q = queues.Queue(maxsize=2) # type: queues.Queue[int] self.assertTrue(q.empty()) self.assertFalse(q.full()) self.assertEqual(2, q.maxsize) @@ -75,22 +75,22 @@ class QueueBasicTest(AsyncTestCase): class QueueGetTest(AsyncTestCase): @gen_test def test_blocking_get(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] q.put_nowait(0) self.assertEqual(0, (yield q.get())) def test_nonblocking_get(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] q.put_nowait(0) self.assertEqual(0, q.get_nowait()) def test_nonblocking_get_exception(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] self.assertRaises(queues.QueueEmpty, q.get_nowait) @gen_test def test_get_with_putters(self): - q = queues.Queue(1) + q = queues.Queue(1) # type: queues.Queue[int] q.put_nowait(0) put = q.put(1) self.assertEqual(0, (yield q.get())) @@ -98,7 +98,7 @@ class QueueGetTest(AsyncTestCase): @gen_test def test_blocking_get_wait(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] q.put(0) self.io_loop.call_later(0.01, q.put, 1) self.io_loop.call_later(0.02, q.put, 2) @@ -107,7 +107,7 @@ class QueueGetTest(AsyncTestCase): @gen_test def test_get_timeout(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] get_timeout = q.get(timeout=timedelta(seconds=0.01)) get = q.get() with self.assertRaises(TimeoutError): @@ -118,7 +118,7 @@ class QueueGetTest(AsyncTestCase): @gen_test def test_get_timeout_preempted(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] get = q.get(timeout=timedelta(seconds=0.01)) q.put(0) yield gen.sleep(0.02) @@ -126,7 +126,7 @@ class QueueGetTest(AsyncTestCase): @gen_test def test_get_clears_timed_out_putters(self): - q = queues.Queue(1) + q = queues.Queue(1) # type: queues.Queue[int] # First putter succeeds, remainder block. putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)] put = q.put(10) @@ -142,7 +142,7 @@ class QueueGetTest(AsyncTestCase): @gen_test def test_get_clears_timed_out_getters(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)] get = q.get() self.assertEqual(11, len(q._getters)) @@ -156,7 +156,7 @@ class QueueGetTest(AsyncTestCase): @gen_test def test_async_for(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] for i in range(5): q.put(i) @@ -173,18 +173,18 @@ class QueueGetTest(AsyncTestCase): class QueuePutTest(AsyncTestCase): @gen_test def test_blocking_put(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] q.put(0) self.assertEqual(0, q.get_nowait()) def test_nonblocking_put_exception(self): - q = queues.Queue(1) + q = queues.Queue(1) # type: queues.Queue[int] q.put(0) self.assertRaises(queues.QueueFull, q.put_nowait, 1) @gen_test def test_put_with_getters(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] get0 = q.get() get1 = q.get() yield q.put(0) @@ -194,7 +194,7 @@ class QueuePutTest(AsyncTestCase): @gen_test def test_nonblocking_put_with_getters(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] get0 = q.get() get1 = q.get() q.put_nowait(0) @@ -207,7 +207,7 @@ class QueuePutTest(AsyncTestCase): @gen_test def test_blocking_put_wait(self): - q = queues.Queue(1) + q = queues.Queue(1) # type: queues.Queue[int] q.put_nowait(0) self.io_loop.call_later(0.01, q.get) self.io_loop.call_later(0.02, q.get) @@ -217,7 +217,7 @@ class QueuePutTest(AsyncTestCase): @gen_test def test_put_timeout(self): - q = queues.Queue(1) + q = queues.Queue(1) # type: queues.Queue[int] q.put_nowait(0) # Now it's full. put_timeout = q.put(1, timeout=timedelta(seconds=0.01)) put = q.put(2) @@ -233,7 +233,7 @@ class QueuePutTest(AsyncTestCase): @gen_test def test_put_timeout_preempted(self): - q = queues.Queue(1) + q = queues.Queue(1) # type: queues.Queue[int] q.put_nowait(0) put = q.put(1, timeout=timedelta(seconds=0.01)) q.get() @@ -242,7 +242,7 @@ class QueuePutTest(AsyncTestCase): @gen_test def test_put_clears_timed_out_putters(self): - q = queues.Queue(1) + q = queues.Queue(1) # type: queues.Queue[int] # First putter succeeds, remainder block. putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)] put = q.put(10) @@ -257,7 +257,7 @@ class QueuePutTest(AsyncTestCase): @gen_test def test_put_clears_timed_out_getters(self): - q = queues.Queue() + q = queues.Queue() # type: queues.Queue[int] getters = [q.get(timedelta(seconds=0.01)) for _ in range(10)] get = q.get() q.get() @@ -393,7 +393,7 @@ class LifoQueueJoinTest(QueueJoinTest): class ProducerConsumerTest(AsyncTestCase): @gen_test def test_producer_consumer(self): - q = queues.Queue(maxsize=3) + q = queues.Queue(maxsize=3) # type: queues.Queue[int] history = [] # We don't yield between get() and task_done(), so get() must wait for diff --git a/tornado/test/tcpclient_test.py b/tornado/test/tcpclient_test.py index 63cf23cf5..3b810d5fe 100644 --- a/tornado/test/tcpclient_test.py +++ b/tornado/test/tcpclient_test.py @@ -40,7 +40,7 @@ class TestTCPServer(TCPServer): def __init__(self, family): super(TestTCPServer, self).__init__() self.streams = [] # type: List[IOStream] - self.queue = Queue() + self.queue = Queue() # type: Queue[IOStream] sockets = bind_sockets(0, 'localhost', family) self.add_sockets(sockets) self.port = sockets[0].getsockname()[1]