]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- rename EngineEvents to ConnectionEvents
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 Feb 2011 00:59:45 +0000 (19:59 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 Feb 2011 00:59:45 +0000 (19:59 -0500)
- simplify connection event model to be inline inside Connection, don't use ad-hoc
subclasses (technically would leak memory for the app that keeps creating engines
and adding events)
- not doing listen-per-connection yet.  this is closer.  overall things
are much simpler now (until we put listen-per-connection in...)

doc/build/core/events.rst
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/event.py
lib/sqlalchemy/events.py
test/engine/test_execute.py

index ffa0fe6251bdf000305df4c4ab51346c6db52edb..fe8a45de51919f4e692bad5a7d9c9421969d9101 100644 (file)
@@ -19,7 +19,7 @@ Connection Pool Events
 Connection Events
 -----------------------
 
-.. autoclass:: sqlalchemy.events.EngineEvents
+.. autoclass:: sqlalchemy.events.ConnectionEvents
     :members:
 
 Schema Events
index ae29ac40b4af60947ccfecabe97a677629c8a03c..9eb5a38ffcc4730bb3c0760dfa052be2336a0866 100644 (file)
@@ -847,6 +847,7 @@ class Connection(Connectable):
         self.__savepoint_seq = 0
         self.__branch = _branch
         self.__invalid = False
+        self._has_events = engine._has_events
         self._echo = self.engine._should_log_info()
         if _execution_options:
             self._execution_options =\
@@ -1107,6 +1108,10 @@ class Connection(Connectable):
     def _begin_impl(self):
         if self._echo:
             self.engine.logger.info("BEGIN (implicit)")
+
+        if self._has_events:
+            self.engine.dispatch.begin(self)
+
         try:
             self.engine.dialect.do_begin(self.connection)
         except Exception, e:
@@ -1114,6 +1119,9 @@ class Connection(Connectable):
             raise
 
     def _rollback_impl(self):
+        if self._has_events:
+            self.engine.dispatch.rollback(self)
+
         if not self.closed and not self.invalidated and \
                         self._connection_is_valid:
             if self._echo:
@@ -1128,6 +1136,9 @@ class Connection(Connectable):
             self.__transaction = None
 
     def _commit_impl(self):
+        if self._has_events:
+            self.engine.dispatch.commit(self)
+
         if self._echo:
             self.engine.logger.info("COMMIT")
         try:
@@ -1138,6 +1149,9 @@ class Connection(Connectable):
             raise
 
     def _savepoint_impl(self, name=None):
+        if self._has_events:
+            self.engine.dispatch.savepoint(self, name)
+
         if name is None:
             self.__savepoint_seq += 1
             name = 'sa_savepoint_%s' % self.__savepoint_seq
@@ -1146,31 +1160,49 @@ class Connection(Connectable):
             return name
 
     def _rollback_to_savepoint_impl(self, name, context):
+        if self._has_events:
+            self.engine.dispatch.rollback_savepoint(self, name, context)
+
         if self._connection_is_valid:
             self.engine.dialect.do_rollback_to_savepoint(self, name)
         self.__transaction = context
 
     def _release_savepoint_impl(self, name, context):
+        if self._has_events:
+            self.engine.dispatch.release_savepoint(self, name, context)
+
         if self._connection_is_valid:
             self.engine.dialect.do_release_savepoint(self, name)
         self.__transaction = context
 
     def _begin_twophase_impl(self, xid):
+        if self._has_events:
+            self.engine.dispatch.begin_twophase(self, xid)
+
         if self._connection_is_valid:
             self.engine.dialect.do_begin_twophase(self, xid)
 
     def _prepare_twophase_impl(self, xid):
+        if self._has_events:
+            self.engine.dispatch.prepare_twophase(self, xid)
+
         if self._connection_is_valid:
             assert isinstance(self.__transaction, TwoPhaseTransaction)
             self.engine.dialect.do_prepare_twophase(self, xid)
 
     def _rollback_twophase_impl(self, xid, is_prepared):
+        if self._has_events:
+            self.engine.dispatch.rollback_twophase(self, xid, is_prepared)
+
         if self._connection_is_valid:
             assert isinstance(self.__transaction, TwoPhaseTransaction)
             self.engine.dialect.do_rollback_twophase(self, xid, is_prepared)
         self.__transaction = None
 
     def _commit_twophase_impl(self, xid, is_prepared):
+        if self._has_events:
+            self.engine.dispatch.commit_twophase(self, xid, is_prepared)
+
         if self._connection_is_valid:
             assert isinstance(self.__transaction, TwoPhaseTransaction)
             self.engine.dialect.do_commit_twophase(self, xid, is_prepared)
@@ -1218,7 +1250,6 @@ class Connection(Connectable):
         * a :class:`.Compiled` object
 
         """
-
         for c in type(object).__mro__:
             if c in Connection.executors:
                 return Connection.executors[c](
@@ -1272,6 +1303,11 @@ class Connection(Connectable):
     def _execute_default(self, default, multiparams, params):
         """Execute a schema.ColumnDefault object."""
 
+        if self._has_events:
+            for fn in self.engine.dispatch.before_execute:
+                default, multiparams, params = \
+                    fn(self, default, multiparams, params)
+
         try:
             try:
                 conn = self.__connection
@@ -1288,83 +1324,121 @@ class Connection(Connectable):
         ret = ctx._exec_default(default, None)
         if self.should_close_with_result:
             self.close()
+
+        if self._has_events:
+            self.engine.dispatch.after_execute(self, 
+                default, multiparams, params, ret)
+
         return ret
 
-    def _execute_ddl(self, ddl, params, multiparams):
+    def _execute_ddl(self, ddl, multiparams, params):
         """Execute a schema.DDL object."""
 
+        if self._has_events:
+            for fn in self.engine.dispatch.before_execute:
+                ddl, multiparams, params = \
+                    fn(self, ddl, multiparams, params)
+
         dialect = self.dialect
 
         compiled = ddl.compile(dialect=dialect)
-        return self._execute_context(
+        ret = self._execute_context(
             dialect,
             dialect.execution_ctx_cls._init_ddl,
             compiled, 
             None,
             compiled
         )
+        if self._has_events:
+            self.engine.dispatch.after_execute(self, 
+                ddl, multiparams, params, ret)
+        return ret
 
     def _execute_clauseelement(self, elem, multiparams, params):
         """Execute a sql.ClauseElement object."""
 
-        params = self.__distill_params(multiparams, params)
-        if params:
-            keys = params[0].keys()
+        if self._has_events:
+            for fn in self.engine.dispatch.before_execute:
+                elem, multiparams, params = \
+                    fn(self, elem, multiparams, params)
+
+        distilled_params = self.__distill_params(multiparams, params)
+        if distilled_params:
+            keys = distilled_params[0].keys()
         else:
             keys = []
 
         dialect = self.dialect
         if 'compiled_cache' in self._execution_options:
-            key = dialect, elem, tuple(keys), len(params) > 1
+            key = dialect, elem, tuple(keys), len(distilled_params) > 1
             if key in self._execution_options['compiled_cache']:
                 compiled_sql = self._execution_options['compiled_cache'][key]
             else:
                 compiled_sql = elem.compile(
                                 dialect=dialect, column_keys=keys, 
-                                inline=len(params) > 1)
+                                inline=len(distilled_params) > 1)
                 self._execution_options['compiled_cache'][key] = compiled_sql
         else:
             compiled_sql = elem.compile(
                             dialect=dialect, column_keys=keys, 
-                            inline=len(params) > 1)
+                            inline=len(distilled_params) > 1)
 
 
-        return self._execute_context(
+        ret = self._execute_context(
             dialect,
             dialect.execution_ctx_cls._init_compiled,
             compiled_sql, 
-            params,
-            compiled_sql, params
+            distilled_params,
+            compiled_sql, distilled_params
         )
+        if self._has_events:
+            self.engine.dispatch.after_execute(self, 
+                elem, multiparams, params, ret)
+        return ret
 
     def _execute_compiled(self, compiled, multiparams, params):
         """Execute a sql.Compiled object."""
 
+        if self._has_events:
+            for fn in self.engine.dispatch.before_execute:
+                compiled, multiparams, params = \
+                    fn(self, compiled, multiparams, params)
+
         dialect = self.dialect
         parameters=self.__distill_params(multiparams, params)
-        return self._execute_context(
+        ret = self._execute_context(
             dialect,
             dialect.execution_ctx_cls._init_compiled,
             compiled, 
             parameters,
             compiled, parameters
         )
+        if self._has_events:
+            self.engine.dispatch.after_execute(self, 
+                compiled, multiparams, params, ret)
+        return ret
 
     def _execute_text(self, statement, multiparams, params):
         """Execute a string SQL statement."""
 
+        if self._has_events:
+            for fn in self.engine.dispatch.before_execute:
+                statement, multiparams, params = \
+                    fn(self, statement, multiparams, params)
+
         dialect = self.dialect
         parameters = self.__distill_params(multiparams, params)
-        return self._execute_context(
+        ret = self._execute_context(
             dialect,
             dialect.execution_ctx_cls._init_statement,
             statement, 
             parameters,
             statement, parameters
         )
-
-    _before_cursor_execute = None
-    _after_cursor_execute = None
+        if self._has_events:
+            self.engine.dispatch.after_execute(self, 
+                statement, multiparams, params, ret)
+        return ret
 
     def _execute_context(self, dialect, constructor, 
                                     statement, parameters, 
@@ -1395,12 +1469,11 @@ class Connection(Connectable):
         if not context.executemany:
             parameters = parameters[0]
 
-        if self._before_cursor_execute:
-            statement, parameters = self._before_cursor_execute(
-                                            context,
-                                            cursor, 
-                                            statement, 
-                                            parameters)
+        if self._has_events:
+            for fn in self.engine.dispatch.before_cursor_execute:
+                statement, parameters = \
+                            fn(self, cursor, statement, parameters, 
+                                        context, context.executemany)
 
         if self._echo:
             self.engine.logger.info(statement)
@@ -1428,9 +1501,12 @@ class Connection(Connectable):
             raise
 
 
-        if self._after_cursor_execute:
-            self._after_cursor_execute(context, cursor, 
-                                        statement, parameters)
+        if self._has_events:
+            self.engine.dispatch.after_cursor_execute(self, cursor, 
+                                                statement, 
+                                                parameters, 
+                                                context, 
+                                                context.executemany)
 
         if context.compiled:
             context.post_exec()
@@ -1757,6 +1833,7 @@ class Engine(Connectable, log.Identified):
     """
 
     _execution_options = util.immutabledict()
+    _has_events = False
     Connection = Connection
 
     def __init__(self, pool, dialect, url, 
@@ -1783,8 +1860,7 @@ class Engine(Connectable, log.Identified):
                 )
             self.update_execution_options(**execution_options)
 
-
-    dispatch = event.dispatcher(events.EngineEvents)
+    dispatch = event.dispatcher(events.ConnectionEvents)
 
     def update_execution_options(self, **opt):
         """update the execution_options dictionary of this :class:`Engine`.
@@ -2028,101 +2104,6 @@ class Engine(Connectable, log.Identified):
 
         return self.pool.unique_connection()
 
-def _listener_connection_cls(cls, dispatch):
-    """Produce a wrapper for :class:`.Connection` which will apply event 
-    dispatch to each method.
-
-    :class:`.Connection` does not provide event dispatch built in so that
-    method call overhead is avoided in the absense of any listeners.
-
-    """
-    class EventListenerConnection(cls):
-        def execute(self, clauseelement, *multiparams, **params):
-            for fn in dispatch.before_execute:
-                clauseelement, multiparams, params = \
-                    fn(self, clauseelement, multiparams, params)
-
-            ret = super(EventListenerConnection, self).\
-                    execute(clauseelement, *multiparams, **params)
-
-            for fn in dispatch.after_execute:
-                fn(self, clauseelement, multiparams, params, ret)
-
-            return ret
-
-        def _execute_clauseelement(self, clauseelement, 
-                                    multiparams=None, params=None):
-            return self.execute(clauseelement, 
-                                    *(multiparams or []), 
-                                    **(params or {}))
-
-        def _before_cursor_execute(self, context, cursor, 
-                                            statement, parameters):
-            for fn in dispatch.before_cursor_execute:
-                statement, parameters = \
-                            fn(self, cursor, statement, parameters, 
-                                        context, context.executemany)
-            return statement, parameters
-
-        def _after_cursor_execute(self, context, cursor, 
-                                            statement, parameters):
-            dispatch.after_cursor_execute(self, cursor, 
-                                                statement, 
-                                                parameters, 
-                                                context, 
-                                                context.executemany)
-
-        def _begin_impl(self):
-            dispatch.begin(self)
-            return super(EventListenerConnection, self).\
-                        _begin_impl()
-
-        def _rollback_impl(self):
-            dispatch.rollback(self)
-            return super(EventListenerConnection, self).\
-                        _rollback_impl()
-
-        def _commit_impl(self):
-            dispatch.commit(self)
-            return super(EventListenerConnection, self).\
-                        _commit_impl()
-
-        def _savepoint_impl(self, name=None):
-            dispatch.savepoint(self, name)
-            return super(EventListenerConnection, self).\
-                        _savepoint_impl(name=name)
-
-        def _rollback_to_savepoint_impl(self, name, context):
-            dispatch.rollback_savepoint(self, name, context)
-            return super(EventListenerConnection, self).\
-                        _rollback_to_savepoint_impl(name, context)
-
-        def _release_savepoint_impl(self, name, context):
-            dispatch.release_savepoint(self, name, context)
-            return super(EventListenerConnection, self).\
-                        _release_savepoint_impl(name, context)
-
-        def _begin_twophase_impl(self, xid):
-            dispatch.begin_twophase(self, xid)
-            return super(EventListenerConnection, self).\
-                        _begin_twophase_impl(xid)
-
-        def _prepare_twophase_impl(self, xid):
-            dispatch.prepare_twophase(self, xid)
-            return super(EventListenerConnection, self).\
-                        _prepare_twophase_impl(xid)
-
-        def _rollback_twophase_impl(self, xid, is_prepared):
-            dispatch.rollback_twophase(self, xid)
-            return super(EventListenerConnection, self).\
-                        _rollback_twophase_impl(xid, is_prepared)
-
-        def _commit_twophase_impl(self, xid, is_prepared):
-            dispatch.commit_twophase(self, xid, is_prepared)
-            return super(EventListenerConnection, self).\
-                        _commit_twophase_impl(xid, is_prepared)
-
-    return EventListenerConnection
 
 # This reconstructor is necessary so that pickles with the C extension or
 # without use the same Binary format.
index c2c21812bea34e658da631d3aa11fc347d21a424..eee19ee1d550e5b12c773bd801bf3e2737687a2e 100644 (file)
@@ -151,6 +151,13 @@ class DefaultEngineStrategy(EngineStrategy):
 
             def first_connect(dbapi_connection, connection_record):
                 c = base.Connection(engine, connection=dbapi_connection)
+
+                # TODO: removing this allows the on connect activities
+                # to generate events.  tests currently assume these aren't
+                # sent.  do we want users to get all the initial connect
+                # activities as events ?
+                c._has_events = False
+
                 dialect.initialize(c)
             event.listen(pool, 'first_connect', first_connect)
 
index d4d0488a986557a30b1d6c9a4a94069065bcad83..f3e84bef299b1808bfc8c914fb4f4eb6ffea2b56 100644 (file)
@@ -11,7 +11,7 @@ with :func:`~sqlalchemy.engine.create_engine`.  This module is semi-private and
 invoked automatically when the threadlocal engine strategy is used.
 """
 
-from sqlalchemy import util, event, events
+from sqlalchemy import util, event
 from sqlalchemy.engine import base
 import weakref
 
@@ -33,15 +33,6 @@ class TLConnection(base.Connection):
         self.__opencount = 0
         base.Connection.close(self)
 
-class TLEvents(events.EngineEvents):
-    @classmethod
-    def _listen(cls, target, identifier, fn):
-        if target.TLConnection is TLConnection:
-            target.TLConnection = base._listener_connection_cls(
-                                        TLConnection, 
-                                        target.dispatch)
-        events.EngineEvents._listen(target, identifier, fn)
-
 class TLEngine(base.Engine):
     """An Engine that includes support for thread-local managed transactions."""
 
@@ -51,7 +42,6 @@ class TLEngine(base.Engine):
         super(TLEngine, self).__init__(*args, **kwargs)
         self._connections = util.threading.local()
 
-    dispatch = event.dispatcher(TLEvents)
 
     def contextual_connect(self, **kw):
         if not hasattr(self._connections, 'conn'):
index 0fcc8ef499f8771174e18d2f1b00e511d11b1a19..7c2b49ce8a0001c9156cedf5f80d08984a27f233 100644 (file)
@@ -78,20 +78,17 @@ class _Dispatch(object):
         self._parent_cls = _parent_cls
 
     def __reduce__(self):
-
         return _UnpickleDispatch(), (self._parent_cls, )
 
-    @property
-    def _descriptors(self):
-        return (getattr(self, k) for k in dir(self) if _is_event_name(k))
-
     def _update(self, other, only_propagate=True):
         """Populate from the listeners in another :class:`_Dispatch`
             object."""
 
-        for ls in other._descriptors:
+        for ls in _event_descriptors(other):
             getattr(self, ls.name)._update(ls, only_propagate=only_propagate)
 
+def _event_descriptors(target):
+    return [getattr(target, k) for k in dir(target) if _is_event_name(k)]
 
 class _EventMeta(type):
     """Intercept new Event subclasses and create 
index fe9c5dda15a76cb7b889798d8aea25c14e2c016e..6435ff3f2ea1d14a15190a102d9becb137f66f96 100644 (file)
@@ -279,8 +279,8 @@ class PoolEvents(event.Events):
 
         """
 
-class EngineEvents(event.Events):
-    """Available events for :class:`.Engine`.
+class ConnectionEvents(event.Events):
+    """Available events for :class:`.Connection`.
 
     The methods here define the name of an event as well as the names of members that are passed to listener functions.
 
@@ -307,12 +307,7 @@ class EngineEvents(event.Events):
 
     @classmethod
     def _listen(cls, target, identifier, fn, retval=False):
-        from sqlalchemy.engine.base import Connection, \
-            _listener_connection_cls
-        if target.Connection is Connection:
-            target.Connection = _listener_connection_cls(
-                                        Connection, 
-                                        target.dispatch)
+        target._has_events = True
 
         if not retval:
             if identifier == 'before_execute':
index 01a0100abdeb37ebbed98b31f6e76b829f56e284..44a9316ccd621564fd3fe4326688713d449407b2 100644 (file)
@@ -1,8 +1,8 @@
-from test.lib.testing import eq_, assert_raises, assert_raises_message
+from test.lib.testing import eq_, assert_raises, assert_raises_message, config
 import re
 from sqlalchemy.interfaces import ConnectionProxy
 from sqlalchemy import MetaData, Integer, String, INT, VARCHAR, func, \
-    bindparam, select, event, TypeDecorator
+    bindparam, select, event, TypeDecorator, create_engine
 from sqlalchemy.sql import column, literal
 from test.lib.schema import Table, Column
 import sqlalchemy as tsa
@@ -10,6 +10,7 @@ from test.lib import TestBase, testing, engines
 import logging
 from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam
 from sqlalchemy.engine import base, default
+from sqlalchemy.engine.base import Connection, Engine
 
 users, metadata = None, None
 class ExecuteTest(TestBase):
@@ -514,6 +515,8 @@ class AlternateResultProxyTest(TestBase):
         self._test_proxy(base.BufferedColumnResultProxy)
 
 class EngineEventsTest(TestBase):
+    def tearDown(self):
+        Engine.dispatch._clear()
 
     def _assert_stmts(self, expected, received):
         for stmt, params, posn in expected:
@@ -528,6 +531,64 @@ class EngineEventsTest(TestBase):
                         == params or testparams == posn):
                     break
 
