]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
shoulda listened harder in APL class
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Jul 2010 22:52:47 +0000 (18:52 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Jul 2010 22:52:47 +0000 (18:52 -0400)
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/event.py

index ae7df83f6064f811988cf61689eb9f8c655d46b5..264a71bef60a2577c67141b18e4a4fb66ed36038 100644 (file)
@@ -1573,7 +1573,6 @@ class Engine(Connectable, log.Identified):
 
     _execution_options = util.frozendict()
     Connection = Connection
-    _dispatch = event.dispatcher(_EngineDispatch)
     
     
     def __init__(self, pool, dialect, url, 
@@ -1592,6 +1591,44 @@ class Engine(Connectable, log.Identified):
             interfaces.ConnectionProxy._adapt_listener(self, proxy)
         if execution_options:
             self.update_execution_options(**execution_options)
+
+    class events(_EngineDispatch):
+        def execute(self, conn, execute, clauseelement, *multiparams, **params):
+            """Intercept high level execute() events."""
+
+        def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+            """Intercept low-level cursor execute() events."""
+
+        def begin(self, conn, begin):
+            """Intercept begin() events."""
+
+        def rollback(self, conn, rollback):
+            """Intercept rollback() events."""
+
+        def commit(self, conn, commit):
+            """Intercept commit() events."""
+
+        def savepoint(self, conn, savepoint, name=None):
+            """Intercept savepoint() events."""
+
+        def rollback_savepoint(self, conn, rollback_savepoint, name, context):
+            """Intercept rollback_savepoint() events."""
+
+        def release_savepoint(self, conn, release_savepoint, name, context):
+            """Intercept release_savepoint() events."""
+
+        def begin_twophase(self, conn, begin_twophase, xid):
+            """Intercept begin_twophase() events."""
+
+        def prepare_twophase(self, conn, prepare_twophase, xid):
+            """Intercept prepare_twophase() events."""
+
+        def rollback_twophase(self, conn, rollback_twophase, xid, is_prepared):
+            """Intercept rollback_twophase() events."""
+
+        def commit_twophase(self, conn, commit_twophase, xid, is_prepared):
+            """Intercept commit_twophase() events."""
+    events = event.dispatcher(events)
     
     def update_execution_options(self, **opt):
         """update the execution_options dictionary of this :class:`Engine`.
@@ -1811,50 +1848,51 @@ class Engine(Connectable, log.Identified):
 
 def _proxy_connection_cls(cls, dispatch):
     class ProxyConnection(cls):
-        def execute(self, object, *multiparams, **params):
-            if not dispatch.on_execute:
-                return super(ProxyConnection, self).execute(object, *multiparams, **params)
-            else:
-                orig = super(ProxyConnection, self).execute
-                return dispatch.exec_('on_execute', orig, 
-                                        conn=self, 
-                                        execute=orig, 
-                                        clauseelement=object, 
-                                        multiparams=multiparams, 
-                                        params=params
-                )
+        def _exec_recursive(self, fns, orig):
+            if not fns:
+                return orig
+            def go(*arg, **kw):
+                nested = self._exec_recursive(fns[1:], orig)
+                ret = fns[0](self, nested, *arg, **kw)
+                if ret is None:
+                    return nested(*arg, **kw)
+                else:
+                    return ret
+            return go
+
+        def _exec_recursive_minus_self(self, fns, orig):
+            if not fns:
+                return orig
+            def go(*arg, **kw):
+                nested = self._exec_recursive(fns[1:], orig)
+                ret = fns[0](nested, *arg, **kw)
+                if ret is None:
+                    return nested(*arg, **kw)
+                else:
+                    return ret
+            return go
             
-        def _execute_clauseelement(self, elem, multiparams=None, params=None):
-            if not dispatch.on_execute:
-                return super(ProxyConnection, self).\
-                        _execute_clauseelement(elem, 
-                                    multiparams=multiparams, 
-                                    params=params)
-            else:
-                orig = super(ProxyConnection, self).execute
-                return dispatch.exec_('on_execute', orig, 
-                                    conn=self, 
-                                    execute=orig, 
-                                    clauseelement=elem, 
-                                    multiparams=multiparams or [], 
-                                    params=params or {}
-                )
+        def execute(self, clauseelement, *multiparams, **params):
 
+            orig = super(ProxyConnection, self).execute
+            
+            g = self._exec_recursive(
+                                dispatch.on_execute, 
+                                orig) 
+            return g(clauseelement, *multiparams, **params)
+            
+            
+        def _execute_clauseelement(self, clauseelement, multiparams=None, params=None):
+            return self.execute(clauseelement, *(multiparams or []), **(params or {}))
 
         def _cursor_execute(self, cursor, statement, 
                                     parameters, context=None):
             orig = super(ProxyConnection, self)._cursor_execute
-            if not dispatch.on_cursor_execute:
-                return orig(cursor, statement, parameters, context=context)
-            else:
-                return dispatch.exec_('on_cursor_execute', orig, 
-                                    conn=self, 
-                                    execute=super(ProxyConnection, self).execute, 
-                                    cursor=cursor,
-                                    statement=statement,
-                                    parameters=parameters,
-                                    executemany=False,
-                                    context=context)
+            g = self._exec_recursive_minus_self(
+                dispatch.on_cursor_execute,
+                orig
+            )
+            return g(cursor, statement, parameters, context=None)
         
         # these are all TODO
         def _cursor_executemany(self, cursor, statement, 
index f844b33452aebb61adbae9b89b3b472a66362e66..5fcda0a6535558b4a87515312d25a91bd813719d 100644 (file)
@@ -33,13 +33,27 @@ class Events(object):
 
 class _ExecEvent(object):
     def exec_and_clear(self, *args, **kw):
-        """Execute the given event once, then clear all listeners."""
+        """Execute this event once, then clear all listeners."""
         
         self(*args, **kw)
         self[:] = []
+    
+    def exec_until_return(self, *args, **kw):
+        """Execute listeners for this event until
+        one returns a non-None value.
+        
+        Returns the value, or None.
+        """
+        
+        if self:
+            for fn in self:
+                r = fn(*args, **kw)
+                if r is not None:
+                    return r
+        return None
         
     def __call__(self, *args, **kw):
-        """Execute the given event."""
+        """Execute this event."""
         
         if self:
             for fn in self: