From: Mike Bayer Date: Sat, 22 Sep 2007 18:32:59 +0000 (+0000) Subject: fixed session extension bug [ticket:757] X-Git-Tag: rel_0_4beta6~17 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=94aa03c63f470a868150cf36baccd34ea3d2210e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fixed session extension bug [ticket:757] --- diff --git a/CHANGES b/CHANGES index ca87687f43..55877928df 100644 --- a/CHANGES +++ b/CHANGES @@ -59,7 +59,7 @@ CHANGES as it does in 0.3 since ~(x==y) compiles to "x != y", but still applies to operators like BETWEEN. -- Other tickets: [ticket:768], [ticket:728], [ticket:779] +- Other tickets: [ticket:768], [ticket:728], [ticket:779], [ticket:757] 0.4.0beta5 ---------- diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 25f8bacab4..fd75adff45 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -205,7 +205,7 @@ class SessionTransaction(object): return self.__parent if self.session.extension is not None: - self.session.before_commit(self.session) + self.session.extension.before_commit(self.session) if self.autoflush: self.session.flush() @@ -218,7 +218,7 @@ class SessionTransaction(object): t[1].commit() if self.session.extension is not None: - self.session.after_commit(self.session) + self.session.extension.after_commit(self.session) self.close() return self.__parent diff --git a/test/orm/session.py b/test/orm/session.py index 3d86a8cf2d..33d4ba9c85 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -2,6 +2,7 @@ import testbase from sqlalchemy import * from sqlalchemy import exceptions from sqlalchemy.orm import * +from sqlalchemy.orm.session import SessionExtension from sqlalchemy.orm.session import Session as SessionCls from testlib import * from testlib.tables import * @@ -592,7 +593,42 @@ class SessionTest(AssertMixin): self._assert_key(key, (User, (1,), None)) key = s.identity_key(User, row=row, entity_name="en") self._assert_key(key, (User, (1,), "en")) + + def test_extension(self): + mapper(User, users) + log = [] + class MyExt(SessionExtension): + def before_commit(self, session): + log.append('before_commit') + def after_commit(self, session): + log.append('after_commit') + def after_rollback(self, session): + log.append('after_rollback') + def before_flush(self, session, flush_context, objects): + log.append('before_flush') + def after_flush(self, session, flush_context): + log.append('after_flush') + def after_flush_postexec(self, session, flush_context): + log.append('after_flush_postexec') + sess = create_session(extension = MyExt()) + u = User() + sess.save(u) + sess.flush() + + assert log == ['before_flush', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec'] + + log = [] + sess = create_session(transactional=True, extension=MyExt()) + u = User() + sess.save(u) + sess.flush() + assert log == ['before_flush', 'after_flush', 'after_flush_postexec'] + log = [] + sess.commit() + assert log == ['before_commit', 'before_flush', 'after_flush', 'after_flush_postexec', 'after_commit'] + + class ScopedSessionTest(ORMTest): def define_tables(self, metadata):