]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fix multiple consequent two phase transactions not working with postgres. For some...
authorAnts Aasma <ants.aasma@gmail.com>
Mon, 8 Oct 2007 15:25:51 +0000 (15:25 +0000)
committerAnts Aasma <ants.aasma@gmail.com>
Mon, 8 Oct 2007 15:25:51 +0000 (15:25 +0000)
- add an option to scoped session mapper extension to not automatically save new objects to session.

CHANGES
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/orm/scoping.py
test/engine/transaction.py

diff --git a/CHANGES b/CHANGES
index 9fd4b82ca619401a07d22db60bbbf1795385fa6f..bad34bd22aa7656149af4b75c0fd49cee67090ee 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -48,6 +48,11 @@ CHANGES
 
 - Firebird now uses dialect.preparer to format sequences names
 
+- Fixed breakage with postgres and multiple two phase transactions. For some
+  reason the implicitly started transaction is not enough. [ticket:810]
+
+- Added an option to the _ScopedExt mapper extension to not automatically
+  save new objects to session on object initialization.
 
 0.4.0beta6
 ----------
index 74b9e6f43707da3ed525fec9a8ee540a6ffd23ae..345893524c2ae09f09a5502ff5b8b4d210d36062 100644 (file)
@@ -306,6 +306,9 @@ class PGDialect(default.DefaultDialect):
         return sqltypes.adapt_type(typeobj, colspecs)
 
     def do_begin_twophase(self, connection, xid):
+        # Two phase transactions seem to require that the transaction is explicitly started.
+        # The implicit transactions that usually work aren't enough.
+        connection.execute(sql.text("BEGIN"))
         self.do_begin(connection.connection)
 
     def do_prepare_twophase(self, connection, xid):
index f4cec04333db8559d0f73fff4504e30d2e53ecad..e29f91da7d1c76a76564da1afd1fb8dc09ed9a53 100644 (file)
@@ -1,4 +1,4 @@
-from sqlalchemy.util import ScopedRegistry, to_list
+from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs
 from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session
 from sqlalchemy.orm.session import Session
 from sqlalchemy import exceptions
@@ -52,10 +52,12 @@ class ScopedSession(object):
         """return a mapper() function which associates this ScopedSession with the Mapper."""
         
         from sqlalchemy.orm import mapper
-        validate = kwargs.pop('validate', False)
+        
+        extension_args = dict((arg,kwargs.pop(arg)) for arg in get_cls_kwargs(_ScopedExt) if arg in kwargs)
+        
         kwargs['extension'] = extension = to_list(kwargs.get('extension', []))
-        if validate:
-            extension.append(self.extension.validating())
+        if extension_args:
+            extension.append(self.extension.configure(**extension_args))
         else:
             extension.append(self.extension)
         return mapper(*args, **kwargs)
@@ -89,13 +91,17 @@ for prop in ('close_all','object_session', 'identity_key'):
     setattr(ScopedSession, prop, clslevel(prop))
     
 class _ScopedExt(MapperExtension):
-    def __init__(self, context, validate=False):
+    def __init__(self, context, validate=False, save_on_init=True):
         self.context = context
         self.validate = validate
+        self.save_on_init = save_on_init
     
     def validating(self):
         return _ScopedExt(self.context, validate=True)
-        
+    
+    def configure(self, **kwargs):
+        return _ScopedExt(self.context, **kwargs)
+    
     def get_session(self):
         return self.context.registry()
 
@@ -117,7 +123,8 @@ class _ScopedExt(MapperExtension):
                     if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
                         raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
                 setattr(instance, key, value)
-        session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
+        if self.save_on_init:
+            session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
         return EXT_CONTINUE
 
     def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
index 6a26bf597d371582454142ac6b631d3bc1da1126..6a5383b8c48b23c1751f4fd463f8e49957444b26 100644 (file)
@@ -310,7 +310,36 @@ class TransactionTest(PersistTest):
             [(1,)]
         )
         connection2.close()
-
+        
+    @testing.supported('postgres', 'mysql')
+    @testing.exclude('mysql', '<', (5, 0, 3))
+    def testmultipletwophase(self):
+        conn = testbase.db.connect()
+        
+        xa = conn.begin_twophase()
+        conn.execute(users.insert(), user_id=1, user_name='user1')
+        xa.prepare()
+        xa.commit()
+        
+        xa = conn.begin_twophase()
+        conn.execute(users.insert(), user_id=2, user_name='user2')
+        xa.prepare()
+        xa.rollback()
+        
+        xa = conn.begin_twophase()
+        conn.execute(users.insert(), user_id=3, user_name='user3')
+        xa.rollback()
+        
+        xa = conn.begin_twophase()
+        conn.execute(users.insert(), user_id=4, user_name='user4')
+        xa.prepare()
+        xa.commit()
+        
+        result = conn.execute(select([users.c.user_name]).order_by(users.c.user_id))
+        self.assertEqual(result.fetchall(), [('user1',),('user4',)])
+        
+        conn.close()
+        
 class AutoRollbackTest(PersistTest):
     def setUpAll(self):
         global metadata