]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- continue with [ticket:2907] and further clean up how we set up
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 13 Jan 2014 08:22:11 +0000 (03:22 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 13 Jan 2014 08:22:11 +0000 (03:22 -0500)
_reset_agent, so that it's local to the various begin_impl(),
rollback_impl(), etc.  this allows setting/resetting of the flag
to be symmetric.
- don't set _reset_agent if it's not None, don't unset it if it isn't
our own transaction.
- make sure we clean it out in close().
- basically, we're dealing here with pools using "threadlocal" that have a
counter, other various mismatches that the tests bring up
- test for recover() now has to invalidate() the previous connection,
because closing it actually rolls it back (e.g. this test was relying
on the broken behavior).

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/testing/engines.py
test/engine/test_transaction.py

index 5c66f4806e9e17b9fb9f52985f5e82f7b56c306b..1f2b7a3e58b1e46b9ff322ebc4eaf62ecfdc28dd 100644 (file)
@@ -404,7 +404,7 @@ class Connection(Connectable):
         """
 
         if self.__transaction is None:
-            self.__transaction = self.connection._reset_agent = RootTransaction(self)
+            self.__transaction = RootTransaction(self)
             return self.__transaction
         else:
             return Transaction(self, self.__transaction)
@@ -423,9 +423,8 @@ class Connection(Connectable):
         See also :meth:`.Connection.begin`,
         :meth:`.Connection.begin_twophase`.
         """
-
         if self.__transaction is None:
-            self.__transaction = self.connection._reset_agent = RootTransaction(self)
+            self.__transaction = RootTransaction(self)
         else:
             self.__transaction = NestedTransaction(self, self.__transaction)
         return self.__transaction
@@ -453,7 +452,7 @@ class Connection(Connectable):
                 "is already in progress.")
         if xid is None:
             xid = self.engine.dialect.create_xid()
-        self.__transaction = self.connection._reset_agent = TwoPhaseTransaction(self, xid)
+        self.__transaction = TwoPhaseTransaction(self, xid)
         return self.__transaction
 
     def recover_twophase(self):
@@ -470,7 +469,7 @@ class Connection(Connectable):
 
         return self.__transaction is not None
 
-    def _begin_impl(self):
+    def _begin_impl(self, transaction):
         if self._echo:
             self.engine.logger.info("BEGIN (implicit)")
 
@@ -479,6 +478,8 @@ class Connection(Connectable):
 
         try:
             self.engine.dialect.do_begin(self.connection)
+            if self.connection._reset_agent is None:
+                self.connection._reset_agent = transaction
         except Exception as e:
             self._handle_dbapi_exception(e, None, None, None, None)
 
@@ -491,11 +492,14 @@ class Connection(Connectable):
                 self.engine.logger.info("ROLLBACK")
             try:
                 self.engine.dialect.do_rollback(self.connection)
-                self.__transaction = self.connection._reset_agent = None
             except Exception as e:
                 self._handle_dbapi_exception(e, None, None, None, None)
+            finally:
+                if self.connection._reset_agent is self.__transaction:
+                    self.connection._reset_agent = None
+                self.__transaction = None
         else:
-            self.__transaction = self.connection._reset_agent = None
+            self.__transaction = None
 
     def _commit_impl(self, autocommit=False):
         if self._has_events:
@@ -505,9 +509,12 @@ class Connection(Connectable):
             self.engine.logger.info("COMMIT")
         try:
             self.engine.dialect.do_commit(self.connection)
-            self.__transaction = self.connection._reset_agent = None
         except Exception as e:
             self._handle_dbapi_exception(e, None, None, None, None)
+        finally:
+            if self.connection._reset_agent is self.__transaction:
+                self.connection._reset_agent = None
+            self.__transaction = None
 
     def _savepoint_impl(self, name=None):
         if self._has_events:
@@ -536,14 +543,17 @@ class Connection(Connectable):
             self.engine.dialect.do_release_savepoint(self, name)
         self.__transaction = context
 
-    def _begin_twophase_impl(self, xid):
+    def _begin_twophase_impl(self, transaction):
         if self._echo:
             self.engine.logger.info("BEGIN TWOPHASE (implicit)")
         if self._has_events:
-            self.dispatch.begin_twophase(self, xid)
+            self.dispatch.begin_twophase(self, transaction.xid)
 
         if self._still_open_and_connection_is_valid:
