]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug where a "branched" connection, that is the kind you get
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 26 Sep 2014 20:25:26 +0000 (16:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 26 Sep 2014 20:25:26 +0000 (16:25 -0400)
when you call :meth:`.Connection.connect`, would not share transaction
status with the parent.  The architecture of branching has been tweaked
a bit so that the branched connection defers to the parent for
all transactional status and operations.
fixes #3190

doc/build/changelog/changelog_10.rst
lib/sqlalchemy/engine/base.py
test/engine/test_execute.py
test/engine/test_reconnect.py
test/engine/test_transaction.py

index a4f3dd6e59d1f4b47b265e52354b45e2e5f34e67..4d5ab1f06d9fd857e93751e00b18ff814844ae7b 100644 (file)
         a bit so that the branched connection defers to the parent for
         all invalidation status and operations.
 
+    .. change::
+        :tags: bug, sql, engine
+        :tickets: 3190
+
+        Fixed bug where a "branched" connection, that is the kind you get
+        when you call :meth:`.Connection.connect`, would not share transaction
+        status with the parent.  The architecture of branching has been tweaked
+        a bit so that the branched connection defers to the parent for
+        all transactional status and operations.
+
     .. change::
         :tags: bug, declarative
         :tickets: 2670
index ec7aed1c32fa8cd15cda1737191ea67247b67cc8..05bb1f4e593eaae965bee71426f75a590271802a 100644 (file)
@@ -57,29 +57,35 @@ class Connection(Connectable):
         """
         self.engine = engine
         self.dialect = engine.dialect
-        self.__connection = connection or engine.raw_connection()
-        self.__transaction = None
-        self.should_close_with_result = close_with_result
-        self.__savepoint_seq = 0
         self.__branch_from = _branch_from
         self.__branch = _branch_from is not None
-        self.__invalid = False
-        self.__can_reconnect = True
-        if _dispatch:
+
+        if _branch_from:
+            self.__connection = connection
+            self._execution_options = _execution_options
+            self._echo = _branch_from._echo
+            self.should_close_with_result = False
             self.dispatch = _dispatch
-        elif _has_events is None:
-            # if _has_events is sent explicitly as False,
-            # then don't join the dispatch of the engine; we don't
-            # want to handle any of the engine's events in that case.
-            self.dispatch = self.dispatch._join(engine.dispatch)
-        self._has_events = _has_events or (
-            _has_events is None and engine._has_events)
-
-        self._echo = self.engine._should_log_info()
-        if _execution_options:
-            self._execution_options =\
-                engine._execution_options.union(_execution_options)
+            self._has_events = _branch_from._has_events
         else:
+            self.__connection = connection \
+                if connection is not None else engine.raw_connection()
+            self.__transaction = None
+            self.__savepoint_seq = 0
+            self.should_close_with_result = close_with_result
+            self.__invalid = False
+            self.__can_reconnect = True
+            self._echo = self.engine._should_log_info()
+
+            if _has_events is None:
+                # if _has_events is sent explicitly as False,
+                # then don't join the dispatch of the engine; we don't
+                # want to handle any of the engine's events in that case.
+                self.dispatch = self.dispatch._join(engine.dispatch)
+            self._has_events = _has_events or (
+                _has_events is None and engine._has_events)
+
+            assert not _execution_options
             self._execution_options = engine._execution_options
 
         if self._has_events or self.engine._has_events:
@@ -90,8 +96,16 @@ class Connection(Connectable):
         engine and connection; but does not have close_with_result enabled,
         and also whose close() method does nothing.
 
-        This is used to execute "sub" statements within a single execution,
-        usually an INSERT statement.
+        The Core uses this very sparingly, only in the case of
+        custom SQL default functions that are to be INSERTed as the
+        primary key of a row where we need to get the value back, so we have
+        to invoke it distinctly - this is a very uncommon case.
+
+        Userland code accesses _branch() when the connect() or
+        contextual_connect() methods are called.  The branched connection
+        acts as much as possible like the parent, except that it stays
+        connected when a close() event occurs.
+
         """
         if self.__branch_from:
             return self.__branch_from._branch()
@@ -100,6 +114,7 @@ class Connection(Connectable):
                 self.engine,
                 self.__connection,
                 _branch_from=self,
+                _execution_options=self._execution_options,
                 _has_events=self._has_events,
                 _dispatch=self.dispatch)
 
