]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added @event.listens_for() decorator, given
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Apr 2011 17:29:11 +0000 (13:29 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 2 Apr 2011 17:29:11 +0000 (13:29 -0400)
target + event name, applies the decorated
function as a listener.  [ticket:2106]
- remove usage of globals from test.base.test_events

CHANGES
doc/build/core/event.rst
lib/sqlalchemy/event.py
test/base/test_events.py

diff --git a/CHANGES b/CHANGES
index 3765a6ed9eed32db3457543f3ba68dad9e46e76c..b4127beae4439da4a7fbe10b06bb5141d469a7c0 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -46,6 +46,11 @@ CHANGES
     collection of Sequence objects, list
     of schema names.  [ticket:2104]
 
+-event
+  - Added @event.listens_for() decorator, given
+    target + event name, applies the decorated 
+    function as a listener.  [ticket:2106]
+
 - pool
   - AssertionPool now stores the traceback indicating
     where the currently checked out connection was
index 86cc7d968faf2eaf5fc8361692f327bee0ab689f..68d4802bc5b4884373b81c685e15570140fc4f04 100644 (file)
@@ -102,3 +102,5 @@ API Reference
 
 .. autofunction:: sqlalchemy.event.listen
 
+.. autofunction:: sqlalchemy.event.listens_for
+
index b2e5cd00f256306cb2f2889e5024186b71cfb139..4be227c51634dd8efb9a580b68b7a967d353ea99 100644 (file)
@@ -13,6 +13,21 @@ NO_RETVAL = util.symbol('NO_RETVAL')
 
 def listen(target, identifier, fn, *args, **kw):
     """Register a listener function for the given target.
+    
+    e.g.::
+    
+        from sqlalchemy import event
+        from sqlalchemy.schema import UniqueConstraint
+        
+        def unique_constraint_name(const, table):
+            const.name = "uq_%s_%s" % (
+                table.name,
+                list(const.columns)[0].name
+            )
+        event.listen(
+                UniqueConstraint, 
+                "after_parent_attach", 
+                unique_constraint_name)
 
     """
 
@@ -24,6 +39,26 @@ def listen(target, identifier, fn, *args, **kw):
     raise exc.InvalidRequestError("No such event '%s' for target '%s'" %
                                 (identifier,target))
 
+def listens_for(target, identifier, *args, **kw):
+    """Decorate a function as a listener for the given target + identifier.
+    
+    e.g.::
+    
+        from sqlalchemy import event
+        from sqlalchemy.schema import UniqueConstraint
+        
+        @event.listens_for(UniqueConstraint, "after_parent_attach")
+        def unique_constraint_name(const, table):
+            const.name = "uq_%s_%s" % (
+                table.name,
+                list(const.columns)[0].name
+            )
+    """
+    def decorate(fn):
+        listen(target, identifier, fn, *args, **kw)
+        return fn
+    return decorate
+
 def remove(target, identifier, fn):
     """Remove an event listener.
 
index 96cda7cc9051442f41d6902865c5a7d2c937ea3f..94d3dad8558aa4b86b4625a847102e9b687375a4 100644 (file)
@@ -8,8 +8,6 @@ class TestEvents(fixtures.TestBase):
     """Test class- and instance-level event registration."""
 
     def setUp(self):
-        global Target
-
         assert 'event_one' not in event._registrars
         assert 'event_two' not in event._registrars
 
@@ -20,31 +18,35 @@ class TestEvents(fixtures.TestBase):
             def event_two(self, x):
                 pass
 
+            def event_three(self, x):
+                pass
+
         class Target(object):
             dispatch = event.dispatcher(TargetEvents)
+        self.Target = Target
 
     def tearDown(self):
-        event._remove_dispatcher(Target.__dict__['dispatch'].events)
+        event._remove_dispatcher(self.Target.__dict__['dispatch'].events)
 
     def test_register_class(self):
         def listen(x, y):
             pass
 
-        event.listen(Target, "event_one", listen)
+        event.listen(self.Target, "event_one", listen)
 
-        eq_(len(Target().dispatch.event_one), 1)
-        eq_(len(Target().dispatch.event_two), 0)
+        eq_(len(self.Target().dispatch.event_one), 1)
+        eq_(len(self.Target().dispatch.event_two), 0)
 
     def test_register_instance(self):
         def listen(x, y):
             pass
 
-        t1 = Target()
+        t1 = self.Target()
         event.listen(t1, "event_one", listen)
 
-        eq_(len(Target().dispatch.event_one), 0)
+        eq_(len(self.Target().dispatch.event_one), 0)
         eq_(len(t1.dispatch.event_one), 1)
-        eq_(len(Target().dispatch.event_two), 0)
+        eq_(len(self.Target().dispatch.event_two), 0)
         eq_(len(t1.dispatch.event_two), 0)
 
     def test_register_class_instance(self):
@@ -54,21 +56,21 @@ class TestEvents(fixtures.TestBase):
         def listen_two(x, y):
             pass
 
-        event.listen(Target, "event_one", listen_one)
+        event.listen(self.Target, "event_one", listen_one)
 
-        t1 = Target()
+        t1 = self.Target()
         event.listen(t1, "event_one", listen_two)
 
-        eq_(len(Target().dispatch.event_one), 1)
+        eq_(len(self.Target().dispatch.event_one), 1)
         eq_(len(t1.dispatch.event_one), 2)
-        eq_(len(Target().dispatch.event_two), 0)
+        eq_(len(self.Target().dispatch.event_two), 0)
         eq_(len(t1.dispatch.event_two), 0)
 
         def listen_three(x, y):
             pass
 
-        event.listen(Target, "event_one", listen_three)
-        eq_(len(Target().dispatch.event_one), 2)
+        event.listen(self.Target, "event_one", listen_three)
+        eq_(len(self.Target().dispatch.event_one), 2)
         eq_(len(t1.dispatch.event_one), 3)
 
     def test_append_vs_insert(self):
@@ -81,21 +83,44 @@ class TestEvents(fixtures.TestBase):
         def listen_three(x, y):
             pass
 
-        event.listen(Target, "event_one", listen_one)
-        event.listen(Target, "event_one", listen_two)
-        event.listen(Target, "event_one", listen_three, insert=True)
+        event.listen(self.Target, "event_one", listen_one)
+        event.listen(self.Target, "event_one", listen_two)
+        event.listen(self.Target, "event_one", listen_three, insert=True)
 
         eq_(
-            list(Target().dispatch.event_one),
+            list(self.Target().dispatch.event_one),
             [listen_three, listen_one, listen_two]
         )
 
+    def test_decorator(self):
+        @event.listens_for(self.Target, "event_one")
+        def listen_one(x, y):
+            pass
+
+        @event.listens_for(self.Target, "event_two")
+        @event.listens_for(self.Target, "event_three")
+        def listen_two(x, y):
+            pass
+
+        eq_(
+            list(self.Target().dispatch.event_one),
+            [listen_one]
+        )
+
+        eq_(
+            list(self.Target().dispatch.event_two),
+            [listen_two]
+        )
+
+        eq_(
+            list(self.Target().dispatch.event_three),
+            [listen_two]
+        )
+
 class TestAcceptTargets(fixtures.TestBase):
     """Test default target acceptance."""
 
     def setUp(self):
-        global TargetOne, TargetTwo
-
         class TargetEventsOne(event.Events):
             def event_one(self, x, y):
                 pass
@@ -109,10 +134,12 @@ class TestAcceptTargets(fixtures.TestBase):
 
         class TargetTwo(object):
             dispatch = event.dispatcher(TargetEventsTwo)
+        self.TargetOne = TargetOne
+        self.TargetTwo = TargetTwo
 
     def tearDown(self):
-        event._remove_dispatcher(TargetOne.__dict__['dispatch'].events)
-        event._remove_dispatcher(TargetTwo.__dict__['dispatch'].events)
+        event._remove_dispatcher(self.TargetOne.__dict__['dispatch'].events)
+        event._remove_dispatcher(self.TargetTwo.__dict__['dispatch'].events)
 
     def test_target_accept(self):
         """Test that events of the same name are routed to the correct
@@ -132,21 +159,21 @@ class TestAcceptTargets(fixtures.TestBase):
         def listen_four(x, y):
             pass
 
-        event.listen(TargetOne, "event_one", listen_one)
-        event.listen(TargetTwo, "event_one", listen_two)
+        event.listen(self.TargetOne, "event_one", listen_one)
+        event.listen(self.TargetTwo, "event_one", listen_two)
 
         eq_(
-            list(TargetOne().dispatch.event_one),
+            list(self.TargetOne().dispatch.event_one),
             [listen_one]
         )
 
         eq_(
-            list(TargetTwo().dispatch.event_one),
+            list(self.TargetTwo().dispatch.event_one),
             [listen_two]
         )
 
-        t1 = TargetOne()
-        t2 = TargetTwo()
+        t1 = self.TargetOne()
+        t2 = self.TargetTwo()
 
         event.listen(t1, "event_one", listen_three)
         event.listen(t2, "event_one", listen_four)
@@ -165,8 +192,6 @@ class TestCustomTargets(fixtures.TestBase):
     """Test custom target acceptance."""
 
     def setUp(self):
-        global Target
-
         class TargetEvents(event.Events):
             @classmethod
             def _accept_with(cls, target):
@@ -180,9 +205,10 @@ class TestCustomTargets(fixtures.TestBase):
 
         class Target(object):
             dispatch = event.dispatcher(TargetEvents)
+        self.Target = Target
 
     def tearDown(self):
-        event._remove_dispatcher(Target.__dict__['dispatch'].events)
+        event._remove_dispatcher(self.Target.__dict__['dispatch'].events)
 
     def test_indirect(self):
         def listen(x, y):
@@ -191,22 +217,20 @@ class TestCustomTargets(fixtures.TestBase):
         event.listen("one", "event_one", listen)
 
         eq_(
-            list(Target().dispatch.event_one),
+            list(self.Target().dispatch.event_one),
             [listen]
         )
 
         assert_raises(
             exc.InvalidRequestError, 
             event.listen,
-            listen, "event_one", Target
+            listen, "event_one", self.Target
         )
 
 class TestListenOverride(fixtures.TestBase):
     """Test custom listen functions which change the listener function signature."""
 
     def setUp(self):
-        global Target
-
         class TargetEvents(event.Events):
             @classmethod
             def _listen(cls, target, identifier, fn, add=False):
@@ -223,9 +247,10 @@ class TestListenOverride(fixtures.TestBase):
 
         class Target(object):
             dispatch = event.dispatcher(TargetEvents)
+        self.Target = Target
 
     def tearDown(self):
-        event._remove_dispatcher(Target.__dict__['dispatch'].events)
+        event._remove_dispatcher(self.Target.__dict__['dispatch'].events)
 
     def test_listen_override(self):
         result = []
@@ -235,10 +260,10 @@ class TestListenOverride(fixtures.TestBase):
         def listen_two(x, y):
             result.append((x, y))
 
-        event.listen(Target, "event_one", listen_one, add=True)
-        event.listen(Target, "event_one", listen_two)
+        event.listen(self.Target, "event_one", listen_one, add=True)
+        event.listen(self.Target, "event_one", listen_two)
 
-        t1 = Target()
+        t1 = self.Target()
         t1.dispatch.event_one(5, 7)
         t1.dispatch.event_one(10, 5)
 
@@ -250,8 +275,6 @@ class TestListenOverride(fixtures.TestBase):
 
 class TestPropagate(fixtures.TestBase):
     def setUp(self):
-        global Target
-
         class TargetEvents(event.Events):
             def event_one(self, arg):
                 pass
@@ -261,6 +284,7 @@ class TestPropagate(fixtures.TestBase):
 
         class Target(object):
             dispatch = event.dispatcher(TargetEvents)
+        self.Target = Target
 
 
     def test_propagate(self):
@@ -271,12 +295,12 @@ class TestPropagate(fixtures.TestBase):
         def listen_two(target, arg):
             result.append((target, arg))
 
-        t1 = Target()
+        t1 = self.Target()
 
         event.listen(t1, "event_one", listen_one, propagate=True)
         event.listen(t1, "event_two", listen_two)
 
-        t2 = Target()
+        t2 = self.Target()
 
         t2.dispatch._update(t1.dispatch)