]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- moved out to on_before_execute, on_after_execute. not much option here,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 29 Aug 2010 15:22:46 +0000 (11:22 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 29 Aug 2010 15:22:46 +0000 (11:22 -0400)
need both forms, the wrapping thing is just silly
- fixed the listen() to not re-wrap continuously.

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/event.py
lib/sqlalchemy/interfaces.py
lib/sqlalchemy/test/assertsql.py
lib/sqlalchemy/test/engines.py
test/engine/test_execute.py

index 1a7b2faaf54874d522ca42bd6df96d2ce055c39f..70ad0191438261d35831b45097394bfae92ab983 100644 (file)
@@ -1551,16 +1551,23 @@ class EngineEvents(event.Events):
     
     @classmethod
     def listen(cls, fn, identifier, target):
-        if issubclass(target.Connection, Connection):
-            target.Connection = _proxy_connection_cls(
+        if target.Connection is Connection:
+            target.Connection = _listener_connection_cls(
                                         Connection, 
                                         target.dispatch)
         event.Events.listen(fn, identifier, target)
 
-    def on_execute(self, conn, clauseelement, *multiparams, **params):
+    def on_before_execute(self, conn, clauseelement, multiparams, params):
+        """Intercept high level execute() events."""
+
+    def on_after_execute(self, conn, clauseelement, multiparams, params, result):
         """Intercept high level execute() events."""
         
-    def on_cursor_execute(self, conn, cursor, statement, 
+    def on_before_cursor_execute(self, conn, cursor, statement, 
+                        parameters, context, executemany):
+        """Intercept low-level cursor execute() events."""
+
+    def on_after_cursor_execute(self, conn, cursor, statement, 
                         parameters, context, executemany):
         """Intercept low-level cursor execute() events."""
 
@@ -1845,100 +1852,126 @@ class Engine(Connectable, log.Identified):
 
         return self.pool.unique_connection()
 
-def _proxy_connection_cls(cls, dispatch):
-    class ProxyConnection(cls):
+def _listener_connection_cls(cls, dispatch):
+    """Produce a wrapper for :class:`.Connection` which will apply event 
+    dispatch to each method.
+    
+    :class:`.Connection` does not provide event dispatch built in so that
+    method call overhead is avoided in the absense of any listeners.
+    
+    """
+    class EventListenerConnection(cls):
         def execute(self, clauseelement, *multiparams, **params):
-            for fn in dispatch.on_execute:
-                result = fn(self, clauseelement, *multiparams, **params)
-                if result:
-                    clauseelement, multiparams, params = result
+            if dispatch.on_before_execute:
+                for fn in dispatch.on_before_execute:
+                    result = fn(self, clauseelement, multiparams, params)
+                    if result:
+                        clauseelement, multiparams, params = result
+            
+            ret = super(EventListenerConnection, self).execute(clauseelement, *multiparams, **params)
+
+            if dispatch.on_after_execute:
+                for fn in dispatch.on_after_execute:
+                    fn(self, clauseelement, multiparams, params, ret)
             
-            return super(ProxyConnection, self).execute(clauseelement, *multiparams, **params)
+            return ret
             
         def _execute_clauseelement(self, clauseelement, multiparams=None, params=None):
             return self.execute(clauseelement, *(multiparams or []), **(params or {}))
 
         def _cursor_execute(self, cursor, statement, 
                                     parameters, context=None):
-            for fn in dispatch.on_cursor_execute:
-                result = fn(self, cursor, statement, parameters, context, False)
-                if result:
-                    statement, parameters = result
+            if dispatch.on_before_cursor_execute:
+                for fn in dispatch.on_before_cursor_execute:
+                    result = fn(self, cursor, statement, parameters, context, False)
+                    if result:
+                        statement, parameters = result
             
-            return super(ProxyConnection, self).\
+            ret = super(EventListenerConnection, self).\
                         _cursor_execute(cursor, statement, parameters, context)
 
+            if dispatch.on_after_cursor_execute:
+                for fn in dispatch.on_after_cursor_execute:
+                    fn(self, cursor, statement, parameters, context, False)
+            
+            return ret
+            
         def _cursor_executemany(self, cursor, statement, 
                                     parameters, context=None):
-            for fn in dispatch.on_cursor_execute:
+            for fn in dispatch.on_before_cursor_execute:
                 result = fn(self, cursor, statement, parameters, context, True)
                 if result:
                     statement, parameters = result
 
-            return super(ProxyConnection, self).\
+            ret = super(EventListenerConnection, self).\
                         _cursor_executemany(cursor, statement, parameters, context)
-                
+
+            for fn in dispatch.on_after_cursor_execute:
+                fn(self, cursor, statement, parameters, context, True)
+            
+            return ret
+            
         def _begin_impl(self):
             for fn in dispatch.on_begin:
                 fn(self)
-            return super(ProxyConnection, self).\
+            return super(EventListenerConnection, self).\
                         _begin_impl()
             
         def _rollback_impl(self):
             for fn in dispatch.on_rollback:
                 fn(self)
-            return super(ProxyConnection, self).\
+            return super(EventListenerConnection, self).\
                         _rollback_impl()
 
         def _commit_impl(self):
             for fn in dispatch.on_commit:
                 fn(self)
-            return super(ProxyConnection, self).\
+            return super(EventListenerConnection, self).\
                         _commit_impl()
 
         def _savepoint_impl(self, name=None):
             for fn in dispatch.on_savepoint:
                 fn(self, name)
-            return super(ProxyConnection, self).\
+            return super(EventListenerConnection, self).\
                         _savepoint_impl(name=name)
                 
         def _rollback_to_savepoint_impl(self, name, context):
             for fn in dispatch.on_rollback_to_savepoint:
                 fn(self, name, context)
-            return super(ProxyConnection, self).\
+            return super(EventListenerConnection, self).\
                         _rollback_to_savepoint_impl(name, context)
             
         def _release_savepoint_impl(self, name, context):
             for fn in dispatch.on_release_savepoint:
                 fn(self, name, context)
-            return super(ProxyConnection, self).\
+            return super(EventListenerConnection, self).\
                         _release_savepoint_impl(name, context)
             
         def _begin_twophase_impl(self, xid):
             for fn in dispatch.on_begin_twophase:
                 fn(self, xid)
-            return super(ProxyConnection, self).\
+            return super(EventListenerConnection, self).\
                         _begin_twophase_impl(xid)
 
         def _prepare_twophase_impl(self, xid):
             for fn in dispatch.on_prepare_twophase:
                 fn(self, xid)
-            return super(ProxyConnection, self).\
+            return super(EventListenerConnection, self).\
                         _prepare_twophase_impl(xid)
 
         def _rollback_twophase_impl(self, xid, is_prepared):
             for fn in dispatch.on_rollback_twophase:
                 fn(self, xid)
-            return super(ProxyConnection, self).\
+            return super(EventListenerConnection, self).\
                         _rollback_twophase_impl(xid)
 
         def _commit_twophase_impl(self, xid, is_prepared):
             for fn in dispatch.on_commit_twophase:
                 fn(self, xid)
-            return super(ProxyConnection, self).\
+            return super(EventListenerConnection, self).\
                         _commit_twophase_impl(xid)
 
-    return ProxyConnection
+    return EventListenerConnection
 
 # This reconstructor is necessary so that pickles with the C extension or
 # without use the same Binary format.
index b6e687b7c97601d75b502f8948598075ce0f0a6c..c982afd6365dcae760efadb22fb35fa18cc249d5 100644 (file)
@@ -30,8 +30,8 @@ class TLConnection(base.Connection):
 class TLEvents(base.EngineEvents):
     @classmethod
     def listen(cls, fn, identifier, target):
-        if issubclass(target.TLConnection, TLConnection):
-            target.TLConnection = base._proxy_connection_cls(
+        if target.TLConnection is TLConnection:
+            target.TLConnection = base._listener_connection_cls(
                                         TLConnection, 
                                         target.dispatch)
         base.EngineEvents.listen(fn, identifier, target)
index 28ed7f56320e90f444fe824da5287bfcf41d1691..5448503b268dbf9bb8997c5816edc72c1a0e0f9a 100644 (file)
@@ -143,26 +143,11 @@ class _ListenerCollection(object):
             self(*args, **kw)
             self._exec_once = True
     
-    def exec_until_return(self, *args, **kw):
-        """Execute listeners for this event until
-        one returns a non-None value.
-
-        Returns the value, or None.
-        """
-
-        if self:
-            for fn in self:
-                r = fn(*args, **kw)
-                if r is not None:
-                    return r
-        return None
-
     def __call__(self, *args, **kw):
         """Execute this event."""
 
-        if self:
-            for fn in self:
-                fn(*args, **kw)
+        for fn in self:
+            fn(*args, **kw)
     
     # I'm not entirely thrilled about the overhead here,
     # but this allows class-level listeners to be added
index 1cceff0b4f79d8f7367ba01339d7f2c4b9d81833..2c16935ce47cec6eef9f9d8dad31add1df258954 100644 (file)
@@ -174,41 +174,69 @@ class ConnectionProxy(object):
     
     @classmethod
     def _adapt_listener(cls, self, listener):
-        
-        def adapt_execute(conn, clauseelement, *multiparams, **params):
+
+        def adapt_execute(conn, clauseelement, multiparams, params):
+
             def execute_wrapper(clauseelement, *multiparams, **params):
                 return clauseelement, multiparams, params
-            return listener.execute(conn, execute_wrapper, clauseelement, *multiparams, **params)
-            
-        event.listen(adapt_execute, 'on_execute', self)
+
+            return listener.execute(conn, execute_wrapper,
+                                    clauseelement, *multiparams,
+                                    **params)
+
+        event.listen(adapt_execute, 'on_before_execute', self)
 
         def adapt_cursor_execute(conn, cursor, statement, 
-                                    parameters, context, executemany):
-            def execute_wrapper(cursor, statement, parameters, context):
+                                parameters,context, executemany, ):
+
+            def execute_wrapper(
+                cursor,
+                statement,
+                parameters,
+                context,
+                ):
                 return statement, parameters
-            return listener.cursor_execute(execute_wrapper, cursor, statement, 
-                                        parameters, context, executemany)
-                                        
-        event.listen(adapt_cursor_execute, 'on_cursor_execute', self)
+
+            return listener.cursor_execute(
+                execute_wrapper,
+                cursor,
+                statement,
+                parameters,
+                context,
+                executemany,
+                )
+
+        event.listen(adapt_cursor_execute, 'on_before_cursor_execute',
+                     self)
 
         def do_nothing_callback(*arg, **kw):
             pass
-        
+
         def adapt_listener(fn):
+
             def go(conn, *arg, **kw):
                 fn(conn, do_nothing_callback, *arg, **kw)
+
             return util.update_wrapper(go, fn)
-            
+
         event.listen(adapt_listener(listener.begin), 'on_begin', self)
-        event.listen(adapt_listener(listener.rollback), 'on_rollback', self)
+        event.listen(adapt_listener(listener.rollback), 'on_rollback',
+                     self)
         event.listen(adapt_listener(listener.commit), 'on_commit', self)
-        event.listen(adapt_listener(listener.savepoint), 'on_savepoint', self)
-        event.listen(adapt_listener(listener.rollback_savepoint), 'on_rollback_savepoint', self)
-        event.listen(adapt_listener(listener.release_savepoint), 'on_release_savepoint', self)
-        event.listen(adapt_listener(listener.begin_twophase), 'on_begin_twophase', self)
-        event.listen(adapt_listener(listener.prepare_twophase), 'on_prepare_twophase', self)
-        event.listen(adapt_listener(listener.rollback_twophase), 'on_rollback_twophase', self)
-        event.listen(adapt_listener(listener.commit_twophase), 'on_commit_twophase', self)
+        event.listen(adapt_listener(listener.savepoint), 'on_savepoint'
+                     , self)
+        event.listen(adapt_listener(listener.rollback_savepoint),
+                     'on_rollback_savepoint', self)
+        event.listen(adapt_listener(listener.release_savepoint),
+                     'on_release_savepoint', self)
+        event.listen(adapt_listener(listener.begin_twophase),
+                     'on_begin_twophase', self)
+        event.listen(adapt_listener(listener.prepare_twophase),
+                     'on_prepare_twophase', self)
+        event.listen(adapt_listener(listener.rollback_twophase),
+                     'on_rollback_twophase', self)
+        event.listen(adapt_listener(listener.commit_twophase),
+                     'on_commit_twophase', self)
         
         
     def execute(self, conn, execute, clauseelement, *multiparams, **params):
index 11ad20e7751dae0fbe08a4800300f43de725bb8e..dee63a876da4c9bd26c4cdc52d9cd83081e4a60c 100644 (file)
@@ -273,10 +273,7 @@ class SQLAssert(object):
     def clear_rules(self):
         del self.rules
         
-    def execute(self, conn, clauseelement, *multiparams, **params):
-        # TODO: this doesn't work.   we need to execute before so that we know 
-        # what's happened with the parameters.
-        
+    def execute(self, conn, clauseelement, multiparams, params, result):
         if self.rules is not None:
             if not self.rules:
                 assert False, "All rules have been exhausted, but further statements remain"
@@ -287,7 +284,6 @@ class SQLAssert(object):
             
         
     def cursor_execute(self, conn, cursor, statement, parameters, context, executemany):
-        print "RECEIVE !", statement, parameters
         if self.rules:
             rule = self.rules[0]
             rule.process_cursor_execute(statement, parameters, context, executemany)
index 779f872646a8c289815422f2591963df20abc636..8b930175fe3a08af561c62ab0a811edae4f66fb1 100644 (file)
@@ -135,8 +135,8 @@ def testing_engine(url=None, options=None):
     options = options or config.db_opts
 
     engine = create_engine(url, **options)
-    event.listen(asserter.execute, 'on_execute', engine)
-    event.listen(asserter.cursor_execute, 'on_cursor_execute', engine)
+    event.listen(asserter.execute, 'on_after_execute', engine)
+    event.listen(asserter.cursor_execute, 'on_after_cursor_execute', engine)
     event.listen(testing_reaper.checkout, 'on_checkout', engine.pool)
     
     # may want to call this, results
index 2c6caf87f9f0063c91277c10d1fe8be715b06501..d85279981def457d8f9ecd07603a9cfb847a8efe 100644 (file)
@@ -305,13 +305,13 @@ class EngineEventsTest(TestBase):
                     break
 
     @testing.fails_on('firebird', 'Data type unknown')
-    def test_execute_events_raw(self):
+    def test_execute_events(self):
 
         stmts = []
         cursor_stmts = []
 
-        def execute(conn, clauseelement, *multiparams,
-                                                    **params ):
+        def execute(conn, clauseelement, multiparams,
+                                                    params ):
             stmts.append((str(clauseelement), params, multiparams))
 
         def cursor_execute(conn, cursor, statement, parameters, 
@@ -324,8 +324,8 @@ class EngineEventsTest(TestBase):
             engines.testing_engine(options=dict(implicit_returning=False,
                                    strategy='threadlocal'))
             ]:
-            event.listen(execute, 'on_execute', engine)
-            event.listen(cursor_execute, 'on_cursor_execute', engine)
+            event.listen(execute, 'on_before_execute', engine)
+            event.listen(cursor_execute, 'on_before_cursor_execute', engine)
             
             m = MetaData(engine)
             t1 = Table('t1', m, 
@@ -375,78 +375,7 @@ class EngineEventsTest(TestBase):
             self._assert_stmts(compiled, stmts)
             self._assert_stmts(cursor, cursor_stmts)
 
-    @testing.fails_on('firebird', 'Data type unknown')
-    def _broken_test_execute_events_generic(self):
-
-        stmts = []
-        cursor_stmts = []
-
-        def listen(event_name, args):
-            if event_name == 'on_execute':
-                clauseelement, params, multiparams = \
-                    args['clauseelement'], args['params'], args['multiparams']
-                stmts.append((str(clauseelement), params, multiparams))
-            elif event_name == 'on_cursor_execute':
-                statement, parameters = args['statement'], args['parameters']
-                cursor_stmts.append((str(statement), parameters, None))
-
-        for engine in [
-            engines.testing_engine(options=dict(implicit_returning=False)), 
-            engines.testing_engine(options=dict(implicit_returning=False,
-                                   strategy='threadlocal'))
-            ]:
-            event.listen(listen, 'on_execute', engine)
-            event.listen(listen, 'on_cursor_execute', engine)
-
-            m = MetaData(engine)
-            t1 = Table('t1', m, 
-                Column('c1', Integer, primary_key=True), 
-                Column('c2', String(50), default=func.lower('Foo'),
-                                            primary_key=True)
-            )
-            m.create_all()
-            try:
-                t1.insert().execute(c1=5, c2='some data')
-                t1.insert().execute(c1=6)
-                eq_(engine.execute('select * from t1').fetchall(), [(5,
-                    'some data'), (6, 'foo')])
-            finally:
-                m.drop_all()
-            engine.dispose()
-            compiled = [('CREATE TABLE t1', {}, None),
-                        ('INSERT INTO t1 (c1, c2)', {'c2': 'some data',
-                        'c1': 5}, None), ('INSERT INTO t1 (c1, c2)',
-                        {'c1': 6}, None), ('select * from t1', {},
-                        None), ('DROP TABLE t1', {}, None)]
-            if not testing.against('oracle+zxjdbc'):  # or engine.dialect.pr
-                                                      # eexecute_pk_sequence
-                                                      # s:
-                cursor = [
-                    ('CREATE TABLE t1', {}, ()),
-                    ('INSERT INTO t1 (c1, c2)', {'c2': 'some data', 'c1'
-                     : 5}, (5, 'some data')),
-                    ('SELECT lower', {'lower_2': 'Foo'}, ('Foo', )),
-                    ('INSERT INTO t1 (c1, c2)', {'c2': 'foo', 'c1': 6},
-                     (6, 'foo')),
-                    ('select * from t1', {}, ()),
-                    ('DROP TABLE t1', {}, ()),
-                    ]
-            else:
-                insert2_params = 6, 'Foo'
-                if testing.against('oracle+zxjdbc'):
-                    insert2_params += (ReturningParam(12), )
-                cursor = [('CREATE TABLE t1', {}, ()),
-                          ('INSERT INTO t1 (c1, c2)', {'c2': 'some data'
-                          , 'c1': 5}, (5, 'some data')),
-                          ('INSERT INTO t1 (c1, c2)', {'c1': 6,
-                          'lower_2': 'Foo'}, insert2_params),
-                          ('select * from t1', {}, ()), ('DROP TABLE t1'
-                          , {}, ())]  # bind param name 'lower_2' might
-                                      # be incorrect
-            self._assert_stmts(compiled, stmts)
-            self._assert_stmts(cursor, cursor_stmts)
-
-    def test_options_raw(self):
+    def test_options(self):
         track = []
         def on_execute(conn, *args, **kw):
             track.append('execute')
@@ -455,8 +384,8 @@ class EngineEventsTest(TestBase):
             track.append('cursor_execute')
             
         engine = engines.testing_engine()
-        event.listen(on_execute, 'on_execute', engine)
-        event.listen(on_cursor_execute, 'on_cursor_execute', engine)
+        event.listen(on_execute, 'on_before_execute', engine)
+        event.listen(on_cursor_execute, 'on_before_cursor_execute', engine)
         conn = engine.connect()
         c2 = conn.execution_options(foo='bar')
         eq_(c2._execution_options, {'foo':'bar'})
@@ -466,7 +395,7 @@ class EngineEventsTest(TestBase):
         eq_(track, ['execute', 'cursor_execute'])
 
 
-    def test_transactional_raw(self):
+    def test_transactional(self):
         track = []
         def tracker(name):
             def go(conn, *args, **kw):
@@ -474,8 +403,8 @@ class EngineEventsTest(TestBase):
             return go
             
         engine = engines.testing_engine()
-        event.listen(tracker('execute'), 'on_execute', engine)
-        event.listen(tracker('cursor_execute'), 'on_cursor_execute', engine)
+        event.listen(tracker('execute'), 'on_before_execute', engine)
+        event.listen(tracker('cursor_execute'), 'on_before_cursor_execute', engine)
         event.listen(tracker('begin'), 'on_begin', engine)
         event.listen(tracker('commit'), 'on_commit', engine)
         event.listen(tracker('rollback'), 'on_rollback', engine)
@@ -495,7 +424,7 @@ class EngineEventsTest(TestBase):
 
     @testing.requires.savepoints
     @testing.requires.two_phase_transactions
-    def test_transactional_advanced_raw(self):
+    def test_transactional_advanced(self):
         track = []
         def tracker(name):
             def go(conn, exec_, *args, **kw):