]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
retrieve 1.3 transaction from branched connection properly
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Jan 2021 19:26:46 +0000 (14:26 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Jan 2021 19:27:26 +0000 (14:27 -0500)
Fixed regression where Alembic would fail to create a transaction properly
if the :class:`sqlalchemy.engine.Connection` were a so-called "branched"
connection, that is, one where the ``.connect()`` method had been called to
create a "sub" connection.

Change-Id: I5319838a08686ede7dc873ce5d39428b1afdf6ff
Fixes: #782
alembic/util/sqla_compat.py
docs/build/unreleased/782.rst [new file with mode: 0644]
tests/test_script_consumption.py

index 29c2519d000e2890bb431a4d7e59eaf422ee14f3..d23f1ebbf11708b39a9c742b93788311f8d0cec6 100644 (file)
@@ -102,7 +102,7 @@ def _get_connection_transaction(connection):
     if sqla_14:
         return connection.get_transaction()
     else:
-        return connection._Connection__transaction
+        return connection._root._Connection__transaction
 
 
 def _create_url(*arg, **kw):
diff --git a/docs/build/unreleased/782.rst b/docs/build/unreleased/782.rst
new file mode 100644 (file)
index 0000000..0a00b46
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, environment
+    :tickets: 782
+
+    Fixed regression where Alembic would fail to create a transaction properly
+    if the :class:`sqlalchemy.engine.Connection` were a so-called "branched"
+    connection, that is, one where the ``.connect()`` method had been called to
+    create a "sub" connection.
index e1b094f6e0017d801ce9270a642d39e9b0ab6de7..e7eda64856a1e3641529724fa977a45eab278634 100644 (file)
@@ -32,6 +32,8 @@ from alembic.util import compat
 
 
 class PatchEnvironment(object):
+    branched_connection = False
+
     @contextmanager
     def _patch_environment(self, transactional_ddl, transaction_per_migration):
         conf = EnvironmentContext.configure
@@ -55,22 +57,74 @@ class PatchEnvironment(object):
             # mode
             assert not conn[0].in_transaction()
 
+    @staticmethod
+    def _branched_connection_env():
+        if config.requirements.sqlalchemy_14.enabled:
+            connect_warning = (
+                'r"The Connection.connect\\(\\) method is considered legacy"'
+            )
+            close_warning = (
+                'r"The .close\\(\\) method on a '
+                "so-called 'branched' connection\""
+            )
+        else:
+            connect_warning = close_warning = ""
+
+        env_file_fixture(
+            textwrap.dedent(
+                """\
+            import alembic
+            from alembic import context
+            from sqlalchemy import engine_from_config, pool
+            from sqlalchemy.testing import expect_warnings
+
+            config = context.config
+
+            target_metadata = None
+
+            def run_migrations_online():
+                connectable = engine_from_config(
+                    config.get_section(config.config_ini_section),
+                    prefix='sqlalchemy.',
+                    poolclass=pool.NullPool)
+
+                with connectable.connect() as conn:
+
+                    with expect_warnings(%(connect_warning)s):
+                        connection = conn.connect()
+                    try:
+                            context.configure(
+                                connection=connection,
+                                target_metadata=target_metadata,
+                            )
+                            with context.begin_transaction():
+                                context.run_migrations()
+                    finally:
+                        with expect_warnings(%(close_warning)s):
+                            connection.close()
+
+            if context.is_offline_mode():
+                assert False
+            else:
+                run_migrations_online()
+            """
+                % {
+                    "connect_warning": connect_warning,
+                    "close_warning": close_warning,
+                }
+            )
+        )
+
 
 @testing.combinations(
-    (
-        False,
-        True,
-    ),
-    (
-        True,
-        False,
-    ),
-    (
-        True,
-        True,
-    ),
-    argnames="transactional_ddl,transaction_per_migration",
-    id_="rr",
+    (False, True, False),
+    (True, False, False),
+    (True, True, False),
+    (False, True, True),
+    (True, False, True),
+    (True, True, True),
+    argnames="transactional_ddl,transaction_per_migration,branched_connection",
+    id_="rrr",
 )
 class ApplyVersionsFunctionalTest(PatchEnvironment, TestBase):
     __only_on__ = "sqlite"
@@ -79,6 +133,7 @@ class ApplyVersionsFunctionalTest(PatchEnvironment, TestBase):
     future = False
     transactional_ddl = False
     transaction_per_migration = True
+    branched_connection = False
 
     def setUp(self):
         self.bind = _sqlite_file_db(future=self.future)
@@ -86,6 +141,8 @@ class ApplyVersionsFunctionalTest(PatchEnvironment, TestBase):
         self.cfg = _sqlite_testing_config(
             sourceless=self.sourceless, future=self.future
         )
+        if self.branched_connection:
+            self._branched_connection_env()
 
     def tearDown(self):
         clear_staging_env()
@@ -223,18 +280,9 @@ class ApplyVersionsFunctionalTest(PatchEnvironment, TestBase):
 # class level combinations can't do the skips for SQLAlchemy 1.3
 # so we have a separate class
 @testing.combinations(
-    (
-        False,
-        True,
-    ),
-    (
-        True,
-        False,
-    ),
-    (
-        True,
-        True,
-    ),
+    (False, True),
+    (True, False),
+    (True, True),
     argnames="transactional_ddl,transaction_per_migration",
     id_="rr",
 )
@@ -396,10 +444,14 @@ class OnlineTransactionalDDLTest(PatchEnvironment, TestBase):
         else:
             self.cfg = _sqlite_testing_config()
 
+        if self.branched_connection:
+            self._branched_connection_env()
+
         script = ScriptDirectory.from_config(self.cfg)
         a = util.rev_id()
         b = util.rev_id()
         c = util.rev_id()
+
         script.generate_revision(a, "revision a", refresh=True)
         write_script(
             script,
@@ -471,20 +523,13 @@ def downgrade():
     # these tests might not be supported anymore; the connection is always
     # going to be in a transaction now even on 1.3.
 
-    @testing.combinations((False,), (True, config.requirements.sqlalchemy_14))
-    def test_raise_when_rev_leaves_open_transaction(self, future):
-        a, b, c = self._opened_transaction_fixture(future)
+    def test_raise_when_rev_leaves_open_transaction(self):
+        a, b, c = self._opened_transaction_fixture()
 
         with self._patch_environment(
             transactional_ddl=False, transaction_per_migration=False
         ):
-            if future:
-                with testing.expect_raises_message(
-                    sa.exc.InvalidRequestError,
-                    "a transaction is already begun",
-                ):
-                    command.upgrade(self.cfg, c)
-            elif config.requirements.sqlalchemy_14.enabled:
+            if config.requirements.sqlalchemy_14.enabled:
                 if self.is_sqlalchemy_future:
                     with testing.expect_raises_message(
                         sa.exc.InvalidRequestError,
@@ -500,20 +545,13 @@ def downgrade():
             else:
                 command.upgrade(self.cfg, c)
 
-    @testing.combinations((False,), (True, config.requirements.sqlalchemy_14))
-    def test_raise_when_rev_leaves_open_transaction_tpm(self, future):
-        a, b, c = self._opened_transaction_fixture(future)
+    def test_raise_when_rev_leaves_open_transaction_tpm(self):
+        a, b, c = self._opened_transaction_fixture()
 
         with self._patch_environment(
             transactional_ddl=False, transaction_per_migration=True
         ):
-            if future:
-                with testing.expect_raises_message(
-                    sa.exc.InvalidRequestError,
-                    "a transaction is already begun",
-                ):
-                    command.upgrade(self.cfg, c)
-            elif config.requirements.sqlalchemy_14.enabled:
+            if config.requirements.sqlalchemy_14.enabled:
                 if self.is_sqlalchemy_future:
                     with testing.expect_raises_message(
                         sa.exc.InvalidRequestError,
@@ -529,8 +567,7 @@ def downgrade():
             else:
                 command.upgrade(self.cfg, c)
 
-    @testing.combinations((False,), (True, config.requirements.sqlalchemy_14))
-    def test_noerr_rev_leaves_open_transaction_transactional_ddl(self, future):
+    def test_noerr_rev_leaves_open_transaction_transactional_ddl(self):
         a, b, c = self._opened_transaction_fixture()
 
         with self._patch_environment(
@@ -584,6 +621,10 @@ run_migrations_online()
         command.stamp(self.cfg, c)
 
 
+class BranchedOnlineTransactionalDDLTest(OnlineTransactionalDDLTest):
+    branched_connection = True
+
+
 class FutureOnlineTransactionalDDLTest(
     FutureEngineMixin, OnlineTransactionalDDLTest
 ):