From: Ben Darnell Date: Sun, 2 Aug 2015 22:08:11 +0000 (-0400) Subject: Implement the async context manager protocol in tornado.locks. X-Git-Tag: v4.3.0b1~63^2~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=05d53e8c4fea0ef687f175b95e846e375f2ce245;p=thirdparty%2Ftornado.git Implement the async context manager protocol in tornado.locks. --- diff --git a/tornado/locks.py b/tornado/locks.py index 27e14953d..b605e8027 100644 --- a/tornado/locks.py +++ b/tornado/locks.py @@ -327,6 +327,17 @@ class Semaphore(_TimeoutGarbageCollector): # Now the semaphore has been released. print("Worker %d is done" % worker_id) + + In Python 3.5, the semaphore itself can be used as an async context + manager:: + + async def worker(worker_id): + async with sem: + print("Worker %d is working" % worker_id) + await use_some_resource() + + # Now the semaphore has been released. + print("Worker %d is done" % worker_id) """ def __init__(self, value=1): super(Semaphore, self).__init__() @@ -389,6 +400,14 @@ class Semaphore(_TimeoutGarbageCollector): __exit__ = __enter__ + @gen.coroutine + def __aenter__(self): + yield self.acquire() + + @gen.coroutine + def __aexit__(self, typ, value, tb): + self.release() + class BoundedSemaphore(Semaphore): """A semaphore that prevents release() being called too many times. @@ -418,7 +437,7 @@ class Lock(object): Releasing an unlocked lock raises `RuntimeError`. - `acquire` supports the context manager protocol: + `acquire` supports the context manager protocol in all Python versions: >>> from tornado import gen, locks >>> lock = locks.Lock() @@ -430,6 +449,16 @@ class Lock(object): ... pass ... ... # Now the lock is released. + + In Python 3.5, `Lock` also supports the async context manager protocol. + Note that in this case there is no `acquire`: + + >>> async def f(): # doctest: +SKIP + ... async with lock: + ... # Do something holding the lock. + ... pass + ... + ... # Now the lock is released. """ def __init__(self): self._block = BoundedSemaphore(value=1) @@ -464,3 +493,11 @@ class Lock(object): "Use Lock like 'with (yield lock)', not like 'with lock'") __exit__ = __enter__ + + @gen.coroutine + def __aenter__(self): + yield self.acquire() + + @gen.coroutine + def __aexit__(self, typ, value, tb): + self.release() diff --git a/tornado/test/locks_test.py b/tornado/test/locks_test.py index 90bdafaa6..b7b630970 100644 --- a/tornado/test/locks_test.py +++ b/tornado/test/locks_test.py @@ -11,12 +11,16 @@ # under the License. from datetime import timedelta +import sys +import textwrap from tornado import gen, locks from tornado.gen import TimeoutError from tornado.testing import gen_test, AsyncTestCase from tornado.test.util import unittest +skipBefore35 = unittest.skipIf(sys.version_info < (3, 5), 'PEP 492 (async/await) not available') + class ConditionTest(AsyncTestCase): def setUp(self): @@ -328,6 +332,24 @@ class SemaphoreContextManagerTest(AsyncTestCase): # Semaphore was released and can be acquired again. self.assertTrue(sem.acquire().done()) + @skipBefore35 + @gen_test + def test_context_manager_async_await(self): + # Repeat the above test using 'async with'. + sem = locks.Semaphore() + + global_namespace = dict(globals(), **locals()) + local_namespace = {} + exec(textwrap.dedent(""" + async def f(): + async with sem as yielded: + self.assertTrue(yielded is None) + """), global_namespace, local_namespace) + yield local_namespace['f']() + + # Semaphore was released and can be acquired again. + self.assertTrue(sem.acquire().done()) + @gen_test def test_context_manager_exception(self): sem = locks.Semaphore() @@ -443,6 +465,28 @@ class LockTests(AsyncTestCase): yield futures self.assertEqual(list(range(N)), history) + @skipBefore35 + @gen_test + def test_acquire_fifo_async_with(self): + # Repeat the above test using `async with lock:` + # instead of `with (yield lock.acquire()):`. + lock = locks.Lock() + self.assertTrue(lock.acquire().done()) + N = 5 + history = [] + + global_namespace = dict(globals(), **locals()) + local_namespace = {} + exec(textwrap.dedent(""" + async def f(idx): + async with lock: + history.append(idx) + """), global_namespace, local_namespace) + futures = [local_namespace['f'](i) for i in range(N)] + lock.release() + yield futures + self.assertEqual(list(range(N)), history) + @gen_test def test_acquire_timeout(self): lock = locks.Lock()