]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixed session extension bug [ticket:757]
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Sep 2007 18:32:59 +0000 (18:32 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Sep 2007 18:32:59 +0000 (18:32 +0000)
CHANGES
lib/sqlalchemy/orm/session.py
test/orm/session.py

diff --git a/CHANGES b/CHANGES
index ca87687f4327ad2be196b00032f6399a04e17c4c..55877928dfa32406e6350a4d5820334810536785 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -59,7 +59,7 @@ CHANGES
   as it does in 0.3 since ~(x==y) compiles to "x != y", but still applies
   to operators like BETWEEN.
 
-- Other tickets: [ticket:768], [ticket:728], [ticket:779]
+- Other tickets: [ticket:768], [ticket:728], [ticket:779], [ticket:757]
 
 0.4.0beta5
 ----------
index 25f8bacab4331edc65467e172a02071c1aec51a3..fd75adff4563056062eb11730f39377ed1408172 100644 (file)
@@ -205,7 +205,7 @@ class SessionTransaction(object):
             return self.__parent
 
         if self.session.extension is not None:
-            self.session.before_commit(self.session)
+            self.session.extension.before_commit(self.session)
             
         if self.autoflush:
             self.session.flush()
@@ -218,7 +218,7 @@ class SessionTransaction(object):
             t[1].commit()
 
         if self.session.extension is not None:
-            self.session.after_commit(self.session)
+            self.session.extension.after_commit(self.session)
 
         self.close()
         return self.__parent
index 3d86a8cf2d5ba17fc7e5d92f7c2e74324b43f727..33d4ba9c85aa3490072d4ab6e3c7719a27860785 100644 (file)
@@ -2,6 +2,7 @@ import testbase
 from sqlalchemy import *
 from sqlalchemy import exceptions
 from sqlalchemy.orm import *
+from sqlalchemy.orm.session import SessionExtension
 from sqlalchemy.orm.session import Session as SessionCls
 from testlib import *
 from testlib.tables import *
@@ -592,7 +593,42 @@ class SessionTest(AssertMixin):
         self._assert_key(key, (User, (1,), None))
         key = s.identity_key(User, row=row, entity_name="en")
         self._assert_key(key, (User, (1,), "en"))
+        
+    def test_extension(self):
+        mapper(User, users)
+        log = []
+        class MyExt(SessionExtension):
+            def before_commit(self, session):
+                log.append('before_commit')
+            def after_commit(self, session):
+                log.append('after_commit')
+            def after_rollback(self, session):
+                log.append('after_rollback')
+            def before_flush(self, session, flush_context, objects):
+                log.append('before_flush')
+            def after_flush(self, session, flush_context):
+                log.append('after_flush')
+            def after_flush_postexec(self, session, flush_context):
+                log.append('after_flush_postexec')
+        sess = create_session(extension = MyExt())
+        u = User()
+        sess.save(u)
+        sess.flush()
+        
+        assert log == ['before_flush', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
+        
+        log = []
+        sess = create_session(transactional=True, extension=MyExt())
+        u = User()
+        sess.save(u)
+        sess.flush()
+        assert log == ['before_flush', 'after_flush', 'after_flush_postexec']
 
+        log = []
+        sess.commit()
+        assert log == ['before_commit', 'before_flush', 'after_flush', 'after_flush_postexec', 'after_commit']
+        
+        
 class ScopedSessionTest(ORMTest):
 
     def define_tables(self, metadata):