+    def test_per_engine_independence(self):
+        e1 = create_engine(config.db_url)
+        e2 = create_engine(config.db_url)
+
+        canary = []
+        def before_exec(conn, stmt, *arg):
+            canary.append(stmt)
+        event.listen(e1, "before_execute", before_exec)
+        s1 = select([1])
+        s2 = select([2])
+        e1.execute(s1)
+        e2.execute(s2)
+        eq_(canary, [s1])
+        event.listen(e2, "before_execute", before_exec)
+        e1.execute(s1)
+        e2.execute(s2)
+        eq_(canary, [s1, s1, s2])
+
+    def test_per_engine_plus_global(self):
+        canary = []
+        def be1(conn, stmt, *arg):
+            canary.append('be1')
+        def be2(conn, stmt, *arg):
+            canary.append('be2')
+        def be3(conn, stmt, *arg):
+            canary.append('be3')
+
+        event.listen(Engine, "before_execute", be1)
+        e1 = create_engine(config.db_url)
+        e2 = create_engine(config.db_url)
+
+        event.listen(e1, "before_execute", be2)
+
+        event.listen(Engine, "before_execute", be3)
+        e1.connect()
+        e2.connect()
+        canary[:] = []
+        e1.execute(select([1]))
+        e2.execute(select([1]))
+
+        eq_(canary, ['be1', 'be3', 'be2', 'be1', 'be3'])
+
+    def test_argument_format_execute(self):
+        def before_execute(conn, clauseelement, multiparams, params):
+            assert isinstance(multiparams, (list, tuple))
+            assert isinstance(params, dict)
+        def after_execute(conn, clauseelement, multiparams, params, result):
+            assert isinstance(multiparams, (list, tuple))
+            assert isinstance(params, dict)
+        e1 = create_engine(config.db_url)
+        event.listen(e1, 'before_execute', before_execute)
+        event.listen(e1, 'after_execute', after_execute)
+
+        e1.execute(select([1]))
+        e1.execute(select([1]).compile(dialect=e1.dialect).statement)
+        e1.execute(select([1]).compile(dialect=e1.dialect))
+        e1._execute_compiled(select([1]).compile(dialect=e1.dialect), [], {})
+
     @testing.fails_on('firebird', 'Data type unknown')
     def test_execute_events(self):
 
@@ -648,8 +709,6 @@ class EngineEventsTest(TestBase):
             canary, ['execute', 'cursor_execute']
         )
 
-
-
     def test_transactional(self):
         canary = []
         def tracker(name):