From: CaselIT Date: Tue, 17 Mar 2020 22:03:32 +0000 (+0100) Subject: Support sqlalchemy 1.4 exec_driver_sql, text() for strings X-Git-Tag: rel_1_4_2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4f351a6ca8a6b5fe6718203226805f4e1a02a2db;p=thirdparty%2Fsqlalchemy%2Falembic.git Support sqlalchemy 1.4 exec_driver_sql, text() for strings Adjusted tests so that only connection-explicit execution is used, along with the use of text() for string invocation. Tests that are testing explicitly for deprecation warnings will bypass SQLAlchemy warnings. Added the RemovedIn20 warning as an error raise for these two specific deprecation cases. Co-authored-by: Mike Bayer Change-Id: I4f6b83366329aa95204522c9e99129021d1899fc --- diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 4316a96b..4ddc0ed9 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -88,7 +88,10 @@ class PostgresqlImpl(DefaultImpl): rendered_metadata_default = "'%s'" % rendered_metadata_default return not self.connection.scalar( - "SELECT %s = %s" % (conn_col_default, rendered_metadata_default) + text( + "SELECT %s = %s" + % (conn_col_default, rendered_metadata_default) + ) ) def alter_column( @@ -152,7 +155,8 @@ class PostgresqlImpl(DefaultImpl): r"nextval\('(.+?)'::regclass\)", column_info["default"] ) if seq_match: - info = inspector.bind.execute( + info = sqla_compat._exec_on_inspector( + inspector, text( "select c.relname, a.attname " "from pg_class as c join " diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 49eef713..48408a4f 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -14,6 +14,7 @@ from sqlalchemy.engine.strategies import MockEngineStrategy from .. import ddl from .. import util +from ..util import sqla_compat from ..util.compat import callable from ..util.compat import EncodedIO @@ -205,6 +206,7 @@ class MigrationContext(object): "got %r" % connection, stacklevel=3, ) + dialect = connection.dialect elif url: url = sqla_url.make_url(url) @@ -442,7 +444,7 @@ class MigrationContext(object): self.connection.execute(self._version.delete()) def _has_version_table(self): - return self.connection.dialect.has_table( + return sqla_compat._connectable_has_table( self.connection, self.version_table, self.version_table_schema ) diff --git a/alembic/testing/__init__.py b/alembic/testing/__init__.py index 238f2bd5..f009da93 100644 --- a/alembic/testing/__init__.py +++ b/alembic/testing/__init__.py @@ -3,6 +3,7 @@ from sqlalchemy.testing import emits_warning # noqa from sqlalchemy.testing import engines # noqa from sqlalchemy.testing import mock # noqa from sqlalchemy.testing import provide_metadata # noqa +from sqlalchemy.testing import uses_deprecated # noqa from sqlalchemy.testing.config import requirements as requires # noqa from alembic import util # noqa diff --git a/alembic/testing/plugin/plugin_base.py b/alembic/testing/plugin/plugin_base.py index 276bc56c..2d5e95ab 100644 --- a/alembic/testing/plugin/plugin_base.py +++ b/alembic/testing/plugin/plugin_base.py @@ -40,6 +40,20 @@ def post_begin(): "once", category=pytest.PytestDeprecationWarning ) + from sqlalchemy import exc + + if hasattr(exc, "RemovedIn20Warning"): + warnings.filterwarnings( + "error", + category=exc.RemovedIn20Warning, + message=".*Engine.execute", + ) + warnings.filterwarnings( + "error", + category=exc.RemovedIn20Warning, + message=".*Passing a string", + ) + # override selected SQLAlchemy pytest hooks with vendored functionality def stop_test_class(cls): diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py index 1cb146b1..48046461 100644 --- a/alembic/testing/requirements.py +++ b/alembic/testing/requirements.py @@ -84,6 +84,13 @@ class SuiteRequirements(Requirements): "SQLAlchemy 1.3 or greater required", ) + @property + def sqlalchemy_14(self): + return exclusions.skip_if( + lambda config: not util.sqla_14, + "SQLAlchemy 1.4 or greater required", + ) + @property def sqlalchemy_1115(self): return exclusions.skip_if( diff --git a/alembic/util/__init__.py b/alembic/util/__init__.py index 961a18bd..cc86111b 100644 --- a/alembic/util/__init__.py +++ b/alembic/util/__init__.py @@ -29,6 +29,7 @@ from .sqla_compat import sqla_1115 # noqa from .sqla_compat import sqla_120 # noqa from .sqla_compat import sqla_1216 # noqa from .sqla_compat import sqla_13 # noqa +from .sqla_compat import sqla_14 # noqa if not sqla_110: diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index e25dbbd5..d3030332 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -1,6 +1,7 @@ import re from sqlalchemy import __version__ +from sqlalchemy import inspect from sqlalchemy import schema from sqlalchemy import sql from sqlalchemy import types as sqltypes @@ -45,6 +46,23 @@ except ImportError: AUTOINCREMENT_DEFAULT = "auto" +def _connectable_has_table(connectable, tablename, schemaname): + if sqla_14: + return inspect(connectable).has_table(tablename, schemaname) + else: + return connectable.dialect.has_table( + connectable, tablename, schemaname + ) + + +def _exec_on_inspector(inspector, statement, **params): + if sqla_14: + with inspector._operation_context() as conn: + return conn.execute(statement, params) + else: + return inspector.bind.execute(statement, params) + + def _server_default_is_computed(column): if not has_computed: return False diff --git a/tests/requirements.py b/tests/requirements.py index eb424ca8..bd258e08 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -1,3 +1,5 @@ +from sqlalchemy import text + from alembic.testing import exclusions from alembic.testing.requirements import SuiteRequirements from alembic.util import sqla_compat @@ -113,10 +115,13 @@ class DefaultRequirements(SuiteRequirements): def check(config): if not exclusions.against(config, "postgresql"): return False - count = config.db.scalar( - "SELECT count(*) FROM pg_extension " - "WHERE extname='%s'" % name - ) + with config.db.connect() as conn: + count = conn.scalar( + text( + "SELECT count(*) FROM pg_extension " + "WHERE extname='%s'" % name + ) + ) return bool(count) return exclusions.only_if(check, "needs %s extension" % name) diff --git a/tests/test_batch.py b/tests/test_batch.py index 5b4d3ec0..4c32518d 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -1334,7 +1334,9 @@ class BatchRoundTripTest(TestBase): eq_( [ dict(row) - for row in self.conn.execute("select * from %s" % tablename) + for row in self.conn.execute( + text("select * from %s" % tablename) + ) ], data, ) diff --git a/tests/test_command.py b/tests/test_command.py index 6ecfdbef..da83da75 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -6,6 +6,7 @@ import os import re from sqlalchemy import exc as sqla_exc +from sqlalchemy import text from alembic import command from alembic import config @@ -371,11 +372,14 @@ finally: r2 = command.revision(self.cfg) db = _sqlite_file_db() command.upgrade(self.cfg, "head") - assert_raises( - sqla_exc.IntegrityError, - db.execute, - "insert into alembic_version values ('%s')" % r2.revision, - ) + with db.connect() as conn: + assert_raises( + sqla_exc.IntegrityError, + conn.execute, + text( + "insert into alembic_version values ('%s')" % r2.revision + ), + ) def test_err_correctly_raised_on_dupe_rows_no_pk(self): self._env_fixture(version_table_pk=False) @@ -383,7 +387,10 @@ finally: r2 = command.revision(self.cfg) db = _sqlite_file_db() command.upgrade(self.cfg, "head") - db.execute("insert into alembic_version values ('%s')" % r2.revision) + with db.connect() as conn: + conn.execute( + text("insert into alembic_version values ('%s')" % r2.revision) + ) assert_raises_message( util.CommandError, "Online migration expected to match one row when " @@ -664,7 +671,7 @@ class StampMultipleHeadsTest(TestBase, _StampTest): eng = _sqlite_file_db() with eng.connect() as conn: result = conn.execute( - "update alembic_version set version_num='fake'" + text("update alembic_version set version_num='fake'") ) eq_(result.rowcount, 1) @@ -843,31 +850,39 @@ down_revision = '%s' def test_stamp_creates_table(self): command.stamp(self.cfg, "head") - eq_( - self.bind.scalar("select version_num from alembic_version"), self.b - ) + with self.bind.connect() as conn: + eq_( + conn.scalar(text("select version_num from alembic_version")), + self.b, + ) def test_stamp_existing_upgrade(self): command.stamp(self.cfg, self.a) command.stamp(self.cfg, self.b) - eq_( - self.bind.scalar("select version_num from alembic_version"), self.b - ) + with self.bind.connect() as conn: + eq_( + conn.scalar(text("select version_num from alembic_version")), + self.b, + ) def test_stamp_existing_downgrade(self): command.stamp(self.cfg, self.b) command.stamp(self.cfg, self.a) - eq_( - self.bind.scalar("select version_num from alembic_version"), self.a - ) + with self.bind.connect() as conn: + eq_( + conn.scalar(text("select version_num from alembic_version")), + self.a, + ) def test_stamp_version_already_there(self): command.stamp(self.cfg, self.b) command.stamp(self.cfg, self.b) - eq_( - self.bind.scalar("select version_num from alembic_version"), self.b - ) + with self.bind.connect() as conn: + eq_( + conn.scalar(text("select version_num from alembic_version")), + self.b, + ) class EditTest(TestBase): diff --git a/tests/test_environment.py b/tests/test_environment.py index 68eab3d2..7e5eb839 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -1,5 +1,6 @@ #!coding: utf-8 from alembic import command +from alembic import testing from alembic.environment import EnvironmentContext from alembic.migration import MigrationContext from alembic.script import ScriptDirectory @@ -94,6 +95,9 @@ def upgrade(): command.upgrade(self.cfg, "arev", sql=True) assert "do some SQL thing with a % percent sign %" in buf.getvalue() + @testing.uses_deprecated( + r"The Engine.execute\(\) function/method is considered legacy" + ) def test_warning_on_passing_engine(self): env = self._fixture() diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index cfa265e2..8c435102 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -307,17 +307,19 @@ class PGAutocommitBlockTest(TestBase): self.conn = conn = config.db.connect() with conn.begin(): - conn.execute("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');") + conn.execute( + text("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')") + ) def tearDown(self): with self.conn.begin(): - self.conn.execute("DROP TYPE mood") + self.conn.execute(text("DROP TYPE mood")) def test_alter_enum(self): context = MigrationContext.configure(connection=self.conn) with context.begin_transaction(_per_migration=True): with context.autocommit_block(): - context.execute("ALTER TYPE mood ADD VALUE 'soso'") + context.execute(text("ALTER TYPE mood ADD VALUE 'soso'")) class PGOfflineEnumTest(TestBase): @@ -430,25 +432,31 @@ class PostgresqlInlineLiteralTest(TestBase): @classmethod def setup_class(cls): cls.bind = config.db - cls.bind.execute( + with config.db.connect() as conn: + conn.execute( + text( + """ + create table tab ( + col varchar(50) + ) """ - create table tab ( - col varchar(50) + ) ) - """ - ) - cls.bind.execute( + conn.execute( + text( + """ + insert into tab (col) values + ('old data 1'), + ('old data 2.1'), + ('old data 3') """ - insert into tab (col) values - ('old data 1'), - ('old data 2.1'), - ('old data 3') - """ - ) + ) + ) @classmethod def teardown_class(cls): - cls.bind.execute("drop table tab") + with cls.bind.connect() as conn: + conn.execute(text("drop table tab")) def setUp(self): self.conn = self.bind.connect() @@ -469,7 +477,7 @@ class PostgresqlInlineLiteralTest(TestBase): ) eq_( self.conn.execute( - "select count(*) from tab where col='new data'" + text("select count(*) from tab where col='new data'") ).scalar(), 1, )