]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Simpler code for Semaphore.acquire() as a context manager.
authorA. Jesse Jiryu Davis <jesse@mongodb.com>
Wed, 25 Feb 2015 20:52:55 +0000 (15:52 -0500)
committerA. Jesse Jiryu Davis <jesse@mongodb.com>
Wed, 25 Feb 2015 20:52:55 +0000 (15:52 -0500)
tornado/locks.py

index 73a06c85bc6fb25df7b37c48a9e11d2a5a05786d..f6852ea08f861410118c4340091b08314ea6b539 100644 (file)
@@ -17,7 +17,6 @@ from __future__ import absolute_import, division, print_function, with_statement
 __all__ = ['Condition', 'Event', 'Semaphore']
 
 import collections
-import contextlib
 
 from tornado import gen, ioloop
 from tornado.concurrent import Future
@@ -126,35 +125,22 @@ class Event(object):
             return gen.with_timeout(timeout, self._future)
 
 
-class _ContextManagerFuture(Future):
-    """A Future that can be used with the "with" statement.
+class _ReleasingContextManager(object):
+    """Releases a Lock or Semaphore at the end of a "with" statement.
 
-    When a coroutine yields this Future, the return value is a context manager
-    that can be used like:
-
-        with (yield future):
+        with (yield semaphore.acquire()):
             pass
 
-    At the end of the block, the Future's exit callback is run. Used for
-    Lock.acquire and Semaphore.acquire.
+        # Now semaphore.release() has been called.
     """
-    def __init__(self, wrapped, exit_callback):
-        super(_ContextManagerFuture, self).__init__()
-        gen.chain_future(wrapped, self)
-        self.exit_callback = exit_callback
+    def __init__(self, obj):
+        self._obj = obj
 
-    def result(self, timeout=None):
-        if self.exception():
-            raise self.exception()
+    def __enter__(self):
+        pass
 
-        # Otherwise return a context manager that cleans up after the block.
-        @contextlib.contextmanager
-        def f():
-            try:
-                yield
-            finally:
-                self.exit_callback()
-        return f()
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self._obj.release()
 
 
 class Semaphore(object):
@@ -212,7 +198,14 @@ class Semaphore(object):
         for waiter in self._waiters:
             if not waiter.done():
                 self._value -= 1
-                waiter.set_result(None)
+
+                # If the waiter is a coroutine paused at
+                #
+                #     with (yield semaphore.acquire()):
+                #
+                # then the context manager's __exit__ calls release() at the end
+                # of the "with" block.
+                waiter.set_result(_ReleasingContextManager(self))
                 break
 
     def acquire(self, timeout=None):
@@ -223,7 +216,8 @@ class Semaphore(object):
         """
         if self._value > 0:
             self._value -= 1
-            future = gen._null_future
+            future = Future()
+            future.set_result(_ReleasingContextManager(self))
         else:
             waiter = Future()
             self._waiters.append(waiter)
@@ -235,7 +229,7 @@ class Semaphore(object):
                 gen.chain_future(future, waiter)
             else:
                 future = waiter
-        return _ContextManagerFuture(future, self.release)
+        return future
 
     def __enter__(self):
         raise RuntimeError(