@@ -108,7 +123,10 @@ class Connection(Connectable):
         """return the 'root' connection.
 
         Returns 'self' if this connection is not a branch, else
-        returns the root connection from which we ultimately branched."""
+        returns the root connection from which we ultimately branched.
+
+        """
+
         if self.__branch_from:
             return self.__branch_from
         else:
@@ -232,7 +250,7 @@ class Connection(Connectable):
         """Return True if this connection is closed."""
 
         return '_Connection__connection' not in self.__dict__ \
-            and not self._root.__can_reconnect
+            and not self.__can_reconnect
 
     @property
     def invalidated(self):
@@ -251,7 +269,7 @@ class Connection(Connectable):
 
     def _revalidate_connection(self):
         if self.__branch_from:
-            return self._root._revalidate_connection()
+            return self.__branch_from._revalidate_connection()
 
         if self.__can_reconnect and self.__invalid:
             if self.__transaction is not None:
@@ -360,9 +378,6 @@ class Connection(Connectable):
             :ref:`pool_connection_invalidation`
 
         """
-        if self.__branch_from:
-            self._root.invalidate()
-            return
 
         if self.invalidated:
             return
@@ -370,10 +385,10 @@ class Connection(Connectable):
         if self.closed:
             raise exc.ResourceClosedError("This Connection is closed")
 
-        if self._connection_is_valid:
-            self.__connection.invalidate(exception)
-        del self.__connection
-        self.__invalid = True
+        if self._root._connection_is_valid:
+            self._root.__connection.invalidate(exception)
+        del self._root.__connection
+        self._root.__invalid = True
 
     def detach(self):
         """Detach the underlying DB-API connection from its connection pool.
@@ -436,6 +451,8 @@ class Connection(Connectable):
         :class:`.Engine`.
 
         """
+        if self.__branch_from:
+            return self.__branch_from.begin()
 
         if self.__transaction is None:
             self.__transaction = RootTransaction(self)
@@ -457,6 +474,9 @@ class Connection(Connectable):
         See also :meth:`.Connection.begin`,
         :meth:`.Connection.begin_twophase`.
         """
+        if self.__branch_from:
+            return self.__branch_from.begin_nested()
+
         if self.__transaction is None:
             self.__transaction = RootTransaction(self)
         else:
@@ -480,6 +500,9 @@ class Connection(Connectable):
 
         """
 
