]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
sqlalchemy 2.0 test updates
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Nov 2021 15:24:23 +0000 (11:24 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Nov 2021 15:24:23 +0000 (11:24 -0400)
- disable branched connection tests for 2.x
- dont use future flag for 2.x
- adjust batch tests for autobegin, inconsistent
  SQLite transactional DDL behaviors

Change-Id: I70caf6afecc83f880dc92fa6cbc29e2043c43bb9

alembic/testing/fixtures.py
alembic/testing/requirements.py
alembic/testing/util.py
alembic/util/__init__.py
alembic/util/sqla_compat.py
tests/test_batch.py
tests/test_script_consumption.py

index 5937d48541599e895f303eccabba86aeb1a38707..5e6ba89cb45d3b215ad29da2aaa04ebb11687072 100644 (file)
@@ -29,6 +29,7 @@ from ..util.compat import string_types
 from ..util.compat import text_type
 from ..util.sqla_compat import create_mock_engine
 from ..util.sqla_compat import sqla_14
+from ..util.sqla_compat import sqla_1x
 
 
 testing_config = configparser.ConfigParser()
@@ -36,7 +37,10 @@ testing_config.read(["test.cfg"])
 
 
 class TestBase(SQLAlchemyTestBase):
-    is_sqlalchemy_future = False
+    if sqla_1x:
+        is_sqlalchemy_future = False
+    else:
+        is_sqlalchemy_future = True
 
     @testing.fixture()
     def ops_context(self, migration_context):
index 37780ab0213c70cfa7e082686de041141d0e5874..3947272bc145a81f39281a3fd0082bb8c0b1b196 100644 (file)
@@ -83,6 +83,13 @@ class SuiteRequirements(Requirements):
             "SQLAlchemy 1.4 or greater required",
         )
 
