]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- got engine events partially working, needs work on return value considerations
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Jul 2010 00:35:03 +0000 (20:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Jul 2010 00:35:03 +0000 (20:35 -0400)
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/event.py
lib/sqlalchemy/interfaces.py

index 264a71bef60a2577c67141b18e4a4fb66ed36038..a5f99022f93eecacfc34218c9476801037d79726 100644 (file)
@@ -1546,20 +1546,6 @@ class TwoPhaseTransaction(Transaction):
     def _do_commit(self):
         self.connection._commit_twophase_impl(self.xid, self._is_prepared)
 
-class _EngineDispatch(event.Events):
-    def append(self, fn, identifier, target):
-        if isinstance(target.Connection, Connection):
-            target.Connection = _proxy_connection_cls(target.Connection, self)
-        event.Dispatch.append(self, fn, identifier)
-
-    def exec_(self, identifier, orig, kw):
-        for fn in getattr(self, identifier):
-            r = fn(**kw)
-            if r:
-                return r
-        else:
-            return orig()
-
 class Engine(Connectable, log.Identified):
     """
     Connects a :class:`~sqlalchemy.pool.Pool` and 
@@ -1574,7 +1560,6 @@ class Engine(Connectable, log.Identified):
     _execution_options = util.frozendict()
     Connection = Connection
     
-    
     def __init__(self, pool, dialect, url, 
                         logging_name=None, echo=None, proxy=None,
                         execution_options=None
@@ -1592,41 +1577,50 @@ class Engine(Connectable, log.Identified):
         if execution_options:
             self.update_execution_options(**execution_options)
 
-    class events(_EngineDispatch):
-        def execute(self, conn, execute, clauseelement, *multiparams, **params):
+    class events(event.Events):
+        @classmethod
+        def listen(cls, target, fn, identifier):
+            if issubclass(target.Connection, Connection):
+                target.Connection = _proxy_connection_cls(
+                                            Connection, 
+                                            target.events)
+            event.Events.listen(target, fn, identifier)
+            
+        def on_execute(self, conn, execute, clauseelement, *multiparams, **params):
             """Intercept high level execute() events."""
 
-        def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+        def on_cursor_execute(self, conn, execute, cursor, statement, 
+                            parameters, context, executemany):
             """Intercept low-level cursor execute() events."""
 
-        def begin(self, conn, begin):
+        def on_begin(self, conn, begin):
             """Intercept begin() events."""
 
-        def rollback(self, conn, rollback):
+        def on_rollback(self, conn, rollback):
             """Intercept rollback() events."""
 
-        def commit(self, conn, commit):
+        def on_commit(self, conn, commit):
             """Intercept commit() events."""
 
-        def savepoint(self, conn, savepoint, name=None):
+        def on_savepoint(self, conn, savepoint, name=None):
             """Intercept savepoint() events."""
 
-        def rollback_savepoint(self, conn, rollback_savepoint, name, context):
+        def on_rollback_savepoint(self, conn, rollback_savepoint, name, context):
             """Intercept rollback_savepoint() events."""
 
-        def release_savepoint(self, conn, release_savepoint, name, context):
+        def on_release_savepoint(self, conn, release_savepoint, name, context):
             """Intercept release_savepoint() events."""
 
-        def begin_twophase(self, conn, begin_twophase, xid):
+        def on_begin_twophase(self, conn, begin_twophase, xid):
             """Intercept begin_twophase() events."""
 
-        def prepare_twophase(self, conn, prepare_twophase, xid):
+        def on_prepare_twophase(self, conn, prepare_twophase, xid):
             """Intercept prepare_twophase() events."""
 
-        def rollback_twophase(self, conn, rollback_twophase, xid, is_prepared):
+        def on_rollback_twophase(self, conn, rollback_twophase, xid, is_prepared):
             """Intercept rollback_twophase() events."""
 
-        def commit_twophase(self, conn, commit_twophase, xid, is_prepared):
+        def on_commit_twophase(self, conn, commit_twophase, xid, is_prepared):
             """Intercept commit_twophase() events."""
     events = event.dispatcher(events)
     
@@ -1847,106 +1841,106 @@ class Engine(Connectable, log.Identified):
         return self.pool.unique_connection()
 
 def _proxy_connection_cls(cls, dispatch):
+    def _exec_recursive(conn, fns, orig):
+        if not fns:
+            return orig
+        def go(*arg, **kw):
+            nested = _exec_recursive(conn, fns[1:], orig)
+            ret = fns[0](conn, nested, *arg, **kw)
+            # TODO: need to get consistent way to check 
+            # for "they called the fn, they didn't", or otherwise
+            # make some decision here how this is to work
+            #if ret is None:
+            #    return nested(*arg, **kw)
+            #else:
+            return ret
+        return go
+
     class ProxyConnection(cls):
-        def _exec_recursive(self, fns, orig):
-            if not fns:
-                return orig
-            def go(*arg, **kw):
-                nested = self._exec_recursive(fns[1:], orig)
-                ret = fns[0](self, nested, *arg, **kw)
-                if ret is None:
-                    return nested(*arg, **kw)
-                else:
-                    return ret
-            return go
-
-        def _exec_recursive_minus_self(self, fns, orig):
-            if not fns:
-                return orig
-            def go(*arg, **kw):
-                nested = self._exec_recursive(fns[1:], orig)
-                ret = fns[0](nested, *arg, **kw)
-                if ret is None:
-                    return nested(*arg, **kw)
-                else:
-                    return ret
-            return go
-            
         def execute(self, clauseelement, *multiparams, **params):
-
-            orig = super(ProxyConnection, self).execute
-            
-            g = self._exec_recursive(
-                                dispatch.on_execute, 
-                                orig) 
+            g = _exec_recursive(self, dispatch.on_execute, 
+                                super(ProxyConnection, self).execute) 
             return g(clauseelement, *multiparams, **params)
             
-            
         def _execute_clauseelement(self, clauseelement, multiparams=None, params=None):
             return self.execute(clauseelement, *(multiparams or []), **(params or {}))
 
+        # TODO : this is all wrong, cursor_execute() and 
+        # cursor_executemany() don't have a return value, need to find some 
+        # other way to check for executed on these
+        
         def _cursor_execute(self, cursor, statement, 
                                     parameters, context=None):
-            orig = super(ProxyConnection, self)._cursor_execute
-            g = self._exec_recursive_minus_self(
-                dispatch.on_cursor_execute,
-                orig
-            )
-            return g(cursor, statement, parameters, context=None)
-        
-        # these are all TODO
-        def _cursor_executemany(self, cursor, statement, 
-                                    parameters, context=None):
-            return proxy.cursor_execute(
-                            super(ProxyConnection, self)._cursor_executemany, 
-                            cursor, statement, parameters, context, True)
-
+            g = _exec_recursive(self, dispatch.on_cursor_execute,
+                    self._cursor_exec)
+            return g(cursor, statement, parameters, context, False)
+        
+        def _cursor_executemany(self, cursor, statement, parameters,
+                                                context=None, ):
+            g = _exec_recursive(self, dispatch.on_cursor_execute,
+                    self._cursor_exec)
+            return g(cursor, statement, parameters, context, True)
+        
+        def _cursor_exec(self, cursor, statement, parameters, context,
+                                        executemany):
+            if executemany:
+                return super(ProxyConnection,
+                             self)._cursor_executemany(cursor,
+                        statement, parameters, context)
+            else:
+                return super(ProxyConnection,
+                             self)._cursor_execute(cursor, statement,
+                        parameters, context)
+                
         def _begin_impl(self):
-            return proxy.begin(self, super(ProxyConnection, self)._begin_impl)
+            g = _exec_recursive(self, dispatch.on_begin,
+                    super(ProxyConnection, self)._begin_impl)
+            return g()
             
         def _rollback_impl(self):
-            return proxy.rollback(self, 
-                                super(ProxyConnection, self)._rollback_impl)
+            g = _exec_recursive(self, dispatch.on_rollback,
+                    super(ProxyConnection, self)._rollback_impl)
+            return g()
 
         def _commit_impl(self):
-            return proxy.commit(self, 
-                                super(ProxyConnection, self)._commit_impl)
+            g = _exec_recursive(self, dispatch.on_commit,
+                    super(ProxyConnection, self)._commit_impl)
+            return g()
 
         def _savepoint_impl(self, name=None):
-            return proxy.savepoint(self, 
-                                super(ProxyConnection, self)._savepoint_impl,
-                                name=name)
+            g = _exec_recursive(self, dispatch.on_savepoint,
+                    super(ProxyConnection, self)._savepoint_impl)
+            return g(name=name)
 
         def _rollback_to_savepoint_impl(self, name, context):
-            return proxy.rollback_savepoint(self, 
-                        super(ProxyConnection,
-                                self)._rollback_to_savepoint_impl, 
-                                name, context)
+            g = _exec_recursive(self, dispatch.on_rollback_savepoint,
+                super(ProxyConnection, self)._rollback_to_savepoint_impl)
+            return g(name, context)
             
         def _release_savepoint_impl(self, name, context):
-            return proxy.release_savepoint(self
-                        super(ProxyConnection, self)._release_savepoint_impl
-                        name, context)
-
+            g = _exec_recursive(self, dispatch.on_release_savepoint
+                        super(ProxyConnection, self)._release_savepoint_impl)
+            return g(name, context)
+            
         def _begin_twophase_impl(self, xid):
-            return proxy.begin_twophase(self, 
-                        super(ProxyConnection, self)._begin_twophase_impl,
-                        xid)
+            g = _exec_recursive(self, dispatch.on_begin_twophase,
+                        super(ProxyConnection, self)._begin_twophase_impl)
+            return g(xid)
 
         def _prepare_twophase_impl(self, xid):
-            return proxy.prepare_twophase(self, 
-                        super(ProxyConnection, self)._prepare_twophase_impl
-                        xid)
+            g = _exec_recursive(self, dispatch.on_prepare_twophase,
+                        super(ProxyConnection, self)._prepare_twophase_impl)
+            return g(xid)
 
         def _rollback_twophase_impl(self, xid, is_prepared):
-            return proxy.rollback_twophase(self, 
-                        super(ProxyConnection, self)._rollback_twophase_impl
-                        xid, is_prepared)
+            g = _exec_recursive(self, dispatch.on_rollback_twophase,
+                        super(ProxyConnection, self)._rollback_twophase_impl)
+            return g(xid, is_prepared)
 
         def _commit_twophase_impl(self, xid, is_prepared):
-            return proxy.commit_twophase(self, 
-                        super(ProxyConnection, self)._commit_twophase_impl
-                        xid, is_prepared)
+            g = _exec_recursive(self, dispatch.on_commit_twophase,
+                        super(ProxyConnection, self)._commit_twophase_impl)
+            return g(xid, is_prepared)
 
     return ProxyConnection
 
index 5fcda0a6535558b4a87515312d25a91bd813719d..375023e2830a8dc8d07c1ce60e2d7513161b37cc 100644 (file)
@@ -13,7 +13,7 @@ from sqlalchemy import util
 def listen(fn, identifier, target, *args):
     """Listen for events, passing to fn."""
     
-    getattr(target.events, identifier).append(fn, target)
+    target.events.listen(target, fn, identifier)
 
 NO_RESULT = util.symbol('no_result')
 
@@ -30,6 +30,10 @@ class Events(object):
     def __init__(self, parent_cls):
         self.parent_cls = parent_cls
     
+    @classmethod
+    def listen(cls, target, fn, identifier):
+        getattr(target.events, identifier).append(fn, target)
+        
 
 class _ExecEvent(object):
     def exec_and_clear(self, *args, **kw):
index c7f3a1109ebcd9caf7480e8e2f28323641f7a712..4eaf4d4ad3e4ca93ece9c84a3591931f38f56d9b 100644 (file)
@@ -174,7 +174,25 @@ class ConnectionProxy(object):
     
     @classmethod
     def _adapt_listener(cls, self, listener):
-        pass
+        event.listen(listener.execute, 'on_execute', self)
+        def _adapt_cursor_execute(conn, execute, cursor, statement, 
+                                    parameters, context, executemany):
+            def _re_execute(cursor, statement, parameters, context):
+                return execute(cursor, statement, parameters, context, executemany)
+            return listener.cursor_execute(_re_execute, cursor, statement, 
+                                        parameters, context, executemany)
+        event.listen(_adapt_cursor_execute, 'on_cursor_execute', self)
+        event.listen(listener.begin, 'on_begin', self)
+        event.listen(listener.rollback, 'on_rollback', self)
+        event.listen(listener.commit, 'on_commit', self)
+        event.listen(listener.savepoint, 'on_savepoint', self)
+        event.listen(listener.rollback_savepoint, 'on_rollback_savepoint', self)
+        event.listen(listener.release_savepoint, 'on_release_savepoint', self)
+        event.listen(listener.begin_twophase, 'on_begin_twophase', self)
+        event.listen(listener.prepare_twophase, 'on_prepare_twophase', self)
+        event.listen(listener.rollback_twophase, 'on_rollback_twophase', self)
+        event.listen(listener.commit_twophase, 'on_commit_twophase', self)
+        
         
     def execute(self, conn, execute, clauseelement, *multiparams, **params):
         """Intercept high level execute() events."""