)
- class MapperExtensionTest(_fixtures.FixtureTest):
-
- """Superseded by MapperEventsTest - test backwards
- compatibility of MapperExtension."""
-
- run_inserts = None
-
- def extension(self):
- methods = []
-
- class Ext(sa.orm.MapperExtension):
- def instrument_class(self, mapper, cls):
- methods.append("instrument_class")
- return sa.orm.EXT_CONTINUE
-
- def init_instance(
- self, mapper, class_, oldinit, instance, args, kwargs
- ):
- methods.append("init_instance")
- return sa.orm.EXT_CONTINUE
-
- def init_failed(
- self, mapper, class_, oldinit, instance, args, kwargs
- ):
- methods.append("init_failed")
- return sa.orm.EXT_CONTINUE
-
- def reconstruct_instance(self, mapper, instance):
- methods.append("reconstruct_instance")
- return sa.orm.EXT_CONTINUE
-
- def before_insert(self, mapper, connection, instance):
- methods.append("before_insert")
- return sa.orm.EXT_CONTINUE
-
- def after_insert(self, mapper, connection, instance):
- methods.append("after_insert")
- return sa.orm.EXT_CONTINUE
-
- def before_update(self, mapper, connection, instance):
- methods.append("before_update")
- return sa.orm.EXT_CONTINUE
-
- def after_update(self, mapper, connection, instance):
- methods.append("after_update")
- return sa.orm.EXT_CONTINUE
-
- def before_delete(self, mapper, connection, instance):
- methods.append("before_delete")
- return sa.orm.EXT_CONTINUE
-
- def after_delete(self, mapper, connection, instance):
- methods.append("after_delete")
- return sa.orm.EXT_CONTINUE
-
- return Ext, methods
-
- def test_basic(self):
- """test that common user-defined methods get called."""
-
- User, users = self.classes.User, self.tables.users
-
- Ext, methods = self.extension()
-
- mapper(User, users, extension=Ext())
- sess = create_session()
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- u = sess.query(User).populate_existing().get(u.id)
- sess.expunge_all()
- u = sess.query(User).get(u.id)
- u.name = "u1 changed"
- sess.flush()
- sess.delete(u)
- sess.flush()
- eq_(
- methods,
- [
- "instrument_class",
- "init_instance",
- "before_insert",
- "after_insert",
- "reconstruct_instance",
- "before_update",
- "after_update",
- "before_delete",
- "after_delete",
- ],
- )
-
- def test_inheritance(self):
- users, addresses, User = (
- self.tables.users,
- self.tables.addresses,
- self.classes.User,
- )
-
- Ext, methods = self.extension()
-
- class AdminUser(User):
- pass
-
- mapper(User, users, extension=Ext())
- mapper(
- AdminUser,
- addresses,
- inherits=User,
- properties={"address_id": addresses.c.id},
- )
-
- sess = create_session()
- am = AdminUser(name="au1", email_address="au1@e1")
- sess.add(am)
- sess.flush()
- am = sess.query(AdminUser).populate_existing().get(am.id)
- sess.expunge_all()
- am = sess.query(AdminUser).get(am.id)
- am.name = "au1 changed"
- sess.flush()
- sess.delete(am)
- sess.flush()
- eq_(
- methods,
- [
- "instrument_class",
- "instrument_class",
- "init_instance",
- "before_insert",
- "after_insert",
- "reconstruct_instance",
- "before_update",
- "after_update",
- "before_delete",
- "after_delete",
- ],
- )
-
- def test_before_after_only_collection(self):
- """before_update is called on parent for collection modifications,
- after_update is called even if no columns were updated.
-
- """
-
- keywords, items, item_keywords, Keyword, Item = (
- self.tables.keywords,
- self.tables.items,
- self.tables.item_keywords,
- self.classes.Keyword,
- self.classes.Item,
- )
-
- Ext1, methods1 = self.extension()
- Ext2, methods2 = self.extension()
-
- mapper(
- Item,
- items,
- extension=Ext1(),
- properties={
- "keywords": relationship(Keyword, secondary=item_keywords)
- },
- )
- mapper(Keyword, keywords, extension=Ext2())
-
- sess = create_session()
- i1 = Item(description="i1")
- k1 = Keyword(name="k1")
- sess.add(i1)
- sess.add(k1)
- sess.flush()
- eq_(
- methods1,
- [
- "instrument_class",
- "init_instance",
- "before_insert",
- "after_insert",
- ],
- )
- eq_(
- methods2,
- [
- "instrument_class",
- "init_instance",
- "before_insert",
- "after_insert",
- ],
- )
-
- del methods1[:]
- del methods2[:]
- i1.keywords.append(k1)
- sess.flush()
- eq_(methods1, ["before_update", "after_update"])
- eq_(methods2, [])
-
- def test_inheritance_with_dupes(self):
- """Inheritance with the same extension instance on both mappers."""
-
- users, addresses, User = (
- self.tables.users,
- self.tables.addresses,
- self.classes.User,
- )
-
- Ext, methods = self.extension()
-
- class AdminUser(User):
- pass
-
- ext = Ext()
- mapper(User, users, extension=ext)
- mapper(
- AdminUser,
- addresses,
- inherits=User,
- extension=ext,
- properties={"address_id": addresses.c.id},
- )
-
- sess = create_session()
- am = AdminUser(name="au1", email_address="au1@e1")
- sess.add(am)
- sess.flush()
- am = sess.query(AdminUser).populate_existing().get(am.id)
- sess.expunge_all()
- am = sess.query(AdminUser).get(am.id)
- am.name = "au1 changed"
- sess.flush()
- sess.delete(am)
- sess.flush()
- eq_(
- methods,
- [
- "instrument_class",
- "instrument_class",
- "init_instance",
- "before_insert",
- "after_insert",
- "reconstruct_instance",
- "before_update",
- "after_update",
- "before_delete",
- "after_delete",
- ],
- )
-
- def test_unnecessary_methods_not_evented(self):
- users = self.tables.users
-
- class MyExtension(sa.orm.MapperExtension):
- def before_insert(self, mapper, connection, instance):
- pass
-
- class Foo(object):
- pass
-
- m = mapper(Foo, users, extension=MyExtension())
- assert not m.class_manager.dispatch.load
- assert not m.dispatch.before_update
- assert len(m.dispatch.before_insert) == 1
-
-
- class AttributeExtensionTest(fixtures.MappedTest):
- @classmethod
- def define_tables(cls, metadata):
- Table(
- "t1",
- metadata,
- Column("id", Integer, primary_key=True),
- Column("type", String(40)),
- Column("data", String(50)),
- )
-
- def test_cascading_extensions(self):
- t1 = self.tables.t1
-
- ext_msg = []
-
- class Ex1(sa.orm.AttributeExtension):
- def set(self, state, value, oldvalue, initiator):
- ext_msg.append("Ex1 %r" % value)
- return "ex1" + value
-
- class Ex2(sa.orm.AttributeExtension):
- def set(self, state, value, oldvalue, initiator):
- ext_msg.append("Ex2 %r" % value)
- return "ex2" + value
-
- class A(fixtures.BasicEntity):
- pass
-
- class B(A):
- pass
-
- class C(B):
- pass
-
- mapper(
- A,
- t1,
- polymorphic_on=t1.c.type,
- polymorphic_identity="a",
- properties={"data": column_property(t1.c.data, extension=Ex1())},
- )
- mapper(B, polymorphic_identity="b", inherits=A)
- mapper(
- C,
- polymorphic_identity="c",
- inherits=B,
- properties={"data": column_property(t1.c.data, extension=Ex2())},
- )
-
- a1 = A(data="a1")
- b1 = B(data="b1")
- c1 = C(data="c1")
-
- eq_(a1.data, "ex1a1")
- eq_(b1.data, "ex1b1")
- eq_(c1.data, "ex2c1")
-
- a1.data = "a2"
- b1.data = "b2"
- c1.data = "c2"
- eq_(a1.data, "ex1a2")
- eq_(b1.data, "ex1b2")
- eq_(c1.data, "ex2c2")
-
- eq_(
- ext_msg,
- [
- "Ex1 'a1'",
- "Ex1 'b1'",
- "Ex2 'c1'",
- "Ex1 'a2'",
- "Ex1 'b2'",
- "Ex2 'c2'",
- ],
- )
-
-
- class SessionExtensionTest(_fixtures.FixtureTest):
- run_inserts = None
-
- def test_extension(self):
- User, users = self.classes.User, self.tables.users
-
- mapper(User, users)
- log = []
-
- class MyExt(sa.orm.session.SessionExtension):
- def before_commit(self, session):
- log.append("before_commit")
-
- def after_commit(self, session):
- log.append("after_commit")
-
- def after_rollback(self, session):
- log.append("after_rollback")
-
- def before_flush(self, session, flush_context, objects):
- log.append("before_flush")
-
- def after_flush(self, session, flush_context):
- log.append("after_flush")
-
- def after_flush_postexec(self, session, flush_context):
- log.append("after_flush_postexec")
-
- def after_begin(self, session, transaction, connection):
- log.append("after_begin")
-
- def after_attach(self, session, instance):
- log.append("after_attach")
-
- def after_bulk_update(self, session, query, query_context, result):
- log.append("after_bulk_update")
-
- def after_bulk_delete(self, session, query, query_context, result):
- log.append("after_bulk_delete")
-
- sess = create_session(extension=MyExt())
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- assert log == [
- "after_attach",
- "before_flush",
- "after_begin",
- "after_flush",
- "after_flush_postexec",
- "before_commit",
- "after_commit",
- ]
- log = []
- sess = create_session(autocommit=False, extension=MyExt())
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- assert log == [
- "after_attach",
- "before_flush",
- "after_begin",
- "after_flush",
- "after_flush_postexec",
- ]
- log = []
- u.name = "ed"
- sess.commit()
- assert log == [
- "before_commit",
- "before_flush",
- "after_flush",
- "after_flush_postexec",
- "after_commit",
- ]
- log = []
- sess.commit()
- assert log == ["before_commit", "after_commit"]
- log = []
- sess.query(User).delete()
- assert log == ["after_begin", "after_bulk_delete"]
- log = []
- sess.query(User).update({"name": "foo"})
- assert log == ["after_bulk_update"]
- log = []
- sess = create_session(
- autocommit=False, extension=MyExt(), bind=testing.db
- )
- sess.connection()
- assert log == ["after_begin"]
- sess.close()
-
- def test_multiple_extensions(self):
- User, users = self.classes.User, self.tables.users
-
- log = []
-
- class MyExt1(sa.orm.session.SessionExtension):
- def before_commit(self, session):
- log.append("before_commit_one")
-
- class MyExt2(sa.orm.session.SessionExtension):
- def before_commit(self, session):
- log.append("before_commit_two")
-
- mapper(User, users)
- sess = create_session(extension=[MyExt1(), MyExt2()])
- u = User(name="u1")
- sess.add(u)
- sess.flush()
- assert log == ["before_commit_one", "before_commit_two"]
-
- def test_unnecessary_methods_not_evented(self):
- class MyExtension(sa.orm.session.SessionExtension):
- def before_commit(self, session):
- pass
-
- s = Session(extension=MyExtension())
- assert not s.dispatch.after_commit
- assert len(s.dispatch.before_commit) == 1
-
-
class QueryEventsTest(
- _RemoveListeners, _fixtures.FixtureTest, AssertsCompiledSQL
+ _RemoveListeners,
+ _fixtures.FixtureTest,
+ AssertsCompiledSQL,
+ testing.AssertsExecutionResults,
):
__dialect__ = "default"