+    @property
+    def sqlalchemy_1x(self):
+        return exclusions.skip_if(
+            lambda config: not util.sqla_1x,
+            "SQLAlchemy 1.x test",
+        )
+
     @property
     def comments(self):
         return exclusions.only_if(
index ccabf9cdc705a1bce6db85143ee76e37de4143db..9d24d0fe931f85fc44afd3990881f7d3e263ad0f 100644 (file)
@@ -5,7 +5,9 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
+import re
 import types
+from typing import Union
 
 
 def flag_combinations(*combinations):
@@ -97,11 +99,28 @@ def metadata_fixture(ddl="function"):
     return decorate
 
 
+def _safe_int(value: str) -> Union[int, str]:
+    try:
+        return int(value)
+    except:
+        return value
+
+
 def testing_engine(url=None, options=None, future=False):
     from sqlalchemy.testing import config
     from sqlalchemy.testing.engines import testing_engine
+    from sqlalchemy import __version__
+
+    _vers = tuple(
+        [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
+    )
+    sqla_1x = _vers < (2,)
 
     if not future:
         future = getattr(config._current.options, "future_engine", False)
-    kw = {"future": future} if future else {}
+
+    if sqla_1x:
+        kw = {"future": future} if future else {}
+    else:
+        kw = {}
     return testing_engine(url, options, **kw)
index 15c8f4e509b3d5d733c2335d4f89e6ad16569304..49bee432cdb016c50951500c431defc833da06bf 100644 (file)
@@ -25,6 +25,7 @@ from .pyfiles import template_to_file
 from .sqla_compat import has_computed
 from .sqla_compat import sqla_13
 from .sqla_compat import sqla_14
+from .sqla_compat import sqla_1x
 
 
 if not sqla_13:
index 65b11307b1342bbc160d7f488b8206a4e7052a41..a05e27bc6c1b5f2c640fb08c1234772f9517e985 100644 (file)
@@ -58,6 +58,7 @@ _vers = tuple(
 sqla_13 = _vers >= (1, 3)
 sqla_14 = _vers >= (1, 4)
 sqla_14_26 = _vers >= (1, 4, 26)
+sqla_1x = _vers < (2,)
 
 try:
     from sqlalchemy import Computed  # noqa
@@ -122,6 +123,14 @@ def _safe_begin_connection_transaction(
         return connection.begin()
 
 
+def _safe_commit_connection_transaction(
+    connection: "Connection",
+) -> None:
+    transaction = _get_connection_transaction(connection)
+    if transaction:
+        transaction.commit()
+
+
 def _safe_rollback_connection_transaction(
     connection: "Connection",
 ) -> None:
index 2753bdc3f9ff114726337211d4e1eb4533f670df..700056acb462209ff3acbb5277ff580ad5aec98c 100644 (file)
@@ -39,6 +39,7 @@ from alembic.testing import mock
 from alembic.testing import TestBase
 from alembic.testing.fixtures import op_fixture
 from alembic.util import exc as alembic_exc
+from alembic.util.sqla_compat import _safe_commit_connection_transaction
 from alembic.util.sqla_compat import _select
 from alembic.util.sqla_compat import has_computed
 from alembic.util.sqla_compat import has_identity
@@ -1282,6 +1283,20 @@ class BatchRoundTripTest(TestBase):
         context = MigrationContext.configure(self.conn)
         self.op = Operations(context)
 
+    def tearDown(self):
+        # why commit?  because SQLite has inconsistent treatment
+        # of transactional DDL. A test that runs CREATE TABLE and then
+        # ALTER TABLE to change the name of that table, will end up
+        # committing the CREATE TABLE but not the ALTER. As batch mode
+        # does this with a temp table name that's not even in the
+        # metadata collection, we don't have an explicit drop for it
+        # (though we could do that too).  calling commit means the
+        # ALTER will go through and the drop_all() will then catch it.
+        _safe_commit_connection_transaction(self.conn)
+        with self.conn.begin():
+            self.metadata.drop_all(self.conn)
+        self.conn.close()
+
     @contextmanager
     def _sqlite_referential_integrity(self):
         self.conn.exec_driver_sql("PRAGMA foreign_keys=ON")
@@ -1385,7 +1400,7 @@ class BatchRoundTripTest(TestBase):
                 type_=Integer,
                 existing_type=Boolean(create_constraint=True, name="ck1"),
             )
-        insp = inspect(config.db)
+        insp = inspect(self.conn)
 
         eq_(
             [
@@ -1440,7 +1455,7 @@ class BatchRoundTripTest(TestBase):
             batch_op.drop_column(
                 "x", existing_type=Boolean(create_constraint=True, name="ck1")
             )
-        insp = inspect(config.db)
+        insp = inspect(self.conn)
 
         assert "x" not in (c["name"] for c in insp.get_columns("hasbool"))
 
@@ -1450,7 +1465,7 @@ class BatchRoundTripTest(TestBase):
             batch_op.alter_column(
                 "x", type_=Boolean(create_constraint=True, name="ck1")
             )
-        insp = inspect(config.db)
+        insp = inspect(self.conn)
 
         if exclusions.against(config, "sqlite"):
             eq_(
@@ -1471,14 +1486,6 @@ class BatchRoundTripTest(TestBase):
                 [Integer],
             )
 
-    def tearDown(self):
-        in_t = getattr(self.conn, "in_transaction", lambda: False)
-        if in_t():
-            self.conn.rollback()
-        with self.conn.begin():
-            self.metadata.drop_all(self.conn)
-        self.conn.close()
-
     def _assert_data(self, data, tablename="foo"):
         res = self.conn.execute(text("select * from %s" % tablename))
         if sqla_14:
@@ -1492,7 +1499,7 @@ class BatchRoundTripTest(TestBase):
             batch_op.alter_column("data", type_=String(30))
             batch_op.create_index("ix_data", ["data"])
 
-        insp = inspect(config.db)
+        insp = inspect(self.conn)
         eq_(
             set(
                 (ix["name"], tuple(ix["column_names"]))
@@ -1734,7 +1741,7 @@ class BatchRoundTripTest(TestBase):
         )
 
     def _assert_table_comment(self, tname, comment):
-        insp = inspect(config.db)
+        insp = inspect(self.conn)
 
         tcomment = insp.get_table_comment(tname)
         eq_(tcomment, {"text": comment})
@@ -1794,7 +1801,7 @@ class BatchRoundTripTest(TestBase):
         self._assert_table_comment("foo", None)
 
     def _assert_column_comment(self, tname, cname, comment):
-        insp = inspect(config.db)
+        insp = inspect(self.conn)
 
         cols = {col["name"]: col for col in insp.get_columns(tname)}
         eq_(cols[cname]["comment"], comment)
@@ -2037,7 +2044,7 @@ class BatchRoundTripTest(TestBase):
             ]
         )
         eq_(
-            [col["name"] for col in inspect(config.db).get_columns("foo")],
+            [col["name"] for col in inspect(self.conn).get_columns("foo")],
             ["id", "data", "x", "data2"],
         )
 
@@ -2063,7 +2070,7 @@ class BatchRoundTripTest(TestBase):
             ]
         )
         eq_(
-            [col["name"] for col in inspect(config.db).get_columns("foo")],
+            [col["name"] for col in inspect(self.conn).get_columns("foo")],
             ["id", "data", "x", "data2"],
         )
 
@@ -2084,7 +2091,7 @@ class BatchRoundTripTest(TestBase):
             tablename="nopk",
         )
         eq_(
-            [col["name"] for col in inspect(config.db).get_columns("foo")],
+            [col["name"] for col in inspect(self.conn).get_columns("foo")],
             ["id", "data", "x"],
         )
 
@@ -2104,7 +2111,7 @@ class BatchRoundTripTest(TestBase):
             ]
         )
         eq_(
-            [col["name"] for col in inspect(config.db).get_columns("foo")],
+            [col["name"] for col in inspect(self.conn).get_columns("foo")],
             ["id", "data2", "data", "x"],
         )
 