+        if self.__branch_from:
+            return self.__branch_from.begin_twophase(xid=xid)
+
         if self.__transaction is not None:
             raise exc.InvalidRequestError(
                 "Cannot start a two phase transaction when a transaction "
@@ -500,10 +523,11 @@ class Connection(Connectable):
 
     def in_transaction(self):
         """Return True if a transaction is in progress."""
-
-        return self.__transaction is not None
+        return self._root.__transaction is not None
 
     def _begin_impl(self, transaction):
+        assert not self.__branch_from
+
         if self._echo:
             self.engine.logger.info("BEGIN (implicit)")
 
@@ -518,6 +542,8 @@ class Connection(Connectable):
             self._handle_dbapi_exception(e, None, None, None, None)
 
     def _rollback_impl(self):
+        assert not self.__branch_from
+
         if self._has_events or self.engine._has_events:
             self.dispatch.rollback(self)
 
@@ -537,6 +563,8 @@ class Connection(Connectable):
             self.__transaction = None
 
     def _commit_impl(self, autocommit=False):
+        assert not self.__branch_from
+
         if self._has_events or self.engine._has_events:
             self.dispatch.commit(self)
 
@@ -553,6 +581,8 @@ class Connection(Connectable):
             self.__transaction = None
 
     def _savepoint_impl(self, name=None):
+        assert not self.__branch_from
+
         if self._has_events or self.engine._has_events:
             self.dispatch.savepoint(self, name)
 
@@ -564,6 +594,8 @@ class Connection(Connectable):
             return name
 
     def _rollback_to_savepoint_impl(self, name, context):
+        assert not self.__branch_from
+
         if self._has_events or self.engine._has_events:
             self.dispatch.rollback_savepoint(self, name, context)
 
@@ -572,6 +604,8 @@ class Connection(Connectable):
         self.__transaction = context
 
     def _release_savepoint_impl(self, name, context):
+        assert not self.__branch_from
+
         if self._has_events or self.engine._has_events:
             self.dispatch.release_savepoint(self, name, context)
 
@@ -580,6 +614,8 @@ class Connection(Connectable):
         self.__transaction = context
 
     def _begin_twophase_impl(self, transaction):
+        assert not self.__branch_from
+
         if self._echo:
             self.engine.logger.info("BEGIN TWOPHASE (implicit)")
         if self._has_events or self.engine._has_events:
@@ -592,6 +628,8 @@ class Connection(Connectable):
                 self.connection._reset_agent = transaction
 
     def _prepare_twophase_impl(self, xid):
+        assert not self.__branch_from
+
         if self._has_events or self.engine._has_events:
             self.dispatch.prepare_twophase(self, xid)
 
@@ -600,6 +638,8 @@ class Connection(Connectable):
             self.engine.dialect.do_prepare_twophase(self, xid)
 
     def _rollback_twophase_impl(self, xid, is_prepared):
+        assert not self.__branch_from
+
         if self._has_events or self.engine._has_events:
             self.dispatch.rollback_twophase(self, xid, is_prepared)
 
@@ -616,6 +656,8 @@ class Connection(Connectable):
             self.__transaction = None
 
     def _commit_twophase_impl(self, xid, is_prepared):
+        assert not self.__branch_from
+
         if self._has_events or self.engine._has_events:
             self.dispatch.commit_twophase(self, xid, is_prepared)
 
@@ -653,13 +695,21 @@ class Connection(Connectable):
         and will allow no further operations.
 
         """
+        if self.__branch_from:
+            try:
+                del self.__connection
+            except AttributeError:
+                pass
+            finally:
+                self.__can_reconnect = False
+                return
         try:
             conn = self.__connection
         except AttributeError:
             pass
         else:
-            if not self.__branch:
-                conn.close()
+
+            conn.close()
             if conn._reset_agent is self.__transaction:
                 conn._reset_agent = None
 
@@ -1014,8 +1064,8 @@ class Connection(Connectable):
             result.rowcount
             result.close(_autoclose_connection=False)
 
-        if self.__transaction is None and context.should_autocommit:
-            self._commit_impl(autocommit=True)
+        if context.should_autocommit and self._root.__transaction is None:
+            self._root._commit_impl(autocommit=True)
 
         if result.closed and self.should_close_with_result:
             self.close()
index e14a4fd2a77aba6d708f37bfb39fddf25cb6e6f0..219a145c62a336e199717ab5ca9fea4603a4f485 100644 (file)
@@ -982,6 +982,17 @@ class ExecutionOptionsTest(fixtures.TestBase):
         eq_(c1._execution_options, {"foo": "bar"})
         eq_(c2._execution_options, {"foo": "bar", "bat": "hoho"})
 
+    def test_branched_connection_execution_options(self):
+        engine = testing_engine("sqlite://")
+
+        conn = engine.connect()
+        c2 = conn.execution_options(foo="bar")
+        c2_branch = c2.connect()
+        eq_(
+            c2_branch._execution_options,
+            {"foo": "bar"}
+        )
+
 
 class AlternateResultProxyTest(fixtures.TestBase):
     __requires__ = ('sqlite', )
index 26a60730120bd6bb5737c238513b6b6ecbc85379..4500ada6a02b0d13740fa2781f3e20e9d09404e7 100644 (file)
@@ -8,7 +8,7 @@ from sqlalchemy import testing
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.engines import testing_engine
-from sqlalchemy.testing.mock import Mock, call
+from sqlalchemy.testing.mock import Mock, call, patch
 
 
 class MockError(Exception):
@@ -507,18 +507,21 @@ class RealReconnectTest(fixtures.TestBase):
     def test_branched_invalidate_branch_to_parent(self):
         c1 = self.engine.connect()
 
-        c1_branch = c1.connect()
-        eq_(c1_branch.execute(select([1])).scalar(), 1)
+        with patch.object(self.engine.pool, "logger") as logger:
+            c1_branch = c1.connect()
+            eq_(c1_branch.execute(select([1])).scalar(), 1)
 
-        self.engine.test_shutdown()
+            self.engine.test_shutdown()
 
-        _assert_invalidated(c1_branch.execute, select([1]))
-        assert c1.invalidated
-        assert c1_branch.invalidated
+            _assert_invalidated(c1_branch.execute, select([1]))
+            assert c1.invalidated
+            assert c1_branch.invalidated
 
-        c1_branch._revalidate_connection()
-        assert not c1.invalidated
-        assert not c1_branch.invalidated
+            c1_branch._revalidate_connection()
+            assert not c1.invalidated
+            assert not c1_branch.invalidated
+
+        assert "Invalidate connection" in logger.mock_calls[0][1][0]
 
     def test_branched_invalidate_parent_to_branch(self):
         c1 = self.engine.connect()
@@ -536,6 +539,19 @@ class RealReconnectTest(fixtures.TestBase):
         assert not c1.invalidated
         assert not c1_branch.invalidated
 
+    def test_branch_invalidate_state(self):
+        c1 = self.engine.connect()
+
+        c1_branch = c1.connect()
+
+        eq_(c1_branch.execute(select([1])).scalar(), 1)
+
+        self.engine.test_shutdown()
+
+        _assert_invalidated(c1_branch.execute, select([1]))
+        assert not c1_branch.closed
+        assert not c1_branch._connection_is_valid
+
     def test_ensure_is_disconnect_gets_connection(self):
         def is_disconnect(e, conn, cursor):
             # connection is still present
index d921e9ead4056ec016f8dcffb8f4b4536114d82c..fbaf01db7bb4a426f09576e892870fe4d32313e2 100644 (file)
@@ -133,6 +133,79 @@ class TransactionTest(fixtures.TestBase):
             finally:
                 connection.close()
 
+    def test_branch_nested_rollback(self):
+        connection = testing.db.connect()
+        try:
+            connection.begin()
+            branched = connection.connect()
+            assert branched.in_transaction()
+            branched.execute(users.insert(), user_id=1, user_name='user1')
+            nested = branched.begin()
+            branched.execute(users.insert(), user_id=2, user_name='user2')
+            nested.rollback()
+            assert not connection.in_transaction()
+            eq_(connection.scalar("select count(*) from query_users"), 0)
+
+        finally:
+            connection.close()
+
+    def test_branch_orig_rollback(self):
+        connection = testing.db.connect()
+        try:
+            branched = connection.connect()
+            branched.execute(users.insert(), user_id=1, user_name='user1')
+            nested = branched.begin()
+            assert branched.in_transaction()
+            branched.execute(users.insert(), user_id=2, user_name='user2')
+            nested.rollback()
+            eq_(connection.scalar("select count(*) from query_users"), 1)
+
+        finally:
+            connection.close()
+
+    def test_branch_autocommit(self):
+        connection = testing.db.connect()
+        try:
+            branched = connection.connect()
+            branched.execute(users.insert(), user_id=1, user_name='user1')
+        finally:
+            connection.close()
+        eq_(testing.db.scalar("select count(*) from query_users"), 1)
+
+    @testing.requires.savepoints
+    def test_branch_savepoint_rollback(self):
+        connection = testing.db.connect()
+        try:
+            trans = connection.begin()
+            branched = connection.connect()
+            assert branched.in_transaction()
+            branched.execute(users.insert(), user_id=1, user_name='user1')
+            nested = branched.begin_nested()
+            branched.execute(users.insert(), user_id=2, user_name='user2')
+            nested.rollback()
+            assert connection.in_transaction()
+            trans.commit()
+            eq_(connection.scalar("select count(*) from query_users"), 1)
+
+        finally:
+            connection.close()
+
+    @testing.requires.two_phase_transactions
+    def test_branch_twophase_rollback(self):
+        connection = testing.db.connect()
+        try:
+            branched = connection.connect()
+            assert not branched.in_transaction()
+            branched.execute(users.insert(), user_id=1, user_name='user1')
+            nested = branched.begin_twophase()
+            branched.execute(users.insert(), user_id=2, user_name='user2')
+            nested.rollback()
+            assert not connection.in_transaction()
+            eq_(connection.scalar("select count(*) from query_users"), 1)
+
+        finally:
+            connection.close()
+
     def test_retains_through_options(self):
         connection = testing.db.connect()
         try: