]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Support sqlalchemy 1.4 exec_driver_sql, text() for strings
authorCaselIT <cfederico87@gmail.com>
Tue, 17 Mar 2020 22:03:32 +0000 (23:03 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 18 Mar 2020 20:45:29 +0000 (16:45 -0400)
Adjusted tests so that only connection-explicit execution
is used, along with the use of text() for string invocation.
Tests that are testing explicitly for deprecation warnings
will bypass SQLAlchemy warnings.   Added the RemovedIn20 warning
as an error raise for these two specific deprecation cases.

Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Change-Id: I4f6b83366329aa95204522c9e99129021d1899fc

12 files changed:
alembic/ddl/postgresql.py
alembic/runtime/migration.py
alembic/testing/__init__.py
alembic/testing/plugin/plugin_base.py
alembic/testing/requirements.py
alembic/util/__init__.py
alembic/util/sqla_compat.py
tests/requirements.py
tests/test_batch.py
tests/test_command.py
tests/test_environment.py
tests/test_postgresql.py

index 4316a96b4ed0024d82a52bb672968c9f7b68e0fb..4ddc0ed9701f3b8dccc45826b0473f443b0ae53e 100644 (file)
@@ -88,7 +88,10 @@ class PostgresqlImpl(DefaultImpl):
             rendered_metadata_default = "'%s'" % rendered_metadata_default
 
         return not self.connection.scalar(
-            "SELECT %s = %s" % (conn_col_default, rendered_metadata_default)
+            text(
+                "SELECT %s = %s"
+                % (conn_col_default, rendered_metadata_default)
+            )
         )
 
     def alter_column(
@@ -152,7 +155,8 @@ class PostgresqlImpl(DefaultImpl):
                 r"nextval\('(.+?)'::regclass\)", column_info["default"]
             )
             if seq_match:
-                info = inspector.bind.execute(
+                info = sqla_compat._exec_on_inspector(
+                    inspector,
                     text(
                         "select c.relname, a.attname "
                         "from pg_class as c join "
index 49eef713548fdc82beac285daf7ac880c7c25d1f..48408a4fb0854cdd8a48c950af971f047b44f137 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy.engine.strategies import MockEngineStrategy
 
 from .. import ddl
 from .. import util
+from ..util import sqla_compat
 from ..util.compat import callable
 from ..util.compat import EncodedIO
 
@@ -205,6 +206,7 @@ class MigrationContext(object):
                     "got %r" % connection,
                     stacklevel=3,
                 )
+
             dialect = connection.dialect
         elif url:
             url = sqla_url.make_url(url)
@@ -442,7 +444,7 @@ class MigrationContext(object):
             self.connection.execute(self._version.delete())
 
     def _has_version_table(self):
-        return self.connection.dialect.has_table(
+        return sqla_compat._connectable_has_table(
             self.connection, self.version_table, self.version_table_schema
         )
 
index 238f2bd51e692651518ab43f5a2b8e06fc0cf92c..f009da930b3b0287dc0e71e2fc799cd642d7160e 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy.testing import emits_warning  # noqa
 from sqlalchemy.testing import engines  # noqa
 from sqlalchemy.testing import mock  # noqa
 from sqlalchemy.testing import provide_metadata  # noqa
+from sqlalchemy.testing import uses_deprecated  # noqa
 from sqlalchemy.testing.config import requirements as requires  # noqa
 
 from alembic import util  # noqa
index 276bc56c63fe986813b851a967538ea7796c8cad..2d5e95ab541f6ec5b3a82bf8e7008c4331b2fa26 100644 (file)
@@ -40,6 +40,20 @@ def post_begin():
             "once", category=pytest.PytestDeprecationWarning
         )
 
+    from sqlalchemy import exc
+
+    if hasattr(exc, "RemovedIn20Warning"):
+        warnings.filterwarnings(
+            "error",
+            category=exc.RemovedIn20Warning,
+            message=".*Engine.execute",
+        )
+        warnings.filterwarnings(
+            "error",
+            category=exc.RemovedIn20Warning,
+            message=".*Passing a string",
+        )
+
 
 # override selected SQLAlchemy pytest hooks with vendored functionality
 def stop_test_class(cls):
index 1cb146b13a77b7642c4624e11ca843b3aec40504..4804646130a2bfbfc45340e60ca6c2887b6b3d54 100644 (file)
@@ -84,6 +84,13 @@ class SuiteRequirements(Requirements):
             "SQLAlchemy 1.3 or greater required",
         )
 
+    @property
+    def sqlalchemy_14(self):
+        return exclusions.skip_if(
+            lambda config: not util.sqla_14,
+            "SQLAlchemy 1.4 or greater required",
+        )
+
     @property
     def sqlalchemy_1115(self):
         return exclusions.skip_if(
index 961a18bd9e37ce1187d24ba558ab57bb8e28199c..cc86111bf893138cbecd3a0d4154a159d3db861f 100644 (file)
@@ -29,6 +29,7 @@ from .sqla_compat import sqla_1115  # noqa
 from .sqla_compat import sqla_120  # noqa
 from .sqla_compat import sqla_1216  # noqa
 from .sqla_compat import sqla_13  # noqa
+from .sqla_compat import sqla_14  # noqa
 
 
 if not sqla_110:
index e25dbbd5c49a63d114632b65b63bfd50baf8aaba..d3030332be3ae72ef893f59c45c567724ff7eb08 100644 (file)
@@ -1,6 +1,7 @@
 import re
 
 from sqlalchemy import __version__
+from sqlalchemy import inspect
 from sqlalchemy import schema
 from sqlalchemy import sql
 from sqlalchemy import types as sqltypes
@@ -45,6 +46,23 @@ except ImportError:
 AUTOINCREMENT_DEFAULT = "auto"
 
 
+def _connectable_has_table(connectable, tablename, schemaname):
+    if sqla_14:
+        return inspect(connectable).has_table(tablename, schemaname)
+    else:
+        return connectable.dialect.has_table(
+            connectable, tablename, schemaname
+        )
+
+
+def _exec_on_inspector(inspector, statement, **params):
+    if sqla_14:
+        with inspector._operation_context() as conn:
+            return conn.execute(statement, params)
+    else:
+        return inspector.bind.execute(statement, params)
+
+
 def _server_default_is_computed(column):
     if not has_computed:
         return False
index eb424ca8c5f14d7a7037edb11d42083fd60ab85a..bd258e08f222461bac1345b94b3f6c1d0ed3ea13 100644 (file)
@@ -1,3 +1,5 @@
+from sqlalchemy import text
+
 from alembic.testing import exclusions
 from alembic.testing.requirements import SuiteRequirements
 from alembic.util import sqla_compat
@@ -113,10 +115,13 @@ class DefaultRequirements(SuiteRequirements):
         def check(config):
             if not exclusions.against(config, "postgresql"):
                 return False
-            count = config.db.scalar(
-                "SELECT count(*) FROM pg_extension "
-                "WHERE extname='%s'" % name
-            )
+            with config.db.connect() as conn:
+                count = conn.scalar(
+                    text(
+                        "SELECT count(*) FROM pg_extension "
+                        "WHERE extname='%s'" % name
+                    )
+                )
             return bool(count)
 
         return exclusions.only_if(check, "needs %s extension" % name)
index 5b4d3ec063aef5e220ccff6fca129205a3c7f83e..4c32518d3c5dcbfd7671a551783d84c8ebb3f1d3 100644 (file)
@@ -1334,7 +1334,9 @@ class BatchRoundTripTest(TestBase):
         eq_(
             [
                 dict(row)
-                for row in self.conn.execute("select * from %s" % tablename)
+                for row in self.conn.execute(
+                    text("select * from %s" % tablename)
+                )
             ],
             data,
         )
index 6ecfdbef28b87d56e2ef901b55e15a9440417781..da83da7599d2945449e663016f750769e3c8b068 100644 (file)
@@ -6,6 +6,7 @@ import os
 import re
 
 from sqlalchemy import exc as sqla_exc
+from sqlalchemy import text
 
 from alembic import command
 from alembic import config
@@ -371,11 +372,14 @@ finally:
         r2 = command.revision(self.cfg)
         db = _sqlite_file_db()
         command.upgrade(self.cfg, "head")
-        assert_raises(
-            sqla_exc.IntegrityError,
-            db.execute,
-            "insert into alembic_version values ('%s')" % r2.revision,
-        )
+        with db.connect() as conn:
+            assert_raises(
+                sqla_exc.IntegrityError,
+                conn.execute,
+                text(
+                    "insert into alembic_version values ('%s')" % r2.revision
+                ),
+            )
 
     def test_err_correctly_raised_on_dupe_rows_no_pk(self):
         self._env_fixture(version_table_pk=False)
@@ -383,7 +387,10 @@ finally:
         r2 = command.revision(self.cfg)
         db = _sqlite_file_db()
         command.upgrade(self.cfg, "head")
-        db.execute("insert into alembic_version values ('%s')" % r2.revision)
+        with db.connect() as conn:
+            conn.execute(
+                text("insert into alembic_version values ('%s')" % r2.revision)
+            )
         assert_raises_message(
             util.CommandError,
             "Online migration expected to match one row when "
@@ -664,7 +671,7 @@ class StampMultipleHeadsTest(TestBase, _StampTest):
         eng = _sqlite_file_db()
         with eng.connect() as conn:
             result = conn.execute(
-                "update alembic_version set version_num='fake'"
+                text("update alembic_version set version_num='fake'")
             )
             eq_(result.rowcount, 1)
 
@@ -843,31 +850,39 @@ down_revision = '%s'
 
     def test_stamp_creates_table(self):
         command.stamp(self.cfg, "head")
-        eq_(
-            self.bind.scalar("select version_num from alembic_version"), self.b
-        )
+        with self.bind.connect() as conn:
+            eq_(
+                conn.scalar(text("select version_num from alembic_version")),
+                self.b,
+            )
 
     def test_stamp_existing_upgrade(self):
         command.stamp(self.cfg, self.a)
         command.stamp(self.cfg, self.b)
-        eq_(
-            self.bind.scalar("select version_num from alembic_version"), self.b
-        )
+        with self.bind.connect() as conn:
+            eq_(
+                conn.scalar(text("select version_num from alembic_version")),
+                self.b,
+            )
 
     def test_stamp_existing_downgrade(self):
         command.stamp(self.cfg, self.b)
         command.stamp(self.cfg, self.a)
-        eq_(
-            self.bind.scalar("select version_num from alembic_version"), self.a
-        )
+        with self.bind.connect() as conn:
+            eq_(
+                conn.scalar(text("select version_num from alembic_version")),
+                self.a,
+            )
 
     def test_stamp_version_already_there(self):
         command.stamp(self.cfg, self.b)
         command.stamp(self.cfg, self.b)
 
-        eq_(
-            self.bind.scalar("select version_num from alembic_version"), self.b
-        )
+        with self.bind.connect() as conn:
+            eq_(
+                conn.scalar(text("select version_num from alembic_version")),
+                self.b,
+            )
 
 
 class EditTest(TestBase):
index 68eab3d26f8a3b54a04fd3015c31c72e3968c9f4..7e5eb8393594d8a2be96a399b7113255f8f6b3e7 100644 (file)
@@ -1,5 +1,6 @@
 #!coding: utf-8
 from alembic import command
+from alembic import testing
 from alembic.environment import EnvironmentContext
 from alembic.migration import MigrationContext
 from alembic.script import ScriptDirectory
@@ -94,6 +95,9 @@ def upgrade():
             command.upgrade(self.cfg, "arev", sql=True)
         assert "do some SQL thing with a % percent sign %" in buf.getvalue()
 
+    @testing.uses_deprecated(
+        r"The Engine.execute\(\) function/method is considered legacy"
+    )
     def test_warning_on_passing_engine(self):
         env = self._fixture()
 
index cfa265e2cb34382c30865615f6bcf9f9d80a2fc6..8c435102b7300512188974d146c6f0e4e824593c 100644 (file)
@@ -307,17 +307,19 @@ class PGAutocommitBlockTest(TestBase):
         self.conn = conn = config.db.connect()
 
         with conn.begin():
-            conn.execute("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');")
+            conn.execute(
+                text("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')")
+            )
 
     def tearDown(self):
         with self.conn.begin():
-            self.conn.execute("DROP TYPE mood")
+            self.conn.execute(text("DROP TYPE mood"))
 
     def test_alter_enum(self):
         context = MigrationContext.configure(connection=self.conn)
         with context.begin_transaction(_per_migration=True):
             with context.autocommit_block():
-                context.execute("ALTER TYPE mood ADD VALUE 'soso'")
+                context.execute(text("ALTER TYPE mood ADD VALUE 'soso'"))
 
 
 class PGOfflineEnumTest(TestBase):
@@ -430,25 +432,31 @@ class PostgresqlInlineLiteralTest(TestBase):
     @classmethod
     def setup_class(cls):
         cls.bind = config.db
-        cls.bind.execute(
+        with config.db.connect() as conn:
+            conn.execute(
+                text(
+                    """
+                create table tab (
+                    col varchar(50)
+                )
             """
-            create table tab (
-                col varchar(50)
+                )
             )
-        """
-        )
-        cls.bind.execute(
+            conn.execute(
+                text(
+                    """
+                insert into tab (col) values
+                    ('old data 1'),
+                    ('old data 2.1'),
+                    ('old data 3')
             """
-            insert into tab (col) values
-                ('old data 1'),
-                ('old data 2.1'),
-                ('old data 3')
-        """
-        )
+                )
+            )
 
     @classmethod
     def teardown_class(cls):
-        cls.bind.execute("drop table tab")
+        with cls.bind.connect() as conn:
+            conn.execute(text("drop table tab"))
 
     def setUp(self):
         self.conn = self.bind.connect()
@@ -469,7 +477,7 @@ class PostgresqlInlineLiteralTest(TestBase):
         )
         eq_(
             self.conn.execute(
-                "select count(*) from tab where col='new data'"
+                text("select count(*) from tab where col='new data'")
             ).scalar(),
             1,
         )