]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- pretty much all tests passing, maybe some callcounts are off
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Jul 2010 17:08:39 +0000 (13:08 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Jul 2010 17:08:39 +0000 (13:08 -0400)
- test suite adjusted to use engine/pool events and not listeners
- deprecation warnings

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/event.py
lib/sqlalchemy/pool.py
lib/sqlalchemy/test/assertsql.py
lib/sqlalchemy/test/engines.py
test/aaa_profiling/test_orm.py
test/aaa_profiling/test_pool.py
test/engine/test_execute.py
test/engine/test_pool.py
test/orm/test_merge.py

index a5f99022f93eecacfc34218c9476801037d79726..adba6fa47fa303526ffdaa94bed6a027718381af 100644 (file)
@@ -1573,18 +1573,19 @@ class Engine(Connectable, log.Identified):
         self.engine = self
         self.logger = log.instance_logger(self, echoflag=echo)
         if proxy:
+#            util.warn_deprecated("The 'proxy' argument to create_engine() is deprecated.  Use event.listen().")
             interfaces.ConnectionProxy._adapt_listener(self, proxy)
         if execution_options:
             self.update_execution_options(**execution_options)
 
     class events(event.Events):
         @classmethod
-        def listen(cls, target, fn, identifier):
+        def listen(cls, fn, identifier, target):
             if issubclass(target.Connection, Connection):
                 target.Connection = _proxy_connection_cls(
                                             Connection, 
                                             target.events)
-            event.Events.listen(target, fn, identifier)
+            event.Events.listen(fn, identifier, target)
             
         def on_execute(self, conn, execute, clauseelement, *multiparams, **params):
             """Intercept high level execute() events."""
index 20393a5b3f871c9feba453fc3d97f31b82b2b4b1..785c6e96ad58122f93eaecb21bc9c84dc2860677 100644 (file)
@@ -5,7 +5,7 @@ with :func:`~sqlalchemy.engine.create_engine`.  This module is semi-private and
 invoked automatically when the threadlocal engine strategy is used.
 """
 
-from sqlalchemy import util
+from sqlalchemy import util, event
 from sqlalchemy.engine import base
 import weakref
 
@@ -32,17 +32,21 @@ class TLEngine(base.Engine):
     """An Engine that includes support for thread-local managed transactions."""
 
     TLConnection = TLConnection
-    # TODO
-    #_dispatch = event.dispatcher(_TLEngineDispatch)
 
     def __init__(self, *args, **kwargs):
         super(TLEngine, self).__init__(*args, **kwargs)
         self._connections = util.threading.local()
-        
-        # dont have to deal with proxy here, the
-        # superclass constructor + class level 
-        # _dispatch handles it
-        
+
+    class events(base.Engine.events):
+        @classmethod
+        def listen(cls, fn, identifier, target):
+            if issubclass(target.TLConnection, TLConnection):
+                target.TLConnection = base._proxy_connection_cls(
+                                            TLConnection, 
+                                            target.events)
+            base.Engine.events.listen(fn, identifier, target)
+    events = event.dispatcher(events)
+    
     def contextual_connect(self, **kw):
         if not hasattr(self._connections, 'conn'):
             connection = None
index 375023e2830a8dc8d07c1ce60e2d7513161b37cc..18dd5348f43b72dee51745c92ac50ed7f6497cca 100644 (file)
@@ -13,7 +13,7 @@ from sqlalchemy import util
 def listen(fn, identifier, target, *args):
     """Listen for events, passing to fn."""
     
-    target.events.listen(target, fn, identifier)
+    target.events.listen(fn, identifier, target, *args)
 
 NO_RESULT = util.symbol('no_result')
 
@@ -31,16 +31,31 @@ class Events(object):
         self.parent_cls = parent_cls
     
     @classmethod
-    def listen(cls, target, fn, identifier):
+    def listen(cls, fn, identifier, target):
         getattr(target.events, identifier).append(fn, target)
+    
+    @property
+    def events(self):
+        """Iterate the Listeners objects."""
+        
+        return (getattr(self, k) for k in dir(self) if k.startswith("on_"))
         
+    def update(self, other):
+        """Populate from the listeners in another :class:`Events` object."""
+
+        for ls in other.events:
+            getattr(self, ls.name).extend(ls)
 
 class _ExecEvent(object):
-    def exec_and_clear(self, *args, **kw):
-        """Execute this event once, then clear all listeners."""
+    _exec_once = False
+    
+    def exec_once(self, *args, **kw):
+        """Execute this event, but only if it has not been
+        executed already for this collection."""
         
-        self(*args, **kw)
-        self[:] = []
+        if not self._exec_once:
+            self(*args, **kw)
+            self._exec_once = True
     
     def exec_until_return(self, *args, **kw):
         """Execute listeners for this event until
@@ -74,12 +89,13 @@ class EventDescriptor(object):
         self._clslevel = []
     
     def append(self, obj, target):
+        assert isinstance(target, type), "Class-level Event targets must be classes."
         self._clslevel.append((obj, target))
     
     def __get__(self, obj, cls):
         if obj is None:
             return self
-        obj.__dict__[self.__name__] = result = Listeners()
+        obj.__dict__[self.__name__] = result = Listeners(self.__name__)
         result.extend([
             fn for fn, target in 
             self._clslevel
@@ -91,6 +107,9 @@ class Listeners(_ExecEvent, list):
     """Represent a collection of listeners linked
     to an instance of :class:`Events`."""
     
+    def __init__(self, name):
+        self.name = name
+        
     def append(self, obj, target):
         list.append(self, obj)
 
index aa8d362f8ccb9f6c236217ddc36dfe1393aeed85..9574d28da690aa3724d7e7b9db5359c596ee5f54 100644 (file)
@@ -125,8 +125,11 @@ class Pool(log.Identified):
         self._reset_on_return = reset_on_return
         self.echo = echo
         if _dispatch:
-            self.events = _dispatch
+            self.events.update(_dispatch)
         if listeners:
+            util.warn_deprecated(
+                        "The 'listeners' argument to Pool (and "
+                        "create_engine()) is deprecated.  Use event.listen().")
             for l in listeners:
                 self.add_listener(l)
 
@@ -203,7 +206,7 @@ class Pool(log.Identified):
             """
     events = event.dispatcher(events)
         
-    @util.deprecated("Use event.listen()")
+    @util.deprecated("Pool.add_listener() is deprecated.  Use event.listen()")
     def add_listener(self, listener):
         """Add a ``PoolListener``-like object to this pool.
         
@@ -275,7 +278,7 @@ class _ConnectionRecord(object):
         self.connection = self.__connect()
         self.info = {}
 
-        pool.events.on_first_connect.exec_and_clear(self.connection, self)
+        pool.events.on_first_connect.exec_once(self.connection, self)
         pool.events.on_connect(self.connection, self)
 
     def close(self):
@@ -305,7 +308,7 @@ class _ConnectionRecord(object):
             self.connection = self.__connect()
             self.info.clear()
             if self.__pool.events.on_connect:
-                self.__pool.events.on_connect(self.connection, con_record)
+                self.__pool.events.on_connect(self.connection, self)
         elif self.__pool._recycle > -1 and \
                 time.time() - self.starttime > self.__pool._recycle:
             self.__pool.logger.info(
@@ -315,7 +318,7 @@ class _ConnectionRecord(object):
             self.connection = self.__connect()
             self.info.clear()
             if self.__pool.events.on_connect:
-                self.__pool.events.on_connect(self.connection, con_record)
+                self.__pool.events.on_connect(self.connection, self)
         return self.connection
 
     def __close(self):
index a044f9d02fbef8d610fbed41032e100eb9949e39..a389c81f845ee1265d9fcf7d64f0f21dfebc511b 100644 (file)
@@ -255,7 +255,7 @@ def _process_assertion_statement(query, context):
 
     return query
 
-class SQLAssert(ConnectionProxy):
+class SQLAssert(object):
     rules = None
     
     def add_rules(self, rules):
@@ -282,8 +282,8 @@ class SQLAssert(ConnectionProxy):
             
         return result
         
-    def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
-        result = execute(cursor, statement, parameters, context)
+    def cursor_execute(self, conn, execute, cursor, statement, parameters, context, executemany):
+        result = execute(cursor, statement, parameters, context, executemany)
         
         if self.rules:
             rule = self.rules[0]
index 9e77f38d718d29a4fcbba737d31766b72a5ecba7..779f872646a8c289815422f2591963df20abc636 100644 (file)
@@ -2,6 +2,7 @@ import sys, types, weakref
 from collections import deque
 import config
 from sqlalchemy.util import function_named, callable
+from sqlalchemy import event
 import re
 import warnings
 
@@ -133,12 +134,10 @@ def testing_engine(url=None, options=None):
     url = url or config.db_url
     options = options or config.db_opts
 
-    options.setdefault('proxy', asserter)
-    
-    listeners = options.setdefault('listeners', [])
-    listeners.append(testing_reaper)
-
     engine = create_engine(url, **options)
+    event.listen(asserter.execute, 'on_execute', engine)
+    event.listen(asserter.cursor_execute, 'on_cursor_execute', engine)
+    event.listen(testing_reaper.checkout, 'on_checkout', engine.pool)
     
     # may want to call this, results
     # in first-connect initializers
index f2b876837c4967bffb833743b8e330fbd69c1419..4f94be79c9426a4ae8f325c6e483f768988fba22 100644 (file)
@@ -53,7 +53,7 @@ class MergeTest(_base.MappedTest):
         # down from 185 on this this is a small slice of a usually
         # bigger operation so using a small variance
 
-        @profiling.function_call_count(95, variance=0.001,
+        @profiling.function_call_count(93, variance=0.001,
                 versions={'2.4': 67, '3': 96})
         def go():
             return sess2.merge(p1, load=False)
index bc3c12d57f17a9d62d3dc5d131413efd8baf10dc..f99af50656ccd916c0055d1dba527aa0741e4371 100644 (file)
@@ -18,8 +18,8 @@ class QueuePoolTest(TestBase, AssertsExecutionResults):
                          use_threadlocal=True)
 
 
-    @profiling.function_call_count(64, {'2.4': 42, '2.7':59
-                                            '2.7+cextension':59,
+    @profiling.function_call_count(64, {'2.4': 42, '2.7':75
+                                            '2.7+cextension':75,
                                             '3.0':65, '3.1':65},
                                             variance=.10)
     def test_first_connect(self):
index 47879ece9e078bc710083e5e2ffefd9a9393059d..6e6069f04be8879bb031026addc5628eab6fdd95 100644 (file)
@@ -2,7 +2,7 @@ from sqlalchemy.test.testing import eq_, assert_raises
 import re
 from sqlalchemy.interfaces import ConnectionProxy
 from sqlalchemy import MetaData, Integer, String, INT, VARCHAR, func, \
-    bindparam, select
+    bindparam, select, event
 from sqlalchemy.test.schema import Table, Column
 import sqlalchemy as tsa
 from sqlalchemy.test import TestBase, testing, engines
@@ -288,9 +288,189 @@ class ResultProxyTest(TestBase):
             assert_raises(AssertionError, t.delete().execute)
         finally:
             engine.dialect.execution_ctx_cls = execution_ctx_cls
+
+class EngineEventsTest(TestBase):
+
+    @testing.fails_on('firebird', 'Data type unknown')
+    def test_execute_events(self):
+
+        stmts = []
+        cursor_stmts = []
+
+        def execute(conn, execute, clauseelement, *multiparams,
+                                                    **params ):
+            stmts.append((str(clauseelement), params, multiparams))
+            return execute(clauseelement, *multiparams, **params)
+
+        def cursor_execute(conn, execute, cursor, statement, parameters, 
+                                context, executemany):
+            cursor_stmts.append((str(statement), parameters, None))
+            return execute(cursor, statement, parameters, context, executemany)
+
+        def assert_stmts(expected, received):
+            for stmt, params, posn in expected:
+                if not received:
+                    assert False
+                while received:
+                    teststmt, testparams, testmultiparams = \
+                        received.pop(0)
+                    teststmt = re.compile(r'[\n\t ]+', re.M).sub(' ',
+                            teststmt).strip()
+                    if teststmt.startswith(stmt) and (testparams
+                            == params or testparams == posn):
+                        break
+
+        for engine in \
+            engines.testing_engine(options=dict(implicit_returning=False)), \
+            engines.testing_engine(options=dict(implicit_returning=False,
+                                   strategy='threadlocal')):
+            event.listen(execute, 'on_execute', engine)
+            event.listen(cursor_execute, 'on_cursor_execute', engine)
+            
+            m = MetaData(engine)
+            t1 = Table('t1', m, 
+                Column('c1', Integer, primary_key=True), 
+                Column('c2', String(50), default=func.lower('Foo'),
+                                            primary_key=True)
+            )
+            m.create_all()
+            try:
+                t1.insert().execute(c1=5, c2='some data')
+                t1.insert().execute(c1=6)
+                eq_(engine.execute('select * from t1').fetchall(), [(5,
+                    'some data'), (6, 'foo')])
+            finally:
+                m.drop_all()
+            engine.dispose()
+            compiled = [('CREATE TABLE t1', {}, None),
+                        ('INSERT INTO t1 (c1, c2)', {'c2': 'some data',
+                        'c1': 5}, None), ('INSERT INTO t1 (c1, c2)',
+                        {'c1': 6}, None), ('select * from t1', {},
+                        None), ('DROP TABLE t1', {}, None)]
+            if not testing.against('oracle+zxjdbc'):  # or engine.dialect.pr
+                                                      # eexecute_pk_sequence
+                                                      # s:
+                cursor = [
+                    ('CREATE TABLE t1', {}, ()),
+                    ('INSERT INTO t1 (c1, c2)', {'c2': 'some data', 'c1'
+                     : 5}, (5, 'some data')),
+                    ('SELECT lower', {'lower_2': 'Foo'}, ('Foo', )),
+                    ('INSERT INTO t1 (c1, c2)', {'c2': 'foo', 'c1': 6},
+                     (6, 'foo')),
+                    ('select * from t1', {}, ()),
+                    ('DROP TABLE t1', {}, ()),
+                    ]
+            else:
+                insert2_params = 6, 'Foo'
+                if testing.against('oracle+zxjdbc'):
+                    insert2_params += (ReturningParam(12), )
+                cursor = [('CREATE TABLE t1', {}, ()),
+                          ('INSERT INTO t1 (c1, c2)', {'c2': 'some data'
+                          , 'c1': 5}, (5, 'some data')),
+                          ('INSERT INTO t1 (c1, c2)', {'c1': 6,
+                          'lower_2': 'Foo'}, insert2_params),
+                          ('select * from t1', {}, ()), ('DROP TABLE t1'
+                          , {}, ())]  # bind param name 'lower_2' might
+                                      # be incorrect
+            assert_stmts(compiled, stmts)
+            assert_stmts(cursor, cursor_stmts)
+
+    def test_options(self):
+        track = []
+        def on_execute(conn, exec_, *args, **kw):
+            track.append('execute')
+            return exec_(*args, **kw)
+            
+        def on_cursor_execute(conn, exec_, *args, **kw):
+            track.append('cursor_execute')
+            return exec_(*args, **kw)
+            
+        engine = engines.testing_engine()
+        event.listen(on_execute, 'on_execute', engine)
+        event.listen(on_cursor_execute, 'on_cursor_execute', engine)
+        conn = engine.connect()
+        c2 = conn.execution_options(foo='bar')
+        eq_(c2._execution_options, {'foo':'bar'})
+        c2.execute(select([1]))
+        c3 = c2.execution_options(bar='bat')
+        eq_(c3._execution_options, {'foo':'bar', 'bar':'bat'})
+        eq_(track, ['execute', 'cursor_execute'])
+
+
+    def test_transactional(self):
+        track = []
+        def tracker(name):
+            def go(conn, exec_, *args, **kw):
+                track.append(name)
+                return exec_(*args, **kw)
+            return go
+            
+        engine = engines.testing_engine()
+        event.listen(tracker('execute'), 'on_execute', engine)
+        event.listen(tracker('cursor_execute'), 'on_cursor_execute', engine)
+        event.listen(tracker('begin'), 'on_begin', engine)
+        event.listen(tracker('commit'), 'on_commit', engine)
+        event.listen(tracker('rollback'), 'on_rollback', engine)
         
-class ProxyConnectionTest(TestBase):
+        conn = engine.connect()
+        trans = conn.begin()
+        conn.execute(select([1]))
+        trans.rollback()
+        trans = conn.begin()
+        conn.execute(select([1]))
+        trans.commit()
+
+        eq_(track, [
+            'begin', 'execute', 'cursor_execute', 'rollback',
+            'begin', 'execute', 'cursor_execute', 'commit',
+            ])
 
+    @testing.requires.savepoints
+    @testing.requires.two_phase_transactions
+    def test_transactional_advanced(self):
+        track = []
+        def tracker(name):
+            def go(conn, exec_, *args, **kw):
+                track.append(name)
+                return exec_(*args, **kw)
+            return go
+            
+        engine = engines.testing_engine()
+        for name in ['begin', 'savepoint', 
+                    'rollback_savepoint', 'release_savepoint',
+                    'rollback', 'begin_twophase', 
+                       'prepare_twophase', 'commit_twophase']:
+            event.listen(tracker(name), 'on_%s' % name, engine)
+
+        conn = engine.connect()
+
+        trans = conn.begin()
+        trans2 = conn.begin_nested()
+        conn.execute(select([1]))
+        trans2.rollback()
+        trans2 = conn.begin_nested()
+        conn.execute(select([1]))
+        trans2.commit()
+        trans.rollback()
+
+        trans = conn.begin_twophase()
+        conn.execute(select([1]))
+        trans.prepare()
+        trans.commit()
+
+        eq_(track, ['begin', 'savepoint', 
+                    'rollback_savepoint', 'savepoint', 'release_savepoint',
+                    'rollback', 'begin_twophase', 
+                       'prepare_twophase', 'commit_twophase']
+        )
+        
+class ProxyConnectionTest(TestBase):
+    """These are the same tests as EngineEventsTest, except using
+    the deprecated ConnectionProxy interface.
+    
+    """
+    
+    @testing.uses_deprecated(r'.*Use event.listen')
     @testing.fails_on('firebird', 'Data type unknown')
     def test_proxy(self):
         
@@ -388,6 +568,7 @@ class ProxyConnectionTest(TestBase):
             assert_stmts(compiled, stmts)
             assert_stmts(cursor, cursor_stmts)
     
+    @testing.uses_deprecated(r'.*Use event.listen')
     def test_options(self):
         track = []
         class TrackProxy(ConnectionProxy):
@@ -407,6 +588,7 @@ class ProxyConnectionTest(TestBase):
         eq_(track, ['execute', 'cursor_execute'])
         
         
+    @testing.uses_deprecated(r'.*Use event.listen')
     def test_transactional(self):
         track = []
         class TrackProxy(ConnectionProxy):
@@ -427,16 +609,11 @@ class ProxyConnectionTest(TestBase):
         trans.commit()
         
         eq_(track, [
-            'begin',
-            'execute',
-            'cursor_execute',
-            'rollback',
-            'begin',
-            'execute',
-            'cursor_execute',
-            'commit',
+            'begin', 'execute', 'cursor_execute', 'rollback',
+            'begin', 'execute', 'cursor_execute', 'commit',
             ])
         
+    @testing.uses_deprecated(r'.*Use event.listen')
     @testing.requires.savepoints
     @testing.requires.two_phase_transactions
     def test_transactional_advanced(self):
index 9db65d2ab83874d09f873ec17279a1621c529daa..c9cd6bdd44a4bbe48e7a4e23a80995d6257f3a46 100644 (file)
@@ -1,5 +1,5 @@
 import threading, time
-from sqlalchemy import pool, interfaces, create_engine, select
+from sqlalchemy import pool, interfaces, create_engine, select, event
 import sqlalchemy as tsa
 from sqlalchemy.test import TestBase, testing
 from sqlalchemy.test.util import gc_collect, lazy_gc
@@ -186,7 +186,8 @@ class PoolTest(PoolTestBase):
         self.assert_(c.connection is not c2.connection)
         self.assert_(not c2.info)
         self.assert_('foo2' in c.info)
-
+    
+    @testing.uses_deprecated(r".*Use event.listen")
     def test_listeners(self):
         dbapi = MockDBAPI()
 
@@ -260,11 +261,10 @@ class PoolTest(PoolTestBase):
 
         def assert_listeners(p, total, conn, fconn, cout, cin):
             for instance in (p, p.recreate()):
-                self.assert_(len(instance.listeners) == total)
-                self.assert_(len(instance._on_connect) == conn)
-                self.assert_(len(instance._on_first_connect) == fconn)
-                self.assert_(len(instance._on_checkout) == cout)
-                self.assert_(len(instance._on_checkin) == cin)
+                self.assert_(len(instance.events.on_connect) == conn)
+                self.assert_(len(instance.events.on_first_connect) == fconn)
+                self.assert_(len(instance.events.on_checkout) == cout)
+                self.assert_(len(instance.events.on_checkin) == cin)
 
         p = _pool()
         assert_listeners(p, 0, 0, 0, 0, 0)
@@ -368,6 +368,7 @@ class PoolTest(PoolTestBase):
         c.close()
         snoop.assert_total(1, 1, 2, 2)
     
+    @testing.uses_deprecated(r".*Use event.listen")
     def test_listeners_callables(self):
         dbapi = MockDBAPI()
 
@@ -391,10 +392,9 @@ class PoolTest(PoolTestBase):
 
             def assert_listeners(p, total, conn, cout, cin):
                 for instance in (p, p.recreate()):
-                    self.assert_(len(instance.listeners) == total)
-                    self.assert_(len(instance._on_connect) == conn)
-                    self.assert_(len(instance._on_checkout) == cout)
-                    self.assert_(len(instance._on_checkin) == cin)
+                    self.assert_(len(instance.events.on_connect) == conn)
+                    self.assert_(len(instance.events.on_checkout) == cout)
+                    self.assert_(len(instance.events.on_checkin) == cin)
 
             p = _pool()
             assert_listeners(p, 0, 0, 0, 0)
@@ -431,9 +431,8 @@ class PoolTest(PoolTestBase):
         called = []
         def listener(*args):
             called.append(True)
-        listener.connect = listener
         engine = create_engine(testing.db.url)
-        engine.pool.add_listener(listener)
+        event.listen(listener, 'on_connect', engine.pool)
         engine.execute(select([1])).close()
         assert called, "Listener not called on connect"
 
index d63d7e086ec0b382cf4f5189944e8d73280f45fb..ad103a8267c1685d9061d39fee16f95336137cad 100644 (file)
@@ -24,7 +24,7 @@ class MergeTest(_fixtures.FixtureTest):
             canary.called = 0
 
         manager = sa.orm.attributes.manager_of_class(cls)
-        manager.events.add_listener('on_load', canary)
+        manager.events.listen(canary, 'on_load', manager)
 
         return canary