@@ -2124,7 +2131,7 @@ class BatchRoundTripTest(TestBase):
             ]
         )
         eq_(
-            [col["name"] for col in inspect(config.db).get_columns("foo")],
+            [col["name"] for col in inspect(self.conn).get_columns("foo")],
             ["id", "data", "data2", "x"],
         )
 
@@ -2158,12 +2165,12 @@ class BatchRoundTripTest(TestBase):
             ]
         )
         eq_(
-            [col["name"] for col in inspect(config.db).get_columns("foo")],
+            [col["name"] for col in inspect(self.conn).get_columns("foo")],
             ["id", "data", "x", "data2"],
         )
 
     def test_create_drop_index(self):
-        insp = inspect(config.db)
+        insp = inspect(self.conn)
         eq_(insp.get_indexes("foo"), [])
 
         with self.op.batch_alter_table("foo", recreate="always") as batch_op:
@@ -2178,8 +2185,7 @@ class BatchRoundTripTest(TestBase):
                 {"id": 5, "data": "d5", "x": 9},
             ]
         )
-
-        insp = inspect(config.db)
+        insp = inspect(self.conn)
         eq_(
             [
                 dict(
@@ -2195,7 +2201,7 @@ class BatchRoundTripTest(TestBase):
         with self.op.batch_alter_table("foo", recreate="always") as batch_op:
             batch_op.drop_index("ix_data")
 
-        insp = inspect(config.db)
+        insp = inspect(self.conn)
         eq_(insp.get_indexes("foo"), [])
 
 
@@ -2316,7 +2322,7 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest):
         ) as batch_op:
             batch_op.add_column(Column("data", Integer))
 
-        insp = inspect(config.db)
+        insp = inspect(self.conn)
 
         eq_(
             [
index 33e40dc179fd4138550a903653983dc8b57c42ce..b3146d3cba095a3913e63765e40d2bac01d3546d 100644 (file)
@@ -117,14 +117,20 @@ class PatchEnvironment:
 
 
 @testing.combinations(
-    (False, True, False),
-    (True, False, False),
-    (True, True, False),
-    (False, True, True),
-    (True, False, True),
-    (True, True, True),
-    argnames="transactional_ddl,transaction_per_migration,branched_connection",
-    id_="rrr",
+    (
+        False,
+        True,
+    ),
+    (
+        True,
+        False,
+    ),
+    (
+        True,
+        True,
+    ),
+    argnames="transactional_ddl,transaction_per_migration",
+    id_="rr",
 )
 class ApplyVersionsFunctionalTest(PatchEnvironment, TestBase):
     __only_on__ = "sqlite"
@@ -277,6 +283,11 @@ class ApplyVersionsFunctionalTest(PatchEnvironment, TestBase):
         assert not db.dialect.has_table(db.connect(), "bat")
 
 
+class LegacyApplyVersionsFunctionalTest(ApplyVersionsFunctionalTest):
+    __requires__ = ("sqlalchemy_1x",)
+    branched_connection = True
+
+
 # class level combinations can't do the skips for SQLAlchemy 1.3
 # so we have a separate class
 @testing.combinations(
@@ -621,6 +632,7 @@ run_migrations_online()
 
 
 class BranchedOnlineTransactionalDDLTest(OnlineTransactionalDDLTest):
+    __requires__ = ("sqlalchemy_1x",)
     branched_connection = True