-            self.engine.dialect.do_begin_twophase(self, xid)
+            self.engine.dialect.do_begin_twophase(self, transaction.xid)
+
+            if self.connection._reset_agent is None:
+                self.connection._reset_agent = transaction
 
     def _prepare_twophase_impl(self, xid):
         if self._has_events:
@@ -559,8 +569,14 @@ class Connection(Connectable):
 
         if self._still_open_and_connection_is_valid:
             assert isinstance(self.__transaction, TwoPhaseTransaction)
-            self.engine.dialect.do_rollback_twophase(self, xid, is_prepared)
-        self.__transaction = self.connection._reset_agent = None
+            try:
+                self.engine.dialect.do_rollback_twophase(self, xid, is_prepared)
+            finally:
+                if self.connection._reset_agent is self.__transaction:
+                    self.connection._reset_agent = None
+                self.__transaction = None
+        else:
+            self.__transaction = None
 
     def _commit_twophase_impl(self, xid, is_prepared):
         if self._has_events:
@@ -568,8 +584,14 @@ class Connection(Connectable):
 
         if self._still_open_and_connection_is_valid:
             assert isinstance(self.__transaction, TwoPhaseTransaction)
-            self.engine.dialect.do_commit_twophase(self, xid, is_prepared)
-        self.__transaction = self.connection._reset_agent = None
+            try:
+                self.engine.dialect.do_commit_twophase(self, xid, is_prepared)
+            finally:
+                if self.connection._reset_agent is self.__transaction:
+                    self.connection._reset_agent = None
+                self.__transaction = None
+        else:
+            self.__transaction = None
 
     def _autorollback(self):
         if not self.in_transaction():
@@ -601,6 +623,8 @@ class Connection(Connectable):
         else:
             if not self.__branch:
                 conn.close()
+            if conn._reset_agent is self.__transaction:
+                conn._reset_agent = None
             del self.__connection
         self.__can_reconnect = False
         self.__transaction = None
@@ -1224,7 +1248,7 @@ class Transaction(object):
 class RootTransaction(Transaction):
     def __init__(self, connection):
         super(RootTransaction, self).__init__(connection, None)
-        self.connection._begin_impl()
+        self.connection._begin_impl(self)
 
     def _do_rollback(self):
         if self.is_active:
@@ -1273,7 +1297,7 @@ class TwoPhaseTransaction(Transaction):
         super(TwoPhaseTransaction, self).__init__(connection, None)
         self._is_prepared = False
         self.xid = xid
-        self.connection._begin_twophase_impl(self.xid)
+        self.connection._begin_twophase_impl(self)
 
     def prepare(self):
         """Prepare this :class:`.TwoPhaseTransaction`.
index d240645a898a86ee3aa3a8068f75d44b85737723..a74bffe26ba8d4a2b0dc1edffd566cd383191235 100644 (file)
@@ -32,6 +32,9 @@ class ConnectionKiller(object):
     def checkout(self, dbapi_con, con_record, con_proxy):
         self.proxy_refs[con_proxy] = True
 
+    def invalidate(self, dbapi_con, con_record, exception):
+        self.conns.discard((dbapi_con, con_record))
+
     def _safe(self, fn):
         try:
             fn()
@@ -49,7 +52,7 @@ class ConnectionKiller(object):
 
     def close_all(self):
         for rec in list(self.proxy_refs):
-            if rec is not None:
+            if rec is not None and rec.is_valid:
                 self._safe(rec._close)
 
     def _after_test_ctx(self):
@@ -226,6 +229,7 @@ def testing_engine(url=None, options=None):
     if use_reaper:
         event.listen(engine.pool, 'connect', testing_reaper.connect)
         event.listen(engine.pool, 'checkout', testing_reaper.checkout)
+        event.listen(engine.pool, 'invalidate', testing_reaper.invalidate)
         testing_reaper.add_engine(engine)
 
     return engine
index e3f5fc25211be5589231a6c4d8d371ac7acb6c5b..c373133d1fec423533e1bfd649df737500393b49 100644 (file)
@@ -342,7 +342,8 @@ class TransactionTest(fixtures.TestBase):
         transaction = connection.begin_twophase()
         connection.execute(users.insert(), user_id=1, user_name='user1')
         transaction.prepare()
-        connection.close()
+        connection.invalidate()
+
         connection2 = testing.db.connect()
         eq_(connection2.execute(select([users.c.user_id]).
             order_by(users.c.user_id)).fetchall(),