self.engine = self
self.logger = log.instance_logger(self, echoflag=echo)
if proxy:
+# util.warn_deprecated("The 'proxy' argument to create_engine() is deprecated. Use event.listen().")
interfaces.ConnectionProxy._adapt_listener(self, proxy)
if execution_options:
self.update_execution_options(**execution_options)
class events(event.Events):
@classmethod
- def listen(cls, target, fn, identifier):
+ def listen(cls, fn, identifier, target):
if issubclass(target.Connection, Connection):
target.Connection = _proxy_connection_cls(
Connection,
target.events)
- event.Events.listen(target, fn, identifier)
+ event.Events.listen(fn, identifier, target)
def on_execute(self, conn, execute, clauseelement, *multiparams, **params):
"""Intercept high level execute() events."""
invoked automatically when the threadlocal engine strategy is used.
"""
-from sqlalchemy import util
+from sqlalchemy import util, event
from sqlalchemy.engine import base
import weakref
"""An Engine that includes support for thread-local managed transactions."""
TLConnection = TLConnection
- # TODO
- #_dispatch = event.dispatcher(_TLEngineDispatch)
def __init__(self, *args, **kwargs):
super(TLEngine, self).__init__(*args, **kwargs)
self._connections = util.threading.local()
-
- # dont have to deal with proxy here, the
- # superclass constructor + class level
- # _dispatch handles it
-
+
+ class events(base.Engine.events):
+ @classmethod
+ def listen(cls, fn, identifier, target):
+ if issubclass(target.TLConnection, TLConnection):
+ target.TLConnection = base._proxy_connection_cls(
+ TLConnection,
+ target.events)
+ base.Engine.events.listen(fn, identifier, target)
+ events = event.dispatcher(events)
+
def contextual_connect(self, **kw):
if not hasattr(self._connections, 'conn'):
connection = None
def listen(fn, identifier, target, *args):
"""Listen for events, passing to fn."""
- target.events.listen(target, fn, identifier)
+ target.events.listen(fn, identifier, target, *args)
NO_RESULT = util.symbol('no_result')
self.parent_cls = parent_cls
@classmethod
- def listen(cls, target, fn, identifier):
+ def listen(cls, fn, identifier, target):
getattr(target.events, identifier).append(fn, target)
+
+ @property
+ def events(self):
+ """Iterate the Listeners objects."""
+
+ return (getattr(self, k) for k in dir(self) if k.startswith("on_"))
+ def update(self, other):
+ """Populate from the listeners in another :class:`Events` object."""
+
+ for ls in other.events:
+ getattr(self, ls.name).extend(ls)
class _ExecEvent(object):
- def exec_and_clear(self, *args, **kw):
- """Execute this event once, then clear all listeners."""
+ _exec_once = False
+
+ def exec_once(self, *args, **kw):
+ """Execute this event, but only if it has not been
+ executed already for this collection."""
- self(*args, **kw)
- self[:] = []
+ if not self._exec_once:
+ self(*args, **kw)
+ self._exec_once = True
def exec_until_return(self, *args, **kw):
"""Execute listeners for this event until
self._clslevel = []
def append(self, obj, target):
+ assert isinstance(target, type), "Class-level Event targets must be classes."
self._clslevel.append((obj, target))
def __get__(self, obj, cls):
if obj is None:
return self
- obj.__dict__[self.__name__] = result = Listeners()
+ obj.__dict__[self.__name__] = result = Listeners(self.__name__)
result.extend([
fn for fn, target in
self._clslevel
"""Represent a collection of listeners linked
to an instance of :class:`Events`."""
+ def __init__(self, name):
+ self.name = name
+
def append(self, obj, target):
list.append(self, obj)
self._reset_on_return = reset_on_return
self.echo = echo
if _dispatch:
- self.events = _dispatch
+ self.events.update(_dispatch)
if listeners:
+ util.warn_deprecated(
+ "The 'listeners' argument to Pool (and "
+ "create_engine()) is deprecated. Use event.listen().")
for l in listeners:
self.add_listener(l)
"""
events = event.dispatcher(events)
- @util.deprecated("Use event.listen()")
+ @util.deprecated("Pool.add_listener() is deprecated. Use event.listen()")
def add_listener(self, listener):
"""Add a ``PoolListener``-like object to this pool.
self.connection = self.__connect()
self.info = {}
- pool.events.on_first_connect.exec_and_clear(self.connection, self)
+ pool.events.on_first_connect.exec_once(self.connection, self)
pool.events.on_connect(self.connection, self)
def close(self):
self.connection = self.__connect()
self.info.clear()
if self.__pool.events.on_connect:
- self.__pool.events.on_connect(self.connection, con_record)
+ self.__pool.events.on_connect(self.connection, self)
elif self.__pool._recycle > -1 and \
time.time() - self.starttime > self.__pool._recycle:
self.__pool.logger.info(
self.connection = self.__connect()
self.info.clear()
if self.__pool.events.on_connect:
- self.__pool.events.on_connect(self.connection, con_record)
+ self.__pool.events.on_connect(self.connection, self)
return self.connection
def __close(self):
return query
-class SQLAssert(ConnectionProxy):
+class SQLAssert(object):
rules = None
def add_rules(self, rules):
return result
- def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
- result = execute(cursor, statement, parameters, context)
+ def cursor_execute(self, conn, execute, cursor, statement, parameters, context, executemany):
+ result = execute(cursor, statement, parameters, context, executemany)
if self.rules:
rule = self.rules[0]
from collections import deque
import config
from sqlalchemy.util import function_named, callable
+from sqlalchemy import event
import re
import warnings
url = url or config.db_url
options = options or config.db_opts
- options.setdefault('proxy', asserter)
-
- listeners = options.setdefault('listeners', [])
- listeners.append(testing_reaper)
-
engine = create_engine(url, **options)
+ event.listen(asserter.execute, 'on_execute', engine)
+ event.listen(asserter.cursor_execute, 'on_cursor_execute', engine)
+ event.listen(testing_reaper.checkout, 'on_checkout', engine.pool)
# may want to call this, results
# in first-connect initializers
# down from 185 on this this is a small slice of a usually
# bigger operation so using a small variance
- @profiling.function_call_count(95, variance=0.001,
+ @profiling.function_call_count(93, variance=0.001,
versions={'2.4': 67, '3': 96})
def go():
return sess2.merge(p1, load=False)
use_threadlocal=True)
- @profiling.function_call_count(64, {'2.4': 42, '2.7':59,
- '2.7+cextension':59,
+ @profiling.function_call_count(64, {'2.4': 42, '2.7':75,
+ '2.7+cextension':75,
'3.0':65, '3.1':65},
variance=.10)
def test_first_connect(self):
import re
from sqlalchemy.interfaces import ConnectionProxy
from sqlalchemy import MetaData, Integer, String, INT, VARCHAR, func, \
- bindparam, select
+ bindparam, select, event
from sqlalchemy.test.schema import Table, Column
import sqlalchemy as tsa
from sqlalchemy.test import TestBase, testing, engines
assert_raises(AssertionError, t.delete().execute)
finally:
engine.dialect.execution_ctx_cls = execution_ctx_cls
+
+class EngineEventsTest(TestBase):
+
+ @testing.fails_on('firebird', 'Data type unknown')
+ def test_execute_events(self):
+
+ stmts = []
+ cursor_stmts = []
+
+ def execute(conn, execute, clauseelement, *multiparams,
+ **params ):
+ stmts.append((str(clauseelement), params, multiparams))
+ return execute(clauseelement, *multiparams, **params)
+
+ def cursor_execute(conn, execute, cursor, statement, parameters,
+ context, executemany):
+ cursor_stmts.append((str(statement), parameters, None))
+ return execute(cursor, statement, parameters, context, executemany)
+
+ def assert_stmts(expected, received):
+ for stmt, params, posn in expected:
+ if not received:
+ assert False
+ while received:
+ teststmt, testparams, testmultiparams = \
+ received.pop(0)
+ teststmt = re.compile(r'[\n\t ]+', re.M).sub(' ',
+ teststmt).strip()
+ if teststmt.startswith(stmt) and (testparams
+ == params or testparams == posn):
+ break
+
+ for engine in \
+ engines.testing_engine(options=dict(implicit_returning=False)), \
+ engines.testing_engine(options=dict(implicit_returning=False,
+ strategy='threadlocal')):
+ event.listen(execute, 'on_execute', engine)
+ event.listen(cursor_execute, 'on_cursor_execute', engine)
+
+ m = MetaData(engine)
+ t1 = Table('t1', m,
+ Column('c1', Integer, primary_key=True),
+ Column('c2', String(50), default=func.lower('Foo'),
+ primary_key=True)
+ )
+ m.create_all()
+ try:
+ t1.insert().execute(c1=5, c2='some data')
+ t1.insert().execute(c1=6)
+ eq_(engine.execute('select * from t1').fetchall(), [(5,
+ 'some data'), (6, 'foo')])
+ finally:
+ m.drop_all()
+ engine.dispose()
+ compiled = [('CREATE TABLE t1', {}, None),
+ ('INSERT INTO t1 (c1, c2)', {'c2': 'some data',
+ 'c1': 5}, None), ('INSERT INTO t1 (c1, c2)',
+ {'c1': 6}, None), ('select * from t1', {},
+ None), ('DROP TABLE t1', {}, None)]
+ if not testing.against('oracle+zxjdbc'): # or engine.dialect.pr
+ # eexecute_pk_sequence
+ # s:
+ cursor = [
+ ('CREATE TABLE t1', {}, ()),
+ ('INSERT INTO t1 (c1, c2)', {'c2': 'some data', 'c1'
+ : 5}, (5, 'some data')),
+ ('SELECT lower', {'lower_2': 'Foo'}, ('Foo', )),
+ ('INSERT INTO t1 (c1, c2)', {'c2': 'foo', 'c1': 6},
+ (6, 'foo')),
+ ('select * from t1', {}, ()),
+ ('DROP TABLE t1', {}, ()),
+ ]
+ else:
+ insert2_params = 6, 'Foo'
+ if testing.against('oracle+zxjdbc'):
+ insert2_params += (ReturningParam(12), )
+ cursor = [('CREATE TABLE t1', {}, ()),
+ ('INSERT INTO t1 (c1, c2)', {'c2': 'some data'
+ , 'c1': 5}, (5, 'some data')),
+ ('INSERT INTO t1 (c1, c2)', {'c1': 6,
+ 'lower_2': 'Foo'}, insert2_params),
+ ('select * from t1', {}, ()), ('DROP TABLE t1'
+ , {}, ())] # bind param name 'lower_2' might
+ # be incorrect
+ assert_stmts(compiled, stmts)
+ assert_stmts(cursor, cursor_stmts)
+
+ def test_options(self):
+ track = []
+ def on_execute(conn, exec_, *args, **kw):
+ track.append('execute')
+ return exec_(*args, **kw)
+
+ def on_cursor_execute(conn, exec_, *args, **kw):
+ track.append('cursor_execute')
+ return exec_(*args, **kw)
+
+ engine = engines.testing_engine()
+ event.listen(on_execute, 'on_execute', engine)
+ event.listen(on_cursor_execute, 'on_cursor_execute', engine)
+ conn = engine.connect()
+ c2 = conn.execution_options(foo='bar')
+ eq_(c2._execution_options, {'foo':'bar'})
+ c2.execute(select([1]))
+ c3 = c2.execution_options(bar='bat')
+ eq_(c3._execution_options, {'foo':'bar', 'bar':'bat'})
+ eq_(track, ['execute', 'cursor_execute'])
+
+
+ def test_transactional(self):
+ track = []
+ def tracker(name):
+ def go(conn, exec_, *args, **kw):
+ track.append(name)
+ return exec_(*args, **kw)
+ return go
+
+ engine = engines.testing_engine()
+ event.listen(tracker('execute'), 'on_execute', engine)
+ event.listen(tracker('cursor_execute'), 'on_cursor_execute', engine)
+ event.listen(tracker('begin'), 'on_begin', engine)
+ event.listen(tracker('commit'), 'on_commit', engine)
+ event.listen(tracker('rollback'), 'on_rollback', engine)
-class ProxyConnectionTest(TestBase):
+ conn = engine.connect()
+ trans = conn.begin()
+ conn.execute(select([1]))
+ trans.rollback()
+ trans = conn.begin()
+ conn.execute(select([1]))
+ trans.commit()
+
+ eq_(track, [
+ 'begin', 'execute', 'cursor_execute', 'rollback',
+ 'begin', 'execute', 'cursor_execute', 'commit',
+ ])
+ @testing.requires.savepoints
+ @testing.requires.two_phase_transactions
+ def test_transactional_advanced(self):
+ track = []
+ def tracker(name):
+ def go(conn, exec_, *args, **kw):
+ track.append(name)
+ return exec_(*args, **kw)
+ return go
+
+ engine = engines.testing_engine()
+ for name in ['begin', 'savepoint',
+ 'rollback_savepoint', 'release_savepoint',
+ 'rollback', 'begin_twophase',
+ 'prepare_twophase', 'commit_twophase']:
+ event.listen(tracker(name), 'on_%s' % name, engine)
+
+ conn = engine.connect()
+
+ trans = conn.begin()
+ trans2 = conn.begin_nested()
+ conn.execute(select([1]))
+ trans2.rollback()
+ trans2 = conn.begin_nested()
+ conn.execute(select([1]))
+ trans2.commit()
+ trans.rollback()
+
+ trans = conn.begin_twophase()
+ conn.execute(select([1]))
+ trans.prepare()
+ trans.commit()
+
+ eq_(track, ['begin', 'savepoint',
+ 'rollback_savepoint', 'savepoint', 'release_savepoint',
+ 'rollback', 'begin_twophase',
+ 'prepare_twophase', 'commit_twophase']
+ )
+
+class ProxyConnectionTest(TestBase):
+ """These are the same tests as EngineEventsTest, except using
+ the deprecated ConnectionProxy interface.
+
+ """
+
+ @testing.uses_deprecated(r'.*Use event.listen')
@testing.fails_on('firebird', 'Data type unknown')
def test_proxy(self):
assert_stmts(compiled, stmts)
assert_stmts(cursor, cursor_stmts)
+ @testing.uses_deprecated(r'.*Use event.listen')
def test_options(self):
track = []
class TrackProxy(ConnectionProxy):
eq_(track, ['execute', 'cursor_execute'])
+ @testing.uses_deprecated(r'.*Use event.listen')
def test_transactional(self):
track = []
class TrackProxy(ConnectionProxy):
trans.commit()
eq_(track, [
- 'begin',
- 'execute',
- 'cursor_execute',
- 'rollback',
- 'begin',
- 'execute',
- 'cursor_execute',
- 'commit',
+ 'begin', 'execute', 'cursor_execute', 'rollback',
+ 'begin', 'execute', 'cursor_execute', 'commit',
])
+ @testing.uses_deprecated(r'.*Use event.listen')
@testing.requires.savepoints
@testing.requires.two_phase_transactions
def test_transactional_advanced(self):
import threading, time
-from sqlalchemy import pool, interfaces, create_engine, select
+from sqlalchemy import pool, interfaces, create_engine, select, event
import sqlalchemy as tsa
from sqlalchemy.test import TestBase, testing
from sqlalchemy.test.util import gc_collect, lazy_gc
self.assert_(c.connection is not c2.connection)
self.assert_(not c2.info)
self.assert_('foo2' in c.info)
-
+
+ @testing.uses_deprecated(r".*Use event.listen")
def test_listeners(self):
dbapi = MockDBAPI()
def assert_listeners(p, total, conn, fconn, cout, cin):
for instance in (p, p.recreate()):
- self.assert_(len(instance.listeners) == total)
- self.assert_(len(instance._on_connect) == conn)
- self.assert_(len(instance._on_first_connect) == fconn)
- self.assert_(len(instance._on_checkout) == cout)
- self.assert_(len(instance._on_checkin) == cin)
+ self.assert_(len(instance.events.on_connect) == conn)
+ self.assert_(len(instance.events.on_first_connect) == fconn)
+ self.assert_(len(instance.events.on_checkout) == cout)
+ self.assert_(len(instance.events.on_checkin) == cin)
p = _pool()
assert_listeners(p, 0, 0, 0, 0, 0)
c.close()
snoop.assert_total(1, 1, 2, 2)
+ @testing.uses_deprecated(r".*Use event.listen")
def test_listeners_callables(self):
dbapi = MockDBAPI()
def assert_listeners(p, total, conn, cout, cin):
for instance in (p, p.recreate()):
- self.assert_(len(instance.listeners) == total)
- self.assert_(len(instance._on_connect) == conn)
- self.assert_(len(instance._on_checkout) == cout)
- self.assert_(len(instance._on_checkin) == cin)
+ self.assert_(len(instance.events.on_connect) == conn)
+ self.assert_(len(instance.events.on_checkout) == cout)
+ self.assert_(len(instance.events.on_checkin) == cin)
p = _pool()
assert_listeners(p, 0, 0, 0, 0)
called = []
def listener(*args):
called.append(True)
- listener.connect = listener
engine = create_engine(testing.db.url)
- engine.pool.add_listener(listener)
+ event.listen(listener, 'on_connect', engine.pool)
engine.execute(select([1])).close()
assert called, "Listener not called on connect"
canary.called = 0
manager = sa.orm.attributes.manager_of_class(cls)
- manager.events.add_listener('on_load', canary)
+ manager.events.listen(canary, 'on_load', manager)
return canary