from sqlalchemy import *
from sqlalchemy import exceptions
from sqlalchemy.orm import *
+from sqlalchemy.orm.session import SessionExtension
from sqlalchemy.orm.session import Session as SessionCls
from testlib import *
from testlib.tables import *
self._assert_key(key, (User, (1,), None))
key = s.identity_key(User, row=row, entity_name="en")
self._assert_key(key, (User, (1,), "en"))
+
+ def test_extension(self):
+ mapper(User, users)
+ log = []
+ class MyExt(SessionExtension):
+ def before_commit(self, session):
+ log.append('before_commit')
+ def after_commit(self, session):
+ log.append('after_commit')
+ def after_rollback(self, session):
+ log.append('after_rollback')
+ def before_flush(self, session, flush_context, objects):
+ log.append('before_flush')
+ def after_flush(self, session, flush_context):
+ log.append('after_flush')
+ def after_flush_postexec(self, session, flush_context):
+ log.append('after_flush_postexec')
+ sess = create_session(extension = MyExt())
+ u = User()
+ sess.save(u)
+ sess.flush()
+
+ assert log == ['before_flush', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
+
+ log = []
+ sess = create_session(transactional=True, extension=MyExt())
+ u = User()
+ sess.save(u)
+ sess.flush()
+ assert log == ['before_flush', 'after_flush', 'after_flush_postexec']
+ log = []
+ sess.commit()
+ assert log == ['before_commit', 'before_flush', 'after_flush', 'after_flush_postexec', 'after_commit']
+
+
class ScopedSessionTest(ORMTest):
def define_tables(self, metadata):