]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Implement the async context manager protocol in tornado.locks.
authorBen Darnell <ben@bendarnell.com>
Sun, 2 Aug 2015 22:08:11 +0000 (18:08 -0400)
committerBen Darnell <ben@bendarnell.com>
Sun, 2 Aug 2015 22:08:11 +0000 (18:08 -0400)
tornado/locks.py
tornado/test/locks_test.py

index 27e14953d34d6b4b728ffdfaf8863457c5106fee..b605e8027b4fbff87447af5d037ea6da912d3aef 100644 (file)
@@ -327,6 +327,17 @@ class Semaphore(_TimeoutGarbageCollector):
 
             # Now the semaphore has been released.
             print("Worker %d is done" % worker_id)
+
+    In Python 3.5, the semaphore itself can be used as an async context
+    manager::
+
+        async def worker(worker_id):
+            async with sem:
+                print("Worker %d is working" % worker_id)
+                await use_some_resource()
+
+            # Now the semaphore has been released.
+            print("Worker %d is done" % worker_id)
     """
     def __init__(self, value=1):
         super(Semaphore, self).__init__()
@@ -389,6 +400,14 @@ class Semaphore(_TimeoutGarbageCollector):
 
     __exit__ = __enter__
 
+    @gen.coroutine
+    def __aenter__(self):
+        yield self.acquire()
+
+    @gen.coroutine
+    def __aexit__(self, typ, value, tb):
+        self.release()
+
 
 class BoundedSemaphore(Semaphore):
     """A semaphore that prevents release() being called too many times.
@@ -418,7 +437,7 @@ class Lock(object):
 
     Releasing an unlocked lock raises `RuntimeError`.
 
-    `acquire` supports the context manager protocol:
+    `acquire` supports the context manager protocol in all Python versions:
 
     >>> from tornado import gen, locks
     >>> lock = locks.Lock()
@@ -430,6 +449,16 @@ class Lock(object):
     ...        pass
     ...
     ...    # Now the lock is released.
+
+    In Python 3.5, `Lock` also supports the async context manager protocol.
+    Note that in this case there is no `acquire`:
+
+    >>> async def f():  # doctest: +SKIP
+    ...    async with lock:
+    ...        # Do something holding the lock.
+    ...        pass
+    ...
+    ...    # Now the lock is released.
     """
     def __init__(self):
         self._block = BoundedSemaphore(value=1)
@@ -464,3 +493,11 @@ class Lock(object):
             "Use Lock like 'with (yield lock)', not like 'with lock'")
 
     __exit__ = __enter__
+
+    @gen.coroutine
+    def __aenter__(self):
+        yield self.acquire()
+
+    @gen.coroutine
+    def __aexit__(self, typ, value, tb):
+        self.release()
index 90bdafaa6020a64057b8299eace09785a13dfb05..b7b630970c91f6a8effbbc8bb197a6a5a21fc77f 100644 (file)
 # under the License.
 
 from datetime import timedelta
+import sys
+import textwrap
 
 from tornado import gen, locks
 from tornado.gen import TimeoutError
 from tornado.testing import gen_test, AsyncTestCase
 from tornado.test.util import unittest
 
+skipBefore35 = unittest.skipIf(sys.version_info < (3, 5), 'PEP 492 (async/await) not available')
+
 
 class ConditionTest(AsyncTestCase):
     def setUp(self):
@@ -328,6 +332,24 @@ class SemaphoreContextManagerTest(AsyncTestCase):
         # Semaphore was released and can be acquired again.
         self.assertTrue(sem.acquire().done())
 
+    @skipBefore35
+    @gen_test
+    def test_context_manager_async_await(self):
+        # Repeat the above test using 'async with'.
+        sem = locks.Semaphore()
+
+        global_namespace = dict(globals(), **locals())
+        local_namespace = {}
+        exec(textwrap.dedent("""
+        async def f():
+            async with sem as yielded:
+                self.assertTrue(yielded is None)
+        """), global_namespace, local_namespace)
+        yield local_namespace['f']()
+
+        # Semaphore was released and can be acquired again.
+        self.assertTrue(sem.acquire().done())
+
     @gen_test
     def test_context_manager_exception(self):
         sem = locks.Semaphore()
@@ -443,6 +465,28 @@ class LockTests(AsyncTestCase):
         yield futures
         self.assertEqual(list(range(N)), history)
 
+    @skipBefore35
+    @gen_test
+    def test_acquire_fifo_async_with(self):
+        # Repeat the above test using `async with lock:`
+        # instead of `with (yield lock.acquire()):`.
+        lock = locks.Lock()
+        self.assertTrue(lock.acquire().done())
+        N = 5
+        history = []
+
+        global_namespace = dict(globals(), **locals())
+        local_namespace = {}
+        exec(textwrap.dedent("""
+        async def f(idx):
+            async with lock:
+                history.append(idx)
+        """), global_namespace, local_namespace)
+        futures = [local_namespace['f'](i) for i in range(N)]
+        lock.release()
+        yield futures
+        self.assertEqual(list(range(N)), history)
+
     @gen_test
     def test_acquire_timeout(self):
         lock = locks.Lock()