From: Mike Bayer Date: Sun, 14 Nov 2010 22:51:54 +0000 (-0500) Subject: - SessionEvents is on board and the event model is done, can start building 0.7 tip... X-Git-Tag: rel_0_7b1~253^2~3 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=4ac324067961f0d4452994083f5aa1a71f6d6d71;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - SessionEvents is on board and the event model is done, can start building 0.7 tip soon --- diff --git a/doc/build/orm/events.rst b/doc/build/orm/events.rst index fdbe3d3393..45c947ae0e 100644 --- a/doc/build/orm/events.rst +++ b/doc/build/orm/events.rst @@ -30,7 +30,8 @@ Instance Events Session Events -------------- -TODO +.. autoclass:: sqlalchemy.orm.events.SessionEvents + :members: Instrumentation Events ----------------------- diff --git a/lib/sqlalchemy/event.py b/lib/sqlalchemy/event.py index c39aff63e1..359a4c017d 100644 --- a/lib/sqlalchemy/event.py +++ b/lib/sqlalchemy/event.py @@ -186,12 +186,18 @@ class _ListenerCollection(object): def __call__(self, *args, **kw): """Execute this event.""" - for fn in self: + for fn in self.parent_listeners + self.listeners: fn(*args, **kw) # I'm not entirely thrilled about the overhead here, # but this allows class-level listeners to be added # at any point. + # + # alternatively, _DispatchDescriptor could notify + # all _ListenerCollection objects, but then we move + # to a higher memory model, i.e.weakrefs to all _ListenerCollection + # objects, the _DispatchDescriptor collection repeated + # for all instances. def __len__(self): return len(self.parent_listeners + self.listeners) diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 6df13e94f0..0f8a7d95ca 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -281,12 +281,10 @@ def relationship(argument, secondary=None, **kwargs): docstring which will be applied to the resulting descriptor. :param extension: - an :class:`AttributeExtension` instance, or list of extensions, + an :class:`.AttributeExtension` instance, or list of extensions, which will be prepended to the list of attribute listeners for - the resulting descriptor placed on the class. These listeners - will receive append and set events before the operation - proceeds, and may be used to halt (via exception throw) or - change the value used in the operation. + the resulting descriptor placed on the class. + **Deprecated.** Please see :class:`.AttributeEvents`. :param foreign_keys: a list of columns which are to be used as "foreign key" columns. @@ -603,12 +601,13 @@ def column_property(*args, **kwargs): class-bound descriptor. :param extension: - an :class:`~sqlalchemy.orm.interfaces.AttributeExtension` instance, - or list of extensions, which will be prepended to the list of - attribute listeners for the resulting descriptor placed on the class. - These listeners will receive append and set events before the - operation proceeds, and may be used to halt (via exception throw) - or change the value used in the operation. + an + :class:`.AttributeExtension` + instance, or list of extensions, which will be prepended + to the list of attribute listeners for the resulting + descriptor placed on the class. + **Deprecated.** Please see :class:`.AttributeEvents`. + """ @@ -643,12 +642,9 @@ def composite(class_, *cols, **kwargs): class-bound descriptor. :param extension: - an :class:`~sqlalchemy.orm.interfaces.AttributeExtension` instance, + an :class:`.AttributeExtension` instance, or list of extensions, which will be prepended to the list of attribute listeners for the resulting descriptor placed on the class. - These listeners will receive append and set events before the - operation proceeds, and may be used to halt (via exception throw) - or change the value used in the operation. """ return CompositeProperty(class_, *cols, **kwargs) @@ -729,8 +725,7 @@ def mapper(class_, local_table=None, *args, **params): :param extension: A :class:`.MapperExtension` instance or list of :class:`.MapperExtension` instances which will be applied to all operations by this - :class:`.Mapper`. Deprecated. - The event package is now used. + :class:`.Mapper`. **Deprecated.** Please see :class:`.MapperEvents`. :param include_properties: An inclusive list or set of string column names to map. As of SQLAlchemy 0.6.4, this collection may also diff --git a/lib/sqlalchemy/orm/deprecated_interfaces.py b/lib/sqlalchemy/orm/deprecated_interfaces.py index 8aa57c6875..3e4f268d08 100644 --- a/lib/sqlalchemy/orm/deprecated_interfaces.py +++ b/lib/sqlalchemy/orm/deprecated_interfaces.py @@ -393,6 +393,19 @@ class SessionExtension(object): """ + @classmethod + def _adapt_listener(cls, self, listener): + event.listen(listener.before_commit, 'on_before_commit', self) + event.listen(listener.after_commit, 'on_after_commit', self) + event.listen(listener.after_rollback, 'on_after_rollback', self) + event.listen(listener.before_flush, 'on_before_flush', self) + event.listen(listener.after_flush, 'on_after_flush', self) + event.listen(listener.after_flush_postexec, 'on_after_flush_postexec', self) + event.listen(listener.after_begin, 'on_after_begin', self) + event.listen(listener.after_attach, 'on_after_attach', self) + event.listen(listener.after_bulk_update, 'on_after_bulk_update', self) + event.listen(listener.after_bulk_delete, 'on_after_bulk_delete', self) + def before_commit(self, session): """Execute right before commit is called. diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 66d3834521..fc8cab2ed7 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -627,22 +627,160 @@ class MapperEvents(event.Events): raise NotImplementedError("Removal of mapper events not yet implemented") class SessionEvents(event.Events): - """""" + """Define events specific to :class:`.Session` lifecycle. + + e.g.:: + + from sqlalchemy import event + from sqlalchemy.orm import sessionmaker + + class my_before_commit(session): + print "before commit!" + + Session = sessionmaker() + + event.listen(my_before_commit, "on_before_commit", Session) + + The :func:`~.event.listen` function will accept + :class:`.Session` objects as well as the return result + of :func:`.sessionmaker` and :func:`.scoped_session`. + + Additionally, it accepts the :class:`.Session` class which + will apply listeners to all :class:`.Session` instances + globally. + + """ + + @classmethod + def accept_with(cls, target): + from sqlalchemy.orm import ScopedSession, Session + if isinstance(target, ScopedSession): + if not isinstance(target.session_factory, type) or \ + not issubclass(target.session_factory, Session): + raise exc.ArgumentError( + "Session event listen on a ScopedSession " + "requries that its creation callable " + "is a Session subclass.") + return target.session_factory + elif isinstance(target, type): + if issubclass(target, ScopedSession): + return Session + elif issubclass(target, Session): + return target + elif isinstance(target, Session): + return target + else: + return None + @classmethod def remove(cls, fn, identifier, target): raise NotImplementedError("Removal of session events not yet implemented") + def on_before_commit(self, session): + """Execute before commit is called. + + Note that this may not be per-flush if a longer running + transaction is ongoing.""" + + def on_after_commit(self, session): + """Execute after a commit has occured. + + Note that this may not be per-flush if a longer running + transaction is ongoing.""" + + def on_after_rollback(self, session): + """Execute after a rollback has occured. + + Note that this may not be per-flush if a longer running + transaction is ongoing.""" + + def on_before_flush( self, session, flush_context, instances): + """Execute before flush process has started. + + `instances` is an optional list of objects which were passed to + the ``flush()`` method. """ + + def on_after_flush(self, session, flush_context): + """Execute after flush has completed, but before commit has been + called. + + Note that the session's state is still in pre-flush, i.e. 'new', + 'dirty', and 'deleted' lists still show pre-flush state as well + as the history settings on instance attributes.""" + + def on_after_flush_postexec(self, session, flush_context): + """Execute after flush has completed, and after the post-exec + state occurs. + + This will be when the 'new', 'dirty', and 'deleted' lists are in + their final state. An actual commit() may or may not have + occured, depending on whether or not the flush started its own + transaction or participated in a larger transaction. """ + + def on_after_begin( self, session, transaction, connection): + """Execute after a transaction is begun on a connection + + `transaction` is the SessionTransaction. This method is called + after an engine level transaction is begun on a connection. """ + + def on_after_attach(self, session, instance): + """Execute after an instance is attached to a session. + + This is called after an add, delete or merge. """ + + def on_after_bulk_update( self, session, query, query_context, result): + """Execute after a bulk update operation to the session. + + This is called after a session.query(...).update() + + `query` is the query object that this update operation was + called on. `query_context` was the query context object. + `result` is the result object returned from the bulk operation. + """ + + def on_after_bulk_delete( self, session, query, query_context, result): + """Execute after a bulk delete operation to the session. + + This is called after a session.query(...).delete() + + `query` is the query object that this delete operation was + called on. `query_context` was the query context object. + `result` is the result object returned from the bulk operation. + """ + + class AttributeEvents(event.Events): """Define events for object attributes. + + These are typically defined on the class-bound descriptor for the + target class. e.g.:: from sqlalchemy import event + + def my_append_listener(target, value, initiator): + print "received append event for target: %s" % target + event.listen(my_append_listener, 'on_append', MyClass.collection) - event.listen(my_set_listener, 'on_set', - MyClass.somescalar, retval=True) + + Listeners have the option to return a possibly modified version + of the value, when the ``retval=True`` flag is passed + to :func:`~.event.listen`:: + + def validate_phone(target, value, oldvalue, initiator): + "Strip non-numeric characters from a phone number" + + return re.sub(r'(?![0-9])', '', value) + + # setup listener on UserContact.phone attribute, instructing + # it to use the return value + listen(validate_phone, 'on_set', UserContact.phone, retval=True) + + A validation function like the above can also raise an exception + such as :class:`ValueError` to halt the operation. - Several modifiers are available to the :func:`~event.listen` function. + Several modifiers are available to the :func:`~.event.listen` function. :param active_history=False: When True, indicates that the "on_set" event would like to receive the "old" value diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index e324c3f9dd..5c608d1f4b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -2160,8 +2160,7 @@ class Query(object): ) ) - for ext in session.extensions: - ext.after_bulk_delete(session, self, context, result) + session.dispatch.on_after_bulk_delete(session, self, context, result) return result.rowcount @@ -2310,9 +2309,8 @@ class Query(object): session.identity_map[identity_key], [_attr_as_key(k) for k in values] ) - - for ext in session.extensions: - ext.after_bulk_update(session, self, context, result) + + session.dispatch.on_after_bulk_update(session, self, context, result) return result.rowcount diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index c384cfc3d0..a1eb4c46d1 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -22,6 +22,9 @@ from sqlalchemy.orm.util import ( from sqlalchemy.orm.mapper import Mapper, _none_set from sqlalchemy.orm.unitofwork import UOWTransaction from sqlalchemy.orm import identity +from sqlalchemy import event +from sqlalchemy.orm.events import SessionEvents + import sys __all__ = ['Session', 'SessionTransaction', 'SessionExtension'] @@ -133,11 +136,10 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, from the most recent database state. :param extension: An optional - :class:`~sqlalchemy.orm.session.SessionExtension` instance, or a list + :class:`~.SessionExtension` instance, or a list of such instances, which will receive pre- and post- commit and flush - events, as well as a post-rollback event. User- defined code may be - placed within these hooks using a user-defined subclass of - ``SessionExtension``. + events, as well as a post-rollback event. **Deprecated.** + Please see :class:`.SessionEvents`. :param query_cls: Class which should be used to create new Query objects, as returned by the ``query()`` method. Defaults to @@ -177,6 +179,7 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, local_kwargs.setdefault(k, kwargs[k]) super(Sess, self).__init__(**local_kwargs) + @classmethod def configure(self, **new_kwargs): """(Re)configure the arguments for this sessionmaker. @@ -187,9 +190,9 @@ def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, Session.configure(bind=create_engine('sqlite://')) """ kwargs.update(new_kwargs) - configure = classmethod(configure) - s = type.__new__(type, "Session", (Sess, class_), {}) - return s + + + return type("Session", (Sess, class_), {}) class SessionTransaction(object): @@ -344,8 +347,7 @@ class SessionTransaction(object): self._connections[conn] = self._connections[conn.engine] = \ (conn, transaction, conn is not bind) - for ext in self.session.extensions: - ext.after_begin(self.session, self, conn) + self.session.dispatch.on_after_begin(self.session, self, conn) return conn def prepare(self): @@ -357,8 +359,7 @@ class SessionTransaction(object): def _prepare_impl(self): self._assert_is_active() if self._parent is None or self.nested: - for ext in self.session.extensions: - ext.before_commit(self.session) + self.session.dispatch.on_before_commit(self.session) stx = self.session.transaction if stx is not self: @@ -388,8 +389,7 @@ class SessionTransaction(object): for t in set(self._connections.values()): t[1].commit() - for ext in self.session.extensions: - ext.after_commit(self.session) + self.session.dispatch.on_after_commit(self.session) if self.session._enable_transaction_accounting: self._remove_snapshot() @@ -426,8 +426,7 @@ class SessionTransaction(object): if self.session._enable_transaction_accounting: self._restore_snapshot() - for ext in self.session.extensions: - ext.after_rollback(self.session) + self.session.dispatch.on_after_rollback(self.session) def _deactivate(self): self._active = False @@ -511,9 +510,13 @@ class Session(object): self.expire_on_commit = expire_on_commit self._enable_transaction_accounting = _enable_transaction_accounting self.twophase = twophase - self.extensions = util.to_list(extension) or [] self._query_cls = query_cls self._mapper_flush_opts = {} + + if extension: + for ext in util.to_list(extension): + SessionExtension._adapt_listener(self, extension) + if binds is not None: for mapperortable, bind in binds.iteritems(): if isinstance(mapperortable, (type, Mapper)): @@ -525,6 +528,8 @@ class Session(object): self.begin() _sessions[self.hash_key] = self + dispatch = event.dispatcher(SessionEvents) + def begin(self, subtransactions=False, nested=False): """Begin a transaction on this Session. @@ -1325,8 +1330,8 @@ class Session(object): if state.session_id != self.hash_key: state.session_id = self.hash_key - for ext in self.extensions: - ext.after_attach(self, state.obj()) + if self.dispatch.on_after_attach: + self.dispatch.on_after_attach(self, state.obj()) def __contains__(self, instance): """Return True if the instance is associated with this session. @@ -1400,10 +1405,11 @@ class Session(object): return flush_context = UOWTransaction(self) - - if self.extensions: - for ext in self.extensions: - ext.before_flush(self, flush_context, objects) + + if self.dispatch.on_before_flush: + self.dispatch.on_before_flush(self, flush_context, objects) + # re-establish "dirty states" in case the listeners + # added dirty = self._dirty_states deleted = set(self._deleted) @@ -1468,8 +1474,7 @@ class Session(object): try: flush_context.execute() - for ext in self.extensions: - ext.after_flush(self, flush_context) + self.dispatch.on_after_flush(self, flush_context) transaction.commit() except: transaction.rollback(_capture_exception=True) @@ -1484,8 +1489,7 @@ class Session(object): # assert self.identity_map._modified == self.identity_map._modified.difference(objects) #self.identity_map._modified.clear() - for ext in self.extensions: - ext.after_flush_postexec(self, flush_context) + self.dispatch.on_after_flush_postexec(self, flush_context) def is_modified(self, instance, include_collections=True, passive=False): """Return ``True`` if instance has modified attributes. diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 7f451698a8..06c9ac0651 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -349,7 +349,7 @@ class PoolEventsTest(PoolTestBase): c2.close() eq_(canary, ['checkin', 'checkin']) - def test_listen_targets(self): + def test_listen_targets_scope(self): canary = [] def listen_one(*args): canary.append("listen_one") @@ -370,7 +370,37 @@ class PoolEventsTest(PoolTestBase): eq_( canary, ["listen_one","listen_four", "listen_two","listen_three"] ) - + + def test_listen_targets_per_subclass(self): + """test that listen() called on a subclass remains specific to that subclass.""" + + canary = [] + def listen_one(*args): + canary.append("listen_one") + def listen_two(*args): + canary.append("listen_two") + def listen_three(*args): + canary.append("listen_three") + + event.listen(listen_one, 'on_connect', pool.Pool) + event.listen(listen_two, 'on_connect', pool.QueuePool) + event.listen(listen_three, 'on_connect', pool.SingletonThreadPool) + + p1 = pool.QueuePool(creator=MockDBAPI().connect) + p2 = pool.SingletonThreadPool(creator=MockDBAPI().connect) + + assert listen_one in p1.dispatch.on_connect + assert listen_two in p1.dispatch.on_connect + assert listen_three not in p1.dispatch.on_connect + assert listen_one in p2.dispatch.on_connect + assert listen_two not in p2.dispatch.on_connect + assert listen_three in p2.dispatch.on_connect + + p1.connect() + eq_(canary, ["listen_one", "listen_two"]) + p2.connect() + eq_(canary, ["listen_one", "listen_two", "listen_one", "listen_three"]) + def teardown(self): # TODO: need to get remove() functionality # going diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 6ac42a6b38..5994c41dac 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -15,7 +15,7 @@ from sqlalchemy.orm import mapper, relationship, backref, joinedload, \ from sqlalchemy.test.testing import eq_ from test.engine import _base as engine_base from test.orm import _base, _fixtures - +from sqlalchemy import event class SessionTest(_fixtures.FixtureTest): run_inserts = None @@ -1164,103 +1164,296 @@ class SessionTest(_fixtures.FixtureTest): assert s.query(Address).one().id == a.id assert s.query(User).first() is None + @testing.resolve_artifact_names - def test_extension(self): + def test_pickled_update(self): mapper(User, users) - log = [] - class MyExt(sa.orm.session.SessionExtension): - def before_commit(self, session): - log.append('before_commit') - def after_commit(self, session): - log.append('after_commit') - def after_rollback(self, session): - log.append('after_rollback') - def before_flush(self, session, flush_context, objects): - log.append('before_flush') - def after_flush(self, session, flush_context): - log.append('after_flush') - def after_flush_postexec(self, session, flush_context): - log.append('after_flush_postexec') - def after_begin(self, session, transaction, connection): - log.append('after_begin') - def after_attach(self, session, instance): - log.append('after_attach') - def after_bulk_update( - self, - session, - query, - query_context, - result, - ): - log.append('after_bulk_update') + sess1 = create_session() + sess2 = create_session() + u1 = User(name='u1') + sess1.add(u1) + assert_raises_message(sa.exc.InvalidRequestError, + 'already attached to session', sess2.add, + u1) + u2 = pickle.loads(pickle.dumps(u1)) + sess2.add(u2) - def after_bulk_delete( - self, - session, - query, - query_context, - result, - ): - log.append('after_bulk_delete') + @testing.resolve_artifact_names + def test_duplicate_update(self): + mapper(User, users) + Session = sessionmaker() + sess = Session() - sess = create_session(extension = MyExt()) + u1 = User(name='u1') + sess.add(u1) + sess.flush() + assert u1.id is not None + + sess.expunge(u1) + + assert u1 not in sess + assert Session.object_session(u1) is None + + u2 = sess.query(User).get(u1.id) + assert u2 is not None and u2 is not u1 + assert u2 in sess + + assert_raises(Exception, lambda: sess.add(u1)) + + sess.expunge(u2) + assert u2 not in sess + assert Session.object_session(u2) is None + + u1.name = "John" + u2.name = "Doe" + + sess.add(u1) + assert u1 in sess + assert Session.object_session(u1) is sess + + sess.flush() + + sess.expunge_all() + + u3 = sess.query(User).get(u1.id) + assert u3 is not u1 and u3 is not u2 and u3.name == u1.name + + @testing.resolve_artifact_names + def test_no_double_save(self): + sess = create_session() + class Foo(object): + def __init__(self): + sess.add(self) + class Bar(Foo): + def __init__(self): + sess.add(self) + Foo.__init__(self) + mapper(Foo, users) + mapper(Bar, users) + + b = Bar() + assert b in sess + assert len(list(sess)) == 1 + + @testing.resolve_artifact_names + def test_identity_map_mutate(self): + mapper(User, users) + + sess = Session() + + sess.add_all([User(name='u1'), User(name='u2'), User(name='u3')]) + sess.commit() + + u1, u2, u3 = sess.query(User).all() + for i, (key, value) in enumerate(sess.identity_map.iteritems()): + if i == 2: + del u3 + gc_collect() + +class SessionEventsTest(_fixtures.FixtureTest): + run_inserts = None + + def test_class_listen(self): + def my_listener(*arg, **kw): + pass + + event.listen(my_listener, 'on_before_flush', Session) + + s = Session() + assert my_listener in s.dispatch.on_before_flush + + def test_sessionmaker_listen(self): + """test that listen can be applied to individual scoped_session() classes.""" + + def my_listener_one(*arg, **kw): + pass + def my_listener_two(*arg, **kw): + pass + + S1 = sessionmaker() + S2 = sessionmaker() + + event.listen(my_listener_one, 'on_before_flush', Session) + event.listen(my_listener_two, 'on_before_flush', S1) + + s1 = S1() + assert my_listener_one in s1.dispatch.on_before_flush + assert my_listener_two in s1.dispatch.on_before_flush + + s2 = S2() + assert my_listener_one in s2.dispatch.on_before_flush + assert my_listener_two not in s2.dispatch.on_before_flush + + def test_scoped_session_invalid_callable(self): + from sqlalchemy.orm import scoped_session + + def my_listener_one(*arg, **kw): + pass + + scope = scoped_session(lambda:Session()) + + assert_raises_message( + sa.exc.ArgumentError, + "Session event listen on a ScopedSession " + "requries that its creation callable is a Session subclass.", + event.listen, my_listener_one, "on_before_flush", scope + ) + + def test_scoped_session_invalid_class(self): + from sqlalchemy.orm import scoped_session + + def my_listener_one(*arg, **kw): + pass + + class NotASession(object): + def __call__(self): + return Session() + + scope = scoped_session(NotASession) + + assert_raises_message( + sa.exc.ArgumentError, + "Session event listen on a ScopedSession " + "requries that its creation callable is a Session subclass.", + event.listen, my_listener_one, "on_before_flush", scope + ) + + def test_scoped_session_listen(self): + from sqlalchemy.orm import scoped_session + + def my_listener_one(*arg, **kw): + pass + + scope = scoped_session(sessionmaker()) + event.listen(my_listener_one, "on_before_flush", scope) + + assert my_listener_one in scope().dispatch.on_before_flush + + def _listener_fixture(self, **kw): + canary = [] + def listener(name): + def go(*arg, **kw): + canary.append(name) + return go + + sess = Session(**kw) + + for evt in [ + 'on_before_commit', + 'on_after_commit', + 'on_after_rollback', + 'on_before_flush', + 'on_after_flush', + 'on_after_flush_postexec', + 'on_after_begin', + 'on_after_attach', + 'on_after_bulk_update', + 'on_after_bulk_delete' + ]: + event.listen(listener(evt), evt, sess) + + return sess, canary + + @testing.resolve_artifact_names + def test_flush_autocommit_hook(self): + + mapper(User, users) + + sess, canary = self._listener_fixture(autoflush=False, autocommit=True) + u = User(name='u1') sess.add(u) sess.flush() - assert log == [ - 'after_attach', - 'before_flush', - 'after_begin', - 'after_flush', - 'before_commit', - 'after_commit', - 'after_flush_postexec', - ] - log = [] - sess = create_session(autocommit=False, extension=MyExt()) + eq_( + canary, + [ 'on_after_attach', 'on_before_flush', 'on_after_begin', + 'on_after_flush', 'on_before_commit', 'on_after_commit', + 'on_after_flush_postexec', ] + ) + + @testing.resolve_artifact_names + def test_flush_noautocommit_hook(self): + sess, canary = self._listener_fixture() + + mapper(User, users) + u = User(name='u1') sess.add(u) sess.flush() - assert log == ['after_attach', 'before_flush', 'after_begin', - 'after_flush', 'after_flush_postexec'] - log = [] + eq_(canary, ['on_after_attach', 'on_before_flush', 'on_after_begin', + 'on_after_flush', 'on_after_flush_postexec']) + + @testing.resolve_artifact_names + def test_flush_in_commit_hook(self): + sess, canary = self._listener_fixture() + + mapper(User, users) + u = User(name='u1') + sess.add(u) + sess.flush() + canary[:] = [] + u.name = 'ed' sess.commit() - assert log == ['before_commit', 'before_flush', 'after_flush', - 'after_flush_postexec', 'after_commit'] - log = [] + eq_(canary, ['on_before_commit', 'on_before_flush', 'on_after_flush', + 'on_after_flush_postexec', 'on_after_commit']) + + def test_standalone_on_commit_hook(self): + sess, canary = self._listener_fixture() sess.commit() - assert log == ['before_commit', 'after_commit'] - log = [] - sess.query(User).delete() - assert log == ['after_begin', 'after_bulk_delete'] - log = [] + eq_(canary, ['on_before_commit', 'on_after_commit']) + + @testing.resolve_artifact_names + def test_on_bulk_update_hook(self): + sess, canary = self._listener_fixture() + mapper(User, users) sess.query(User).update({'name': 'foo'}) - assert log == ['after_bulk_update'] - log = [] - sess = create_session(autocommit=False, extension=MyExt(), - bind=testing.db) + eq_(canary, ['on_after_begin', 'on_after_bulk_update']) + + @testing.resolve_artifact_names + def test_on_bulk_delete_hook(self): + sess, canary = self._listener_fixture() + mapper(User, users) + sess.query(User).delete() + eq_(canary, ['on_after_begin', 'on_after_bulk_delete']) + + def test_connection_emits_after_begin(self): + sess, canary = self._listener_fixture(bind=testing.db) conn = sess.connection() - assert log == ['after_begin'] + eq_(canary, ['on_after_begin']) @testing.resolve_artifact_names - def test_before_flush(self): - """test that the flush plan can be affected during before_flush()""" + def test_reentrant_flush(self): + + mapper(User, users) + + def before_flush(session, flush_context, objects): + session.flush() + + sess = Session() + event.listen(before_flush, 'on_before_flush', sess) + sess.add(User(name='foo')) + assert_raises_message(sa.exc.InvalidRequestError, + 'already flushing', sess.flush) + + @testing.resolve_artifact_names + def test_before_flush_affects_flush_plan(self): mapper(User, users) - class MyExt(sa.orm.session.SessionExtension): - def before_flush(self, session, flush_context, objects): - for obj in list(session.new) + list(session.dirty): - if isinstance(obj, User): - session.add(User(name='another %s' % obj.name)) - for obj in list(session.deleted): - if isinstance(obj, User): - x = session.query(User).filter(User.name - == 'another %s' % obj.name).one() - session.delete(x) + def before_flush(session, flush_context, objects): + for obj in list(session.new) + list(session.dirty): + if isinstance(obj, User): + session.add(User(name='another %s' % obj.name)) + for obj in list(session.deleted): + if isinstance(obj, User): + x = session.query(User).filter(User.name + == 'another %s' % obj.name).one() + session.delete(x) - sess = create_session(extension = MyExt(), autoflush=True) + sess = Session() + event.listen(before_flush, 'on_before_flush', sess) + u = User(name='u1') sess.add(u) sess.flush() @@ -1301,19 +1494,18 @@ class SessionTest(_fixtures.FixtureTest): def test_before_flush_affects_dirty(self): mapper(User, users) - class MyExt(sa.orm.session.SessionExtension): - def before_flush(self, session, flush_context, objects): - for obj in list(session.identity_map.values()): - obj.name += " modified" + def before_flush(session, flush_context, objects): + for obj in list(session.identity_map.values()): + obj.name += " modified" - sess = create_session(extension = MyExt(), autoflush=True) + sess = Session(autoflush=True) + event.listen(before_flush, 'on_before_flush', sess) + u = User(name='u1') sess.add(u) sess.flush() eq_(sess.query(User).order_by(User.name).all(), - [ - User(name='u1') - ] + [User(name='u1')] ) sess.add(User(name='u2')) @@ -1325,106 +1517,97 @@ class SessionTest(_fixtures.FixtureTest): User(name='u2') ] ) - - @testing.resolve_artifact_names - def test_reentrant_flush(self): - - mapper(User, users) - - class MyExt(sa.orm.session.SessionExtension): - def before_flush(s, session, flush_context, objects): - session.flush() - sess = create_session(extension=MyExt()) - sess.add(User(name='foo')) - assert_raises_message(sa.exc.InvalidRequestError, - 'already flushing', sess.flush) + def teardown(self): + # TODO: need to get remove() functionality + # going + Session.dispatch.clear() + super(SessionEventsTest, self).teardown() + - @testing.resolve_artifact_names - def test_pickled_update(self): - mapper(User, users) - sess1 = create_session() - sess2 = create_session() - u1 = User(name='u1') - sess1.add(u1) - assert_raises_message(sa.exc.InvalidRequestError, - 'already attached to session', sess2.add, - u1) - u2 = pickle.loads(pickle.dumps(u1)) - sess2.add(u2) +class SessionExtensionTest(_fixtures.FixtureTest): + run_inserts = None @testing.resolve_artifact_names - def test_duplicate_update(self): + def test_extension(self): mapper(User, users) - Session = sessionmaker() - sess = Session() - - u1 = User(name='u1') - sess.add(u1) - sess.flush() - assert u1.id is not None - - sess.expunge(u1) - - assert u1 not in sess - assert Session.object_session(u1) is None - - u2 = sess.query(User).get(u1.id) - assert u2 is not None and u2 is not u1 - assert u2 in sess - - assert_raises(Exception, lambda: sess.add(u1)) - - sess.expunge(u2) - assert u2 not in sess - assert Session.object_session(u2) is None - - u1.name = "John" - u2.name = "Doe" + log = [] + class MyExt(sa.orm.session.SessionExtension): + def before_commit(self, session): + log.append('before_commit') + def after_commit(self, session): + log.append('after_commit') + def after_rollback(self, session): + log.append('after_rollback') + def before_flush(self, session, flush_context, objects): + log.append('before_flush') + def after_flush(self, session, flush_context): + log.append('after_flush') + def after_flush_postexec(self, session, flush_context): + log.append('after_flush_postexec') + def after_begin(self, session, transaction, connection): + log.append('after_begin') + def after_attach(self, session, instance): + log.append('after_attach') + def after_bulk_update( + self, + session, + query, + query_context, + result, + ): + log.append('after_bulk_update') - sess.add(u1) - assert u1 in sess - assert Session.object_session(u1) is sess + def after_bulk_delete( + self, + session, + query, + query_context, + result, + ): + log.append('after_bulk_delete') + sess = create_session(extension = MyExt()) + u = User(name='u1') + sess.add(u) sess.flush() + assert log == [ + 'after_attach', + 'before_flush', + 'after_begin', + 'after_flush', + 'before_commit', + 'after_commit', + 'after_flush_postexec', + ] + log = [] + sess = create_session(autocommit=False, extension=MyExt()) + u = User(name='u1') + sess.add(u) + sess.flush() + assert log == ['after_attach', 'before_flush', 'after_begin', + 'after_flush', 'after_flush_postexec'] + log = [] + u.name = 'ed' + sess.commit() + assert log == ['before_commit', 'before_flush', 'after_flush', + 'after_flush_postexec', 'after_commit'] + log = [] + sess.commit() + assert log == ['before_commit', 'after_commit'] + log = [] + sess.query(User).delete() + assert log == ['after_begin', 'after_bulk_delete'] + log = [] + sess.query(User).update({'name': 'foo'}) + assert log == ['after_bulk_update'] + log = [] + sess = create_session(autocommit=False, extension=MyExt(), + bind=testing.db) + conn = sess.connection() + assert log == ['after_begin'] - sess.expunge_all() - - u3 = sess.query(User).get(u1.id) - assert u3 is not u1 and u3 is not u2 and u3.name == u1.name - - @testing.resolve_artifact_names - def test_no_double_save(self): - sess = create_session() - class Foo(object): - def __init__(self): - sess.add(self) - class Bar(Foo): - def __init__(self): - sess.add(self) - Foo.__init__(self) - mapper(Foo, users) - mapper(Bar, users) - - b = Bar() - assert b in sess - assert len(list(sess)) == 1 - - @testing.resolve_artifact_names - def test_identity_map_mutate(self): - mapper(User, users) - sess = Session() - - sess.add_all([User(name='u1'), User(name='u2'), User(name='u3')]) - sess.commit() - - u1, u2, u3 = sess.query(User).all() - for i, (key, value) in enumerate(sess.identity_map.iteritems()): - if i == 2: - del u3 - gc_collect() - class DisposedStates(_base.MappedTest): run_setup_mappers = 'once'