From: Mike Bayer Date: Fri, 5 Nov 2021 15:24:23 +0000 (-0400) Subject: sqlalchemy 2.0 test updates X-Git-Tag: rel_1_7_5~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d5a368ca7dbfe8501632cbacc69f04ccbfde48ae;p=thirdparty%2Fsqlalchemy%2Falembic.git sqlalchemy 2.0 test updates - disable branched connection tests for 2.x - dont use future flag for 2.x - adjust batch tests for autobegin, inconsistent SQLite transactional DDL behaviors Change-Id: I70caf6afecc83f880dc92fa6cbc29e2043c43bb9 --- diff --git a/alembic/testing/fixtures.py b/alembic/testing/fixtures.py index 5937d485..5e6ba89c 100644 --- a/alembic/testing/fixtures.py +++ b/alembic/testing/fixtures.py @@ -29,6 +29,7 @@ from ..util.compat import string_types from ..util.compat import text_type from ..util.sqla_compat import create_mock_engine from ..util.sqla_compat import sqla_14 +from ..util.sqla_compat import sqla_1x testing_config = configparser.ConfigParser() @@ -36,7 +37,10 @@ testing_config.read(["test.cfg"]) class TestBase(SQLAlchemyTestBase): - is_sqlalchemy_future = False + if sqla_1x: + is_sqlalchemy_future = False + else: + is_sqlalchemy_future = True @testing.fixture() def ops_context(self, migration_context): diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py index 37780ab0..3947272b 100644 --- a/alembic/testing/requirements.py +++ b/alembic/testing/requirements.py @@ -83,6 +83,13 @@ class SuiteRequirements(Requirements): "SQLAlchemy 1.4 or greater required", ) + @property + def sqlalchemy_1x(self): + return exclusions.skip_if( + lambda config: not util.sqla_1x, + "SQLAlchemy 1.x test", + ) + @property def comments(self): return exclusions.only_if( diff --git a/alembic/testing/util.py b/alembic/testing/util.py index ccabf9cd..9d24d0fe 100644 --- a/alembic/testing/util.py +++ b/alembic/testing/util.py @@ -5,7 +5,9 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import re import types +from typing import Union def flag_combinations(*combinations): @@ -97,11 +99,28 @@ def metadata_fixture(ddl="function"): return decorate +def _safe_int(value: str) -> Union[int, str]: + try: + return int(value) + except: + return value + + def testing_engine(url=None, options=None, future=False): from sqlalchemy.testing import config from sqlalchemy.testing.engines import testing_engine + from sqlalchemy import __version__ + + _vers = tuple( + [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)] + ) + sqla_1x = _vers < (2,) if not future: future = getattr(config._current.options, "future_engine", False) - kw = {"future": future} if future else {} + + if sqla_1x: + kw = {"future": future} if future else {} + else: + kw = {} return testing_engine(url, options, **kw) diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py index 15c8f4e5..49bee432 100644 --- a/alembic/util/__init__.py +++ b/alembic/util/__init__.py @@ -25,6 +25,7 @@ from .pyfiles import template_to_file from .sqla_compat import has_computed from .sqla_compat import sqla_13 from .sqla_compat import sqla_14 +from .sqla_compat import sqla_1x if not sqla_13: diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 65b11307..a05e27bc 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -58,6 +58,7 @@ _vers = tuple( sqla_13 = _vers >= (1, 3) sqla_14 = _vers >= (1, 4) sqla_14_26 = _vers >= (1, 4, 26) +sqla_1x = _vers < (2,) try: from sqlalchemy import Computed # noqa @@ -122,6 +123,14 @@ def _safe_begin_connection_transaction( return connection.begin() +def _safe_commit_connection_transaction( + connection: "Connection", +) -> None: + transaction = _get_connection_transaction(connection) + if transaction: + transaction.commit() + + def _safe_rollback_connection_transaction( connection: "Connection", ) -> None: diff --git a/tests/test_batch.py b/tests/test_batch.py index 2753bdc3..700056ac 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -39,6 +39,7 @@ from alembic.testing import mock from alembic.testing import TestBase from alembic.testing.fixtures import op_fixture from alembic.util import exc as alembic_exc +from alembic.util.sqla_compat import _safe_commit_connection_transaction from alembic.util.sqla_compat import _select from alembic.util.sqla_compat import has_computed from alembic.util.sqla_compat import has_identity @@ -1282,6 +1283,20 @@ class BatchRoundTripTest(TestBase): context = MigrationContext.configure(self.conn) self.op = Operations(context) + def tearDown(self): + # why commit? because SQLite has inconsistent treatment + # of transactional DDL. A test that runs CREATE TABLE and then + # ALTER TABLE to change the name of that table, will end up + # committing the CREATE TABLE but not the ALTER. As batch mode + # does this with a temp table name that's not even in the + # metadata collection, we don't have an explicit drop for it + # (though we could do that too). calling commit means the + # ALTER will go through and the drop_all() will then catch it. + _safe_commit_connection_transaction(self.conn) + with self.conn.begin(): + self.metadata.drop_all(self.conn) + self.conn.close() + @contextmanager def _sqlite_referential_integrity(self): self.conn.exec_driver_sql("PRAGMA foreign_keys=ON") @@ -1385,7 +1400,7 @@ class BatchRoundTripTest(TestBase): type_=Integer, existing_type=Boolean(create_constraint=True, name="ck1"), ) - insp = inspect(config.db) + insp = inspect(self.conn) eq_( [ @@ -1440,7 +1455,7 @@ class BatchRoundTripTest(TestBase): batch_op.drop_column( "x", existing_type=Boolean(create_constraint=True, name="ck1") ) - insp = inspect(config.db) + insp = inspect(self.conn) assert "x" not in (c["name"] for c in insp.get_columns("hasbool")) @@ -1450,7 +1465,7 @@ class BatchRoundTripTest(TestBase): batch_op.alter_column( "x", type_=Boolean(create_constraint=True, name="ck1") ) - insp = inspect(config.db) + insp = inspect(self.conn) if exclusions.against(config, "sqlite"): eq_( @@ -1471,14 +1486,6 @@ class BatchRoundTripTest(TestBase): [Integer], ) - def tearDown(self): - in_t = getattr(self.conn, "in_transaction", lambda: False) - if in_t(): - self.conn.rollback() - with self.conn.begin(): - self.metadata.drop_all(self.conn) - self.conn.close() - def _assert_data(self, data, tablename="foo"): res = self.conn.execute(text("select * from %s" % tablename)) if sqla_14: @@ -1492,7 +1499,7 @@ class BatchRoundTripTest(TestBase): batch_op.alter_column("data", type_=String(30)) batch_op.create_index("ix_data", ["data"]) - insp = inspect(config.db) + insp = inspect(self.conn) eq_( set( (ix["name"], tuple(ix["column_names"])) @@ -1734,7 +1741,7 @@ class BatchRoundTripTest(TestBase): ) def _assert_table_comment(self, tname, comment): - insp = inspect(config.db) + insp = inspect(self.conn) tcomment = insp.get_table_comment(tname) eq_(tcomment, {"text": comment}) @@ -1794,7 +1801,7 @@ class BatchRoundTripTest(TestBase): self._assert_table_comment("foo", None) def _assert_column_comment(self, tname, cname, comment): - insp = inspect(config.db) + insp = inspect(self.conn) cols = {col["name"]: col for col in insp.get_columns(tname)} eq_(cols[cname]["comment"], comment) @@ -2037,7 +2044,7 @@ class BatchRoundTripTest(TestBase): ] ) eq_( - [col["name"] for col in inspect(config.db).get_columns("foo")], + [col["name"] for col in inspect(self.conn).get_columns("foo")], ["id", "data", "x", "data2"], ) @@ -2063,7 +2070,7 @@ class BatchRoundTripTest(TestBase): ] ) eq_( - [col["name"] for col in inspect(config.db).get_columns("foo")], + [col["name"] for col in inspect(self.conn).get_columns("foo")], ["id", "data", "x", "data2"], ) @@ -2084,7 +2091,7 @@ class BatchRoundTripTest(TestBase): tablename="nopk", ) eq_( - [col["name"] for col in inspect(config.db).get_columns("foo")], + [col["name"] for col in inspect(self.conn).get_columns("foo")], ["id", "data", "x"], ) @@ -2104,7 +2111,7 @@ class BatchRoundTripTest(TestBase): ] ) eq_( - [col["name"] for col in inspect(config.db).get_columns("foo")], + [col["name"] for col in inspect(self.conn).get_columns("foo")], ["id", "data2", "data", "x"], ) @@ -2124,7 +2131,7 @@ class BatchRoundTripTest(TestBase): ] ) eq_( - [col["name"] for col in inspect(config.db).get_columns("foo")], + [col["name"] for col in inspect(self.conn).get_columns("foo")], ["id", "data", "data2", "x"], ) @@ -2158,12 +2165,12 @@ class BatchRoundTripTest(TestBase): ] ) eq_( - [col["name"] for col in inspect(config.db).get_columns("foo")], + [col["name"] for col in inspect(self.conn).get_columns("foo")], ["id", "data", "x", "data2"], ) def test_create_drop_index(self): - insp = inspect(config.db) + insp = inspect(self.conn) eq_(insp.get_indexes("foo"), []) with self.op.batch_alter_table("foo", recreate="always") as batch_op: @@ -2178,8 +2185,7 @@ class BatchRoundTripTest(TestBase): {"id": 5, "data": "d5", "x": 9}, ] ) - - insp = inspect(config.db) + insp = inspect(self.conn) eq_( [ dict( @@ -2195,7 +2201,7 @@ class BatchRoundTripTest(TestBase): with self.op.batch_alter_table("foo", recreate="always") as batch_op: batch_op.drop_index("ix_data") - insp = inspect(config.db) + insp = inspect(self.conn) eq_(insp.get_indexes("foo"), []) @@ -2316,7 +2322,7 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest): ) as batch_op: batch_op.add_column(Column("data", Integer)) - insp = inspect(config.db) + insp = inspect(self.conn) eq_( [ diff --git a/tests/test_script_consumption.py b/tests/test_script_consumption.py index 33e40dc1..b3146d3c 100644 --- a/tests/test_script_consumption.py +++ b/tests/test_script_consumption.py @@ -117,14 +117,20 @@ class PatchEnvironment: @testing.combinations( - (False, True, False), - (True, False, False), - (True, True, False), - (False, True, True), - (True, False, True), - (True, True, True), - argnames="transactional_ddl,transaction_per_migration,branched_connection", - id_="rrr", + ( + False, + True, + ), + ( + True, + False, + ), + ( + True, + True, + ), + argnames="transactional_ddl,transaction_per_migration", + id_="rr", ) class ApplyVersionsFunctionalTest(PatchEnvironment, TestBase): __only_on__ = "sqlite" @@ -277,6 +283,11 @@ class ApplyVersionsFunctionalTest(PatchEnvironment, TestBase): assert not db.dialect.has_table(db.connect(), "bat") +class LegacyApplyVersionsFunctionalTest(ApplyVersionsFunctionalTest): + __requires__ = ("sqlalchemy_1x",) + branched_connection = True + + # class level combinations can't do the skips for SQLAlchemy 1.3 # so we have a separate class @testing.combinations( @@ -621,6 +632,7 @@ run_migrations_online() class BranchedOnlineTransactionalDDLTest(OnlineTransactionalDDLTest): + __requires__ = ("sqlalchemy_1x",) branched_connection = True