]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Accommodate SQLAlchemy 1.4/2.0
authorCaselIT <cfederico87@gmail.com>
Mon, 19 Oct 2020 21:23:08 +0000 (23:23 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Jan 2021 04:41:43 +0000 (23:41 -0500)
To accommodate SQLAlchemy 1.4 and 2.0, the migration model now no longer
assumes that the SQLAlchemy Connection will autocommit an individual
operation.   This essentially means that for databases that use
non-transactional DDL (pysqlite current driver behavior, MySQL), there is
still a BEGIN/COMMIT block that will surround each individual migration.
Databases that support transactional DDL should continue to have the
same flow, either per migration or per-entire run, depending on the
value of the :paramref:`.Environment.configure.transaction_per_migration`
flag.

Compatibility is established such that the entire library should
not generate any SQLAlchemy 2.0 deprecation warnings and
SQLALCHEMY_WARN_20 is part of conftest.py. (one warning remains
for the moment that needs to be resolved on the SQLAlchemy side)

The test suite requires SQLAlchemy 1.4.0b2 for testing 1.4;
1.4.0b1 won't work.

Test suite / setup also being modernized, as we are at
SQLAlchemy 1.3 we can now remove the majority of the testing
suite plugin.

Change-Id: If55b1ea3c12ead66405ab3fadc76d15d89dabb90

42 files changed:
alembic/autogenerate/api.py
alembic/ddl/impl.py
alembic/operations/batch.py
alembic/operations/schemaobj.py
alembic/runtime/migration.py
alembic/script/base.py
alembic/script/revision.py
alembic/script/write_hooks.py
alembic/testing/__init__.py
alembic/testing/assertions.py
alembic/testing/env.py
alembic/testing/fixture_functions.py [deleted file]
alembic/testing/fixtures.py
alembic/testing/plugin/bootstrap.py
alembic/testing/plugin/plugin_base.py [deleted file]
alembic/testing/plugin/pytestplugin.py [deleted file]
alembic/testing/requirements.py
alembic/testing/util.py
alembic/testing/warnings.py [new file with mode: 0644]
alembic/util/__init__.py
alembic/util/compat.py
alembic/util/langhelpers.py
alembic/util/pyfiles.py
alembic/util/sqla_compat.py
docs/build/autogenerate.rst
docs/build/unreleased/autocommit.rst [new file with mode: 0644]
setup.cfg
setup.py
tests/conftest.py
tests/requirements.py
tests/test_autogen_indexes.py
tests/test_batch.py
tests/test_bulk_insert.py
tests/test_command.py
tests/test_environment.py
tests/test_impl.py [new file with mode: 0644]
tests/test_mysql.py
tests/test_postgresql.py
tests/test_script_consumption.py
tests/test_sqlite.py
tests/test_version_table.py
tox.ini

index 5d1e84816f799c7f4e3e9fae4fbceaface6cf279..db5fe1264fcf5750d351ee1b5cf4b64e7d0d3a98 100644 (file)
@@ -31,22 +31,23 @@ def compare_metadata(context, metadata):
         from sqlalchemy.schema import SchemaItem
         from sqlalchemy.types import TypeEngine
         from sqlalchemy import (create_engine, MetaData, Column,
-                Integer, String, Table)
+                Integer, String, Table, text)
         import pprint
 
         engine = create_engine("sqlite://")
 
-        engine.execute('''
-            create table foo (
-                id integer not null primary key,
-                old_data varchar,
-                x integer
-            )''')
-
-        engine.execute('''
-            create table bar (
-                data varchar
-            )''')
+        with engine.begin() as conn:
+            conn.execute(text('''
+                create table foo (
+                    id integer not null primary key,
+                    old_data varchar,
+                    x integer
+                )'''))
+
+            conn.execute(text('''
+                create table bar (
+                    data varchar
+                )'''))
 
         metadata = MetaData()
         Table('foo', metadata,
index 3674c67dea28cb9e360d2d6b0efe04157f5fe99f..923fd8b51917a1b58078071b034a00a68fb3d122 100644 (file)
@@ -140,7 +140,10 @@ class DefaultImpl(with_metaclass(ImplMeta)):
             conn = self.connection
             if execution_options:
                 conn = conn.execution_options(**execution_options)
-            return conn.execute(construct, *multiparams, **params)
+            if params:
+                multiparams += (params,)
+
+            return conn.execute(construct, multiparams)
 
     def execute(self, sql, execution_options=None):
         self._exec(sql, execution_options)
@@ -316,7 +319,7 @@ class DefaultImpl(with_metaclass(ImplMeta)):
         if self.as_sql:
             for row in rows:
                 self._exec(
-                    table.insert(inline=True).values(
+                    sqla_compat._insert_inline(table).values(
                         **dict(
                             (
                                 k,
@@ -338,10 +341,14 @@ class DefaultImpl(with_metaclass(ImplMeta)):
                 table._autoincrement_column = None
             if rows:
                 if multiinsert:
-                    self._exec(table.insert(inline=True), multiparams=rows)
+                    self._exec(
+                        sqla_compat._insert_inline(table), multiparams=rows
+                    )
                 else:
                     for row in rows:
-                        self._exec(table.insert(inline=True).values(**row))
+                        self._exec(
+                            sqla_compat._insert_inline(table).values(**row)
+                        )
 
     def _tokenize_column_type(self, column):
         definition = self.dialect.type_compiler.process(column.type).lower()
index 81e3b99585c18a7e61dadc5606e007495093ca00..f0291936e390fb79138ec79ea7cce6a46bd0cf18 100644 (file)
@@ -5,7 +5,6 @@ from sqlalchemy import Index
 from sqlalchemy import MetaData
 from sqlalchemy import PrimaryKeyConstraint
 from sqlalchemy import schema as sql_schema
-from sqlalchemy import select
 from sqlalchemy import Table
 from sqlalchemy import types as sqltypes
 from sqlalchemy.events import SchemaEventTarget
@@ -14,9 +13,12 @@ from sqlalchemy.util import topological
 
 from ..util import exc
 from ..util.sqla_compat import _columns_for_constraint
+from ..util.sqla_compat import _ensure_scope_for_ddl
 from ..util.sqla_compat import _fk_is_self_referential
+from ..util.sqla_compat import _insert_inline
 from ..util.sqla_compat import _is_type_bound
 from ..util.sqla_compat import _remove_column_from_collection
+from ..util.sqla_compat import _select
 
 
 class BatchOperationsImpl(object):
@@ -76,44 +78,44 @@ class BatchOperationsImpl(object):
     def flush(self):
         should_recreate = self._should_recreate()
 
-        if not should_recreate:
-            for opname, arg, kw in self.batch:
-                fn = getattr(self.operations.impl, opname)
-                fn(*arg, **kw)
-        else:
-            if self.naming_convention:
-                m1 = MetaData(naming_convention=self.naming_convention)
+        with _ensure_scope_for_ddl(self.impl.connection):
+            if not should_recreate:
+                for opname, arg, kw in self.batch:
+                    fn = getattr(self.operations.impl, opname)
+                    fn(*arg, **kw)
             else:
-                m1 = MetaData()
+                if self.naming_convention:
+                    m1 = MetaData(naming_convention=self.naming_convention)
+                else:
+                    m1 = MetaData()
 
-            if self.copy_from is not None:
-                existing_table = self.copy_from
-                reflected = False
-            else:
-                existing_table = Table(
-                    self.table_name,
-                    m1,
-                    schema=self.schema,
-                    autoload=True,
-                    autoload_with=self.operations.get_bind(),
-                    *self.reflect_args,
-                    **self.reflect_kwargs
+                if self.copy_from is not None:
+                    existing_table = self.copy_from
+                    reflected = False
+                else:
+                    existing_table = Table(
+                        self.table_name,
+                        m1,
+                        schema=self.schema,
+                        autoload_with=self.operations.get_bind(),
+                        *self.reflect_args,
+                        **self.reflect_kwargs
+                    )
+                    reflected = True
+
+                batch_impl = ApplyBatchImpl(
+                    self.impl,
+                    existing_table,
+                    self.table_args,
+                    self.table_kwargs,
+                    reflected,
+                    partial_reordering=self.partial_reordering,
                 )
-                reflected = True
-
-            batch_impl = ApplyBatchImpl(
-                self.impl,
-                existing_table,
-                self.table_args,
-                self.table_kwargs,
-                reflected,
-                partial_reordering=self.partial_reordering,
-            )
-            for opname, arg, kw in self.batch:
-                fn = getattr(batch_impl, opname)
-                fn(*arg, **kw)
+                for opname, arg, kw in self.batch:
+                    fn = getattr(batch_impl, opname)
+                    fn(*arg, **kw)
 
-            batch_impl._create(self.impl)
+                batch_impl._create(self.impl)
 
     def alter_column(self, *arg, **kw):
         self.batch.append(("alter_column", arg, kw))
@@ -362,14 +364,14 @@ class ApplyBatchImpl(object):
 
         try:
             op_impl._exec(
-                self.new_table.insert(inline=True).from_select(
+                _insert_inline(self.new_table).from_select(
                     list(
                         k
                         for k, transfer in self.column_transfers.items()
                         if "expr" in transfer
                     ),
-                    select(
-                        [
+                    _select(
+                        *[
                             transfer["expr"]
                             for transfer in self.column_transfers.values()
                             if "expr" in transfer
index d90b5e665a6d99c619a69922f7c79335493021e6..5e8aa4fec0d416f0c1405833fddcbaae62372256 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy.types import Integer
 from sqlalchemy.types import NULLTYPE
 
 from .. import util
+from ..util.compat import raise_
 from ..util.compat import string_types
 
 
@@ -113,10 +114,13 @@ class SchemaObjects(object):
         }
         try:
             const = types[type_]
-        except KeyError:
-            raise TypeError(
-                "'type' can be one of %s"
-                % ", ".join(sorted(repr(x) for x in types))
+        except KeyError as ke:
+            raise_(
+                TypeError(
+                    "'type' can be one of %s"
+                    % ", ".join(sorted(repr(x) for x in types))
+                ),
+                from_=ke,
             )
         else:
             const = const(name=name)
index 5c8590d6c4fd895014241295f61288ab6fbe7507..48bb842364a9297edd7c5d03ba3dae20a61da568 100644 (file)
@@ -8,7 +8,7 @@ from sqlalchemy import MetaData
 from sqlalchemy import PrimaryKeyConstraint
 from sqlalchemy import String
 from sqlalchemy import Table
-from sqlalchemy.engine import Connection
+from sqlalchemy.engine import Engine
 from sqlalchemy.engine import url as sqla_url
 from sqlalchemy.engine.strategies import MockEngineStrategy
 
@@ -31,15 +31,18 @@ class _ProxyTransaction(object):
 
     def rollback(self):
         self._proxied_transaction.rollback()
+        self.migration_context._transaction = None
 
     def commit(self):
         self._proxied_transaction.commit()
+        self.migration_context._transaction = None
 
     def __enter__(self):
         return self
 
     def __exit__(self, type_, value, traceback):
         self._proxied_transaction.__exit__(type_, value, traceback)
+        self.migration_context._transaction = None
 
 
 class MigrationContext(object):
@@ -105,8 +108,13 @@ class MigrationContext(object):
         if as_sql:
             self.connection = self._stdout_connection(connection)
             assert self.connection is not None
+            self._in_external_transaction = False
         else:
             self.connection = connection
+            self._in_external_transaction = (
+                sqla_compat._get_connection_in_transaction(connection)
+            )
+
         self._migrations_fn = opts.get("fn")
         self.as_sql = as_sql
 
@@ -199,12 +207,11 @@ class MigrationContext(object):
             dialect_opts = {}
 
         if connection:
-            if not isinstance(connection, Connection):
-                util.warn(
+            if isinstance(connection, Engine):
+                raise util.CommandError(
                     "'connection' argument to configure() is expected "
                     "to be a sqlalchemy.engine.Connection instance, "
                     "got %r" % connection,
-                    stacklevel=3,
                 )
 
             dialect = connection.dialect
@@ -268,19 +275,27 @@ class MigrationContext(object):
         """
         _in_connection_transaction = self._in_connection_transaction()
 
-        if self.impl.transactional_ddl:
-            if self.as_sql:
-                self.impl.emit_commit()
+        if self.impl.transactional_ddl and self.as_sql:
+            self.impl.emit_commit()
 
-            elif _in_connection_transaction:
-                assert self._transaction is not None
+        elif _in_connection_transaction:
+            assert self._transaction is not None
 
-                self._transaction.commit()
-                self._transaction = None
+            self._transaction.commit()
+            self._transaction = None
 
         if not self.as_sql:
             current_level = self.connection.get_isolation_level()
-            self.connection.execution_options(isolation_level="AUTOCOMMIT")
+            base_connection = self.connection
+
+            # in 1.3 and 1.4 non-future mode, the connection gets switched
+            # out.  we can use the base connection with the new mode
+            # except that it will not know it's in "autocommit" and will
+            # emit deprecation warnings when an autocommit action takes
+            # place.
+            self.connection = (
+                self.impl.connection
+            ) = base_connection.execution_options(isolation_level="AUTOCOMMIT")
         try:
             yield
         finally:
@@ -288,13 +303,13 @@ class MigrationContext(object):
                 self.connection.execution_options(
                     isolation_level=current_level
                 )
+                self.connection = self.impl.connection = base_connection
 
-            if self.impl.transactional_ddl:
-                if self.as_sql:
-                    self.impl.emit_begin()
+            if self.impl.transactional_ddl and self.as_sql:
+                self.impl.emit_begin()
 
-                elif _in_connection_transaction:
-                    self._transaction = self.bind.begin()
+            elif _in_connection_transaction:
+                self._transaction = self.connection.begin()
 
     def begin_transaction(self, _per_migration=False):
         """Begin a logical transaction for migration operations.
@@ -337,23 +352,50 @@ class MigrationContext(object):
             :meth:`.MigrationContext.autocommit_block`
 
         """
-        transaction_now = _per_migration == self._transaction_per_migration
 
-        if not transaction_now:
+        @contextmanager
+        def do_nothing():
+            yield
 
-            @contextmanager
-            def do_nothing():
-                yield
+        if self._in_external_transaction:
+            return do_nothing()
 
+        if self.impl.transactional_ddl:
+            transaction_now = _per_migration == self._transaction_per_migration
+        else:
+            transaction_now = _per_migration is True
+
+        if not transaction_now:
             return do_nothing()
 
         elif not self.impl.transactional_ddl:
+            assert _per_migration
 
-            @contextmanager
-            def do_nothing():
-                yield
-
-            return do_nothing()
+            if self.as_sql:
+                return do_nothing()
+            else:
+                # track our own notion of a "transaction block", which must be
+                # committed when complete.   Don't rely upon whether or not the
+                # SQLAlchemy connection reports as "in transaction"; this
+                # because SQLAlchemy future connection features autobegin
+                # behavior, so it may already be in a transaction from our
+                # emitting of queries like "has_version_table", etc. While we
+                # could track these operations as well, that leaves open the
+                # possibility of new operations or other things happening in
+                # the user environment that still may be triggering
+                # "autobegin".
+
+                in_transaction = self._transaction is not None
+
+                if in_transaction:
+                    return do_nothing()
+                else:
+                    self._transaction = (
+                        sqla_compat._safe_begin_connection_transaction(
+                            self.connection
+                        )
+                    )
+                    return _ProxyTransaction(self)
         elif self.as_sql:
 
             @contextmanager
@@ -364,7 +406,9 @@ class MigrationContext(object):
 
             return begin_commit()
         else:
-            self._transaction = self.bind.begin()
+            self._transaction = sqla_compat._safe_begin_connection_transaction(
+                self.connection
+            )
             return _ProxyTransaction(self)
 
     def get_current_revision(self):
@@ -439,9 +483,10 @@ class MigrationContext(object):
         )
 
     def _ensure_version_table(self, purge=False):
-        self._version.create(self.connection, checkfirst=True)
-        if purge:
-            self.connection.execute(self._version.delete())
+        with sqla_compat._ensure_scope_for_ddl(self.connection):
+            self._version.create(self.connection, checkfirst=True)
+            if purge:
+                self.connection.execute(self._version.delete())
 
     def _has_version_table(self):
         return sqla_compat._connectable_has_table(
@@ -504,12 +549,9 @@ class MigrationContext(object):
 
         head_maintainer = HeadMaintainer(self, heads)
 
-        starting_in_transaction = (
-            not self.as_sql and self._in_connection_transaction()
-        )
-
         for step in self._migrations_fn(heads, self):
             with self.begin_transaction(_per_migration=True):
+
                 if self.as_sql and not head_maintainer.heads:
                     # for offline mode, include a CREATE TABLE from
                     # the base
@@ -535,18 +577,6 @@ class MigrationContext(object):
                         run_args=kw,
                     )
 
-            if (
-                not starting_in_transaction
-                and not self.as_sql
-                and not self.impl.transactional_ddl
-                and self._in_connection_transaction()
-            ):
-                raise util.CommandError(
-                    'Migration "%s" has left an uncommitted '
-                    "transaction opened; transactional_ddl is False so "
-                    "Alembic is not committing transactions" % step
-                )
-
         if self.as_sql and not head_maintainer.heads:
             self._version.drop(self.connection)
 
index fea9e879f2ad459cd8ac8d9253108f10c8c27095..363895c1ed4bd7c113f4f6a33bd30ba119d6edd3 100644 (file)
@@ -171,7 +171,7 @@ class ScriptDirectory(object):
                     "ancestor/descendant revisions along the same branch"
                 )
             ancestor = ancestor % {"start": start, "end": end}
-            compat.raise_from_cause(util.CommandError(ancestor))
+            compat.raise_(util.CommandError(ancestor), from_=rna)
         except revision.MultipleHeads as mh:
             if not multiple_heads:
                 multiple_heads = (
@@ -185,15 +185,15 @@ class ScriptDirectory(object):
                 "head_arg": end or mh.argument,
                 "heads": util.format_as_comma(mh.heads),
             }
-            compat.raise_from_cause(util.CommandError(multiple_heads))
+            compat.raise_(util.CommandError(multiple_heads), from_=mh)
         except revision.ResolutionError as re:
             if resolution is None:
                 resolution = "Can't locate revision identified by '%s'" % (
                     re.argument
                 )
-            compat.raise_from_cause(util.CommandError(resolution))
+            compat.raise_(util.CommandError(resolution), from_=re)
         except revision.RevisionError as err:
-            compat.raise_from_cause(util.CommandError(err.args[0]))
+            compat.raise_(util.CommandError(err.args[0]), from_=err)
 
     def walk_revisions(self, base="base", head="heads"):
         """Iterate through all revisions.
@@ -571,7 +571,7 @@ class ScriptDirectory(object):
         try:
             Script.verify_rev_id(revid)
         except revision.RevisionError as err:
-            compat.raise_from_cause(util.CommandError(err.args[0]))
+            compat.raise_(util.CommandError(err.args[0]), from_=err)
 
         with self._catch_revision_errors(
             multiple_heads=(
@@ -659,7 +659,7 @@ class ScriptDirectory(object):
         try:
             script = Script._from_path(self, path)
         except revision.RevisionError as err:
-            compat.raise_from_cause(util.CommandError(err.args[0]))
+            compat.raise_(util.CommandError(err.args[0]), from_=err)
         if branch_labels and not script.branch_labels:
             raise util.CommandError(
                 "Version %s specified branch_labels %s, however the "
index 683d3227e8ee439f6989112a479955ea41dcd565..c75d1c0e643403a05d5e70d076a1a55fa96a174c 100644 (file)
@@ -422,9 +422,12 @@ class RevisionMap(object):
         except KeyError:
             try:
                 nonbranch_rev = self._revision_for_ident(branch_label)
-            except ResolutionError:
-                raise ResolutionError(
-                    "No such branch: '%s'" % branch_label, branch_label
+            except ResolutionError as re:
+                util.raise_(
+                    ResolutionError(
+                        "No such branch: '%s'" % branch_label, branch_label
+                    ),
+                    from_=re,
                 )
             else:
                 return nonbranch_rev
index 7d0843b8c292824d7cc6b179395ed25ac0110f4b..d6d1d385e21cc073a0261c93098693fa37dce205 100644 (file)
@@ -39,9 +39,10 @@ def _invoke(name, revision, options):
     """
     try:
         hook = _registry[name]
-    except KeyError:
-        compat.raise_from_cause(
-            util.CommandError("No formatter with name '%s' registered" % name)
+    except KeyError as ke:
+        compat.raise_(
+            util.CommandError("No formatter with name '%s' registered" % name),
+            from_=ke,
         )
     else:
         return hook(revision, options)
@@ -65,12 +66,13 @@ def _run_hooks(path, hook_config):
         opts["_hook_name"] = name
         try:
             type_ = opts["type"]
-        except KeyError:
-            compat.raise_from_cause(
+        except KeyError as ke:
+            compat.raise_(
                 util.CommandError(
                     "Key %s.type is required for post write hook %r"
                     % (name, name)
-                )
+                ),
+                from_=ke,
             )
         else:
             util.status(
@@ -89,12 +91,13 @@ def console_scripts(path, options):
 
     try:
         entrypoint_name = options["entrypoint"]
-    except KeyError:
-        compat.raise_from_cause(
+    except KeyError as ke:
+        compat.raise_(
             util.CommandError(
                 "Key %s.entrypoint is required for post write hook %r"
                 % (options["_hook_name"], options["_hook_name"])
-            )
+            ),
+            from_=ke,
         )
     iter_ = pkg_resources.iter_entry_points("console_scripts", entrypoint_name)
     impl = next(iter_)
index 23c0f19bec2a33266c8aec5f265173a89c670ea7..5f497a6341300766653cb065a8094acb83886947 100644 (file)
@@ -1,10 +1,12 @@
 from sqlalchemy.testing import config  # noqa
-from sqlalchemy.testing import exclusions  # noqa
 from sqlalchemy.testing import emits_warning  # noqa
 from sqlalchemy.testing import engines  # noqa
+from sqlalchemy.testing import exclusions  # noqa
 from sqlalchemy.testing import mock  # noqa
 from sqlalchemy.testing import provide_metadata  # noqa
 from sqlalchemy.testing import uses_deprecated  # noqa
+from sqlalchemy.testing.config import combinations  # noqa
+from sqlalchemy.testing.config import fixture  # noqa
 from sqlalchemy.testing.config import requirements as requires  # noqa
 
 from alembic import util  # noqa
@@ -13,12 +15,14 @@ from .assertions import assert_raises_message  # noqa
 from .assertions import emits_python_deprecation_warning  # noqa
 from .assertions import eq_  # noqa
 from .assertions import eq_ignore_whitespace  # noqa
+from .assertions import expect_raises  # noqa
+from .assertions import expect_raises_message  # noqa
+from .assertions import expect_sqlalchemy_deprecated  # noqa
+from .assertions import expect_sqlalchemy_deprecated_20  # noqa
 from .assertions import is_  # noqa
 from .assertions import is_false  # noqa
 from .assertions import is_not_  # noqa
 from .assertions import is_true  # noqa
 from .assertions import ne_  # noqa
-from .fixture_functions import combinations  # noqa
-from .fixture_functions import fixture  # noqa
 from .fixtures import TestBase  # noqa
 from .util import resolve_lambda  # noqa
index b09e09f3b071667c5c761ea9497d70f3edae9d0e..6d39f4c41ee37382775ae4391bc5242874949b16 100644 (file)
@@ -1,7 +1,10 @@
 from __future__ import absolute_import
 
+import contextlib
 import re
+import sys
 
+from sqlalchemy import exc as sa_exc
 from sqlalchemy import util
 from sqlalchemy.engine import default
 from sqlalchemy.testing.assertions import _expect_warnings
@@ -17,27 +20,92 @@ from ..util import sqla_compat
 from ..util.compat import py3k
 
 
+def _assert_proper_exception_context(exception):
+    """assert that any exception we're catching does not have a __context__
+    without a __cause__, and that __suppress_context__ is never set.
+
+    Python 3 will report nested as exceptions as "during the handling of
+    error X, error Y occurred". That's not what we want to do.  we want
+    these exceptions in a cause chain.
+
+    """
+
+    if not util.py3k:
+        return
+
+    if (
+        exception.__context__ is not exception.__cause__
+        and not exception.__suppress_context__
+    ):
+        assert False, (
+            "Exception %r was correctly raised but did not set a cause, "
+            "within context %r as its cause."
+            % (exception, exception.__context__)
+        )
+
+
 def assert_raises(except_cls, callable_, *args, **kw):
+    return _assert_raises(except_cls, callable_, args, kw, check_context=True)
+
+
+def assert_raises_context_ok(except_cls, callable_, *args, **kw):
+    return _assert_raises(except_cls, callable_, args, kw)
+
+
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+    return _assert_raises(
+        except_cls, callable_, args, kwargs, msg=msg, check_context=True
+    )
+
+
+def assert_raises_message_context_ok(
+    except_cls, msg, callable_, *args, **kwargs
+):
+    return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
+
+
+def _assert_raises(
+    except_cls, callable_, args, kwargs, msg=None, check_context=False
+):
+
+    with _expect_raises(except_cls, msg, check_context) as ec:
+        callable_(*args, **kwargs)
+    return ec.error
+
+
+class _ErrorContainer(object):
+    error = None
+
+
+@contextlib.contextmanager
+def _expect_raises(except_cls, msg=None, check_context=False):
+    ec = _ErrorContainer()
+    if check_context:
+        are_we_already_in_a_traceback = sys.exc_info()[0]
     try:
-        callable_(*args, **kw)
+        yield ec
         success = False
-    except except_cls:
+    except except_cls as err:
+        ec.error = err
         success = True
+        if msg is not None:
+            assert re.search(
+                msg, util.text_type(err), re.UNICODE
+            ), "%r !~ %s" % (msg, err)
+        if check_context and not are_we_already_in_a_traceback:
+            _assert_proper_exception_context(err)
+        print(util.text_type(err).encode("utf-8"))
 
     # assert outside the block so it works for AssertionError too !
     assert success, "Callable did not raise an exception"
 
 
-def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
-    try:
-        callable_(*args, **kwargs)
-        assert False, "Callable did not raise an exception"
-    except except_cls as e:
-        assert re.search(msg, util.text_type(e), re.UNICODE), "%r !~ %s" % (
-            msg,
-            e,
-        )
-        print(util.text_type(e).encode("utf-8"))
+def expect_raises(except_cls, check_context=True):
+    return _expect_raises(except_cls, check_context=check_context)
+
+
+def expect_raises_message(except_cls, msg, check_context=True):
+    return _expect_raises(except_cls, msg=msg, check_context=check_context)
 
 
 def eq_ignore_whitespace(a, b, msg=None):
@@ -106,3 +174,11 @@ def emits_python_deprecation_warning(*messages):
             return fn(*args, **kw)
 
     return decorate
+
+
+def expect_sqlalchemy_deprecated(*messages, **kw):
+    return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
+
+
+def expect_sqlalchemy_deprecated_20(*messages, **kw):
+    return _expect_warnings(sa_exc.RemovedIn20Warning, messages, **kw)
index 473c73e7a37d57e4062bbb17c9c34fa27004bac0..62b74ec577f031f8f7e39a96922d58c72d5c2ba7 100644 (file)
@@ -4,9 +4,10 @@ import os
 import shutil
 import textwrap
 
-from sqlalchemy.testing import engines
+from sqlalchemy.testing import config
 from sqlalchemy.testing import provision
 
+from . import util as testing_util
 from .. import util
 from ..script import Script
 from ..script import ScriptDirectory
@@ -93,25 +94,28 @@ config = context.config
         f.write(txt)
 
 
-def _sqlite_file_db(tempname="foo.db"):
+def _sqlite_file_db(tempname="foo.db", future=False):
     dir_ = os.path.join(_get_staging_directory(), "scripts")
     url = "sqlite:///%s/%s" % (dir_, tempname)
-    return engines.testing_engine(url=url)
+    return testing_util.testing_engine(url=url, future=future)
 
 
-def _sqlite_testing_config(sourceless=False):
+def _sqlite_testing_config(sourceless=False, future=False):
     dir_ = os.path.join(_get_staging_directory(), "scripts")
     url = "sqlite:///%s/foo.db" % dir_
 
+    sqlalchemy_future = future or ("future" in config.db.__class__.__module__)
+
     return _write_config_file(
         """
 [alembic]
 script_location = %s
 sqlalchemy.url = %s
 sourceless = %s
+%s
 
 [loggers]
-keys = root
+keys = root,sqlalchemy
 
 [handlers]
 keys = console
@@ -121,6 +125,11 @@ level = WARN
 handlers = console
 qualname =
 
+[logger_sqlalchemy]
+level = DEBUG
+handlers =
+qualname = sqlalchemy.engine
+
 [handler_console]
 class = StreamHandler
 args = (sys.stderr,)
@@ -134,12 +143,19 @@ keys = generic
 format = %%(levelname)-5.5s [%%(name)s] %%(message)s
 datefmt = %%H:%%M:%%S
     """
-        % (dir_, url, "true" if sourceless else "false")
+        % (
+            dir_,
+            url,
+            "true" if sourceless else "false",
+            "sqlalchemy.future = true" if sqlalchemy_future else "",
+        )
     )
 
 
 def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
     dir_ = os.path.join(_get_staging_directory(), "scripts")
+    sqlalchemy_future = "future" in config.db.__class__.__module__
+
     url = "sqlite:///%s/foo.db" % dir_
 
     return _write_config_file(
@@ -147,6 +163,7 @@ def _multi_dir_testing_config(sourceless=False, extra_version_location=""):
 [alembic]
 script_location = %s
 sqlalchemy.url = %s
+sqlalchemy.future = %s
 sourceless = %s
 version_locations = %%(here)s/model1/ %%(here)s/model2/ %%(here)s/model3/ %s
 
@@ -177,6 +194,7 @@ datefmt = %%H:%%M:%%S
         % (
             dir_,
             url,
+            "true" if sqlalchemy_future else "false",
             "true" if sourceless else "false",
             extra_version_location,
         )
@@ -463,6 +481,8 @@ def _multidb_testing_config(engines):
 
     dir_ = os.path.join(_get_staging_directory(), "scripts")
 
+    sqlalchemy_future = "future" in config.db.__class__.__module__
+
     databases = ", ".join(engines.keys())
     engines = "\n\n".join(
         "[%s]\n" "sqlalchemy.url = %s" % (key, value.url)
@@ -474,7 +494,7 @@ def _multidb_testing_config(engines):
 [alembic]
 script_location = %s
 sourceless = false
-
+sqlalchemy.future = %s
 databases = %s
 
 %s
@@ -502,5 +522,5 @@ keys = generic
 format = %%(levelname)-5.5s [%%(name)s] %%(message)s
 datefmt = %%H:%%M:%%S
     """
-        % (dir_, databases, engines)
+        % (dir_, "true" if sqlalchemy_future else "false", databases, engines)
     )
diff --git a/alembic/testing/fixture_functions.py b/alembic/testing/fixture_functions.py
deleted file mode 100644 (file)
index 2640693..0000000
+++ /dev/null
@@ -1,79 +0,0 @@
-_fixture_functions = None  # installed by plugin_base
-
-
-def combinations(*comb, **kw):
-    r"""Deliver multiple versions of a test based on positional combinations.
-
-    This is a facade over pytest.mark.parametrize.
-
-
-    :param \*comb: argument combinations.  These are tuples that will be passed
-     positionally to the decorated function.
-
-    :param argnames: optional list of argument names.   These are the names
-     of the arguments in the test function that correspond to the entries
-     in each argument tuple.   pytest.mark.parametrize requires this, however
-     the combinations function will derive it automatically if not present
-     by using ``inspect.getfullargspec(fn).args[1:]``.  Note this assumes the
-     first argument is "self" which is discarded.
-
-    :param id\_: optional id template.  This is a string template that
-     describes how the "id" for each parameter set should be defined, if any.
-     The number of characters in the template should match the number of
-     entries in each argument tuple.   Each character describes how the
-     corresponding entry in the argument tuple should be handled, as far as
-     whether or not it is included in the arguments passed to the function, as
-     well as if it is included in the tokens used to create the id of the
-     parameter set.
-
-     If omitted, the argument combinations are passed to parametrize as is.  If
-     passed, each argument combination is turned into a pytest.param() object,
-     mapping the elements of the argument tuple to produce an id based on a
-     character value in the same position within the string template using the
-     following scheme::
-
-        i - the given argument is a string that is part of the id only, don't
-            pass it as an argument
-
-        n - the given argument should be passed and it should be added to the
-            id by calling the .__name__ attribute
-
-        r - the given argument should be passed and it should be added to the
-            id by calling repr()
-
-        s - the given argument should be passed and it should be added to the
-            id by calling str()
-
-        a - (argument) the given argument should be passed and it should not
-            be used to generated the id
-
-     e.g.::
-
-        @testing.combinations(
-            (operator.eq, "eq"),
-            (operator.ne, "ne"),
-            (operator.gt, "gt"),
-            (operator.lt, "lt"),
-            id_="na"
-        )
-        def test_operator(self, opfunc, name):
-            pass
-
-    The above combination will call ``.__name__`` on the first member of
-    each tuple and use that as the "id" to pytest.param().
-
-
-    """
-    return _fixture_functions.combinations(*comb, **kw)
-
-
-def fixture(*arg, **kw):
-    return _fixture_functions.fixture(*arg, **kw)
-
-
-def get_current_test_name():
-    return _fixture_functions.get_current_test_name()
-
-
-def skip_test(msg):
-    raise _fixture_functions.skip_test_exception(msg)
index dd1d2f16409c5129df9cd2103c9b2e77fa165973..d5d45ac73bfb4d183027a804487bab06af6235f5 100644 (file)
@@ -8,11 +8,13 @@ from sqlalchemy import inspect
 from sqlalchemy import MetaData
 from sqlalchemy import String
 from sqlalchemy import Table
+from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy.testing import config
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertions import eq_
-from sqlalchemy.testing.fixtures import TestBase  # noqa
+from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest
+from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase
 
 import alembic
 from .assertions import _get_dialect
@@ -20,16 +22,53 @@ from ..environment import EnvironmentContext
 from ..migration import MigrationContext
 from ..operations import Operations
 from ..util import compat
+from ..util import sqla_compat
 from ..util.compat import configparser
 from ..util.compat import string_types
 from ..util.compat import text_type
 from ..util.sqla_compat import create_mock_engine
 from ..util.sqla_compat import sqla_14
 
+
 testing_config = configparser.ConfigParser()
 testing_config.read(["test.cfg"])
 
 
+class TestBase(SQLAlchemyTestBase):
+    is_sqlalchemy_future = False
+
+    @testing.fixture()
+    def ops_context(self, migration_context):
+        with migration_context.begin_transaction(_per_migration=True):
+            yield Operations(migration_context)
+
+    @testing.fixture
+    def migration_context(self, connection):
+        return MigrationContext.configure(
+            connection, opts=dict(transaction_per_migration=True)
+        )
+
+    @testing.fixture
+    def connection(self):
+        with config.db.connect() as conn:
+            yield conn
+
+
+class TablesTest(TestBase, SQLAlchemyTablesTest):
+    pass
+
+
+if sqla_14:
+    from sqlalchemy.testing.fixtures import FutureEngineMixin
+else:
+
+    class FutureEngineMixin(object):
+        __requires__ = ("sqlalchemy_14",)
+
+
+FutureEngineMixin.is_sqlalchemy_future = True
+
+
 def capture_db(dialect="postgresql://"):
     buf = []
 
@@ -205,7 +244,8 @@ class AlterColRoundTripFixture(object):
         ), "server defaults %r and %r didn't compare as equivalent" % (s1, s2)
 
     def tearDown(self):
-        self.metadata.drop_all(self.conn)
+        with self.conn.begin():
+            self.metadata.drop_all(self.conn)
         self.conn.close()
 
     def _run_alter_col(self, from_, to_, compare=None):
@@ -218,26 +258,27 @@ class AlterColRoundTripFixture(object):
         )
         t = Table("x", self.metadata, column)
 
-        t.create(self.conn)
-        insp = inspect(self.conn)
-        old_col = insp.get_columns("x")[0]
-
-        # TODO: conditional comment support
-        self.op.alter_column(
-            "x",
-            column.name,
-            existing_type=column.type,
-            existing_server_default=column.server_default
-            if column.server_default is not None
-            else False,
-            existing_nullable=True if column.nullable else False,
-            # existing_comment=column.comment,
-            nullable=to_.get("nullable", None),
-            # modify_comment=False,
-            server_default=to_.get("server_default", False),
-            new_column_name=to_.get("name", None),
-            type_=to_.get("type", None),
-        )
+        with sqla_compat._ensure_scope_for_ddl(self.conn):
+            t.create(self.conn)
+            insp = inspect(self.conn)
+            old_col = insp.get_columns("x")[0]
+
+            # TODO: conditional comment support
+            self.op.alter_column(
+                "x",
+                column.name,
+                existing_type=column.type,
+                existing_server_default=column.server_default
+                if column.server_default is not None
+                else False,
+                existing_nullable=True if column.nullable else False,
+                # existing_comment=column.comment,
+                nullable=to_.get("nullable", None),
+                # modify_comment=False,
+                server_default=to_.get("server_default", False),
+                new_column_name=to_.get("name", None),
+                type_=to_.get("type", None),
+            )
 
         insp = inspect(self.conn)
         new_col = insp.get_columns("x")[0]
index 8200ec18bfb930d135c52aaafdf8b4a63a830cb8..d4a2c5521847f6d34003b49b2826ae49b1d84c29 100644 (file)
@@ -1,35 +1,4 @@
 """
 Bootstrapper for test framework plugins.
 
-This is vendored from SQLAlchemy so that we can use local overrides
-for plugin_base.py and pytestplugin.py.
-
 """
-
-
-import os
-import sys
-
-
-bootstrap_file = locals()["bootstrap_file"]
-to_bootstrap = locals()["to_bootstrap"]
-
-
-def load_file_as_module(name):
-    path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
-    if sys.version_info >= (3, 3):
-        from importlib import machinery
-
-        mod = machinery.SourceFileLoader(name, path).load_module()
-    else:
-        import imp
-
-        mod = imp.load_source(name, path)
-    return mod
-
-
-if to_bootstrap == "pytest":
-    sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
-    sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
-else:
-    raise Exception("unknown bootstrap: %s" % to_bootstrap)  # noqa
diff --git a/alembic/testing/plugin/plugin_base.py b/alembic/testing/plugin/plugin_base.py
deleted file mode 100644 (file)
index 2d5e95a..0000000
+++ /dev/null
@@ -1,125 +0,0 @@
-"""vendored plugin_base functions from the most recent SQLAlchemy versions.
-
-Alembic tests need to run on older versions of SQLAlchemy that don't
-necessarily have all the latest testing fixtures.
-
-"""
-from __future__ import absolute_import
-
-import abc
-import sys
-
-from sqlalchemy.testing.plugin.plugin_base import *  # noqa
-from sqlalchemy.testing.plugin.plugin_base import post
-from sqlalchemy.testing.plugin.plugin_base import post_begin as sqla_post_begin
-from sqlalchemy.testing.plugin.plugin_base import stop_test_class as sqla_stc
-
-py3k = sys.version_info >= (3, 0)
-
-
-if py3k:
-
-    ABC = abc.ABC
-else:
-
-    class ABC(object):
-        __metaclass__ = abc.ABCMeta
-
-
-def post_begin():
-    sqla_post_begin()
-
-    import warnings
-
-    try:
-        import pytest
-    except ImportError:
-        pass
-    else:
-        warnings.filterwarnings(
-            "once", category=pytest.PytestDeprecationWarning
-        )
-
-    from sqlalchemy import exc
-
-    if hasattr(exc, "RemovedIn20Warning"):
-        warnings.filterwarnings(
-            "error",
-            category=exc.RemovedIn20Warning,
-            message=".*Engine.execute",
-        )
-        warnings.filterwarnings(
-            "error",
-            category=exc.RemovedIn20Warning,
-            message=".*Passing a string",
-        )
-
-
-# override selected SQLAlchemy pytest hooks with vendored functionality
-def stop_test_class(cls):
-    sqla_stc(cls)
-    import os
-    from alembic.testing.env import _get_staging_directory
-
-    assert not os.path.exists(_get_staging_directory()), (
-        "staging directory %s was not cleaned up" % _get_staging_directory()
-    )
-
-
-def want_class(name, cls):
-    from sqlalchemy.testing import config
-    from sqlalchemy.testing import fixtures
-
-    if not issubclass(cls, fixtures.TestBase):
-        return False
-    elif name.startswith("_"):
-        return False
-    elif (
-        config.options.backend_only
-        and not getattr(cls, "__backend__", False)
-        and not getattr(cls, "__sparse_backend__", False)
-    ):
-        return False
-    else:
-        return True
-
-
-@post
-def _init_symbols(options, file_config):
-    from sqlalchemy.testing import config
-    from alembic.testing import fixture_functions as alembic_config
-
-    config._fixture_functions = (
-        alembic_config._fixture_functions
-    ) = _fixture_fn_class()
-
-
-class FixtureFunctions(ABC):
-    @abc.abstractmethod
-    def skip_test_exception(self, *arg, **kw):
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def combinations(self, *args, **kw):
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def param_ident(self, *args, **kw):
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def fixture(self, *arg, **kw):
-        raise NotImplementedError()
-
-    def get_current_test_name(self):
-        raise NotImplementedError()
-
-
-_fixture_fn_class = None
-
-
-def set_fixture_functions(fixture_fn_class):
-    from sqlalchemy.testing.plugin import plugin_base
-
-    global _fixture_fn_class
-    _fixture_fn_class = plugin_base._fixture_fn_class = fixture_fn_class
diff --git a/alembic/testing/plugin/pytestplugin.py b/alembic/testing/plugin/pytestplugin.py
deleted file mode 100644 (file)
index 6b76a17..0000000
+++ /dev/null
@@ -1,314 +0,0 @@
-"""vendored pytestplugin functions from the most recent SQLAlchemy versions.
-
-Alembic tests need to run on older versions of SQLAlchemy that don't
-necessarily have all the latest testing fixtures.
-
-"""
-try:
-    # installed by bootstrap.py
-    import sqla_plugin_base as plugin_base
-except ImportError:
-    # assume we're a package, use traditional import
-    from . import plugin_base
-
-from functools import update_wrapper
-import inspect
-import itertools
-import operator
-import os
-import re
-import sys
-
-import pytest
-from sqlalchemy.testing.plugin.pytestplugin import *  # noqa
-from sqlalchemy.testing.plugin.pytestplugin import pytest_configure as spc
-
-py3k = sys.version_info.major >= 3
-
-if py3k:
-    from typing import TYPE_CHECKING
-else:
-    TYPE_CHECKING = False
-
-if TYPE_CHECKING:
-    from typing import Sequence
-
-
-# override selected SQLAlchemy pytest hooks with vendored functionality
-def pytest_configure(config):
-    spc(config)
-
-    plugin_base.set_fixture_functions(PytestFixtureFunctions)
-
-
-def pytest_pycollect_makeitem(collector, name, obj):
-
-    if inspect.isclass(obj) and plugin_base.want_class(name, obj):
-        ctor = getattr(pytest.Class, "from_parent", pytest.Class)
-
-        return [
-            ctor(name=parametrize_cls.__name__, parent=collector)
-            for parametrize_cls in _parametrize_cls(collector.module, obj)
-        ]
-    elif (
-        inspect.isfunction(obj)
-        and isinstance(collector, pytest.Instance)
-        and plugin_base.want_method(collector.cls, obj)
-    ):
-        # None means, fall back to default logic, which includes
-        # method-level parametrize
-        return None
-    else:
-        # empty list means skip this item
-        return []
-
-
-_current_class = None
-
-
-def _parametrize_cls(module, cls):
-    """implement a class-based version of pytest parametrize."""
-
-    if "_sa_parametrize" not in cls.__dict__:
-        return [cls]
-
-    _sa_parametrize = cls._sa_parametrize
-    classes = []
-    for full_param_set in itertools.product(
-        *[params for argname, params in _sa_parametrize]
-    ):
-        cls_variables = {}
-
-        for argname, param in zip(
-            [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
-        ):
-            if not argname:
-                raise TypeError("need argnames for class-based combinations")
-            argname_split = re.split(r",\s*", argname)
-            for arg, val in zip(argname_split, param.values):
-                cls_variables[arg] = val
-        parametrized_name = "_".join(
-            # token is a string, but in py2k py.test is giving us a unicode,
-            # so call str() on it.
-            str(re.sub(r"\W", "", token))
-            for param in full_param_set
-            for token in param.id.split("-")
-        )
-        name = "%s_%s" % (cls.__name__, parametrized_name)
-        newcls = type.__new__(type, name, (cls,), cls_variables)
-        setattr(module, name, newcls)
-        classes.append(newcls)
-    return classes
-
-
-def getargspec(fn):
-    if sys.version_info.major == 3:
-        return inspect.getfullargspec(fn)
-    else:
-        return inspect.getargspec(fn)
-
-
-def _pytest_fn_decorator(target):
-    """Port of langhelpers.decorator with pytest-specific tricks."""
-    # from sqlalchemy rel_1_3_14
-
-    from sqlalchemy.util.langhelpers import format_argspec_plus
-    from sqlalchemy.util.compat import inspect_getfullargspec
-
-    def _exec_code_in_env(code, env, fn_name):
-        exec(code, env)
-        return env[fn_name]
-
-    def decorate(fn, add_positional_parameters=()):
-
-        spec = inspect_getfullargspec(fn)
-        if add_positional_parameters:
-            spec.args.extend(add_positional_parameters)
-
-        metadata = dict(target="target", fn="__fn", name=fn.__name__)
-        metadata.update(format_argspec_plus(spec, grouped=False))
-        code = (
-            """\
-def %(name)s(%(args)s):
-    return %(target)s(%(fn)s, %(apply_kw)s)
-"""
-            % metadata
-        )
-        decorated = _exec_code_in_env(
-            code, {"target": target, "__fn": fn}, fn.__name__
-        )
-        if not add_positional_parameters:
-            decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
-            decorated.__wrapped__ = fn
-            return update_wrapper(decorated, fn)
-        else:
-            # this is the pytest hacky part.  don't do a full update wrapper
-            # because pytest is really being sneaky about finding the args
-            # for the wrapped function
-            decorated.__module__ = fn.__module__
-            decorated.__name__ = fn.__name__
-            return decorated
-
-    return decorate
-
-
-class PytestFixtureFunctions(plugin_base.FixtureFunctions):
-    def skip_test_exception(self, *arg, **kw):
-        return pytest.skip.Exception(*arg, **kw)
-
-    _combination_id_fns = {
-        "i": lambda obj: obj,
-        "r": repr,
-        "s": str,
-        "n": operator.attrgetter("__name__"),
-    }
-
-    def combinations(self, *arg_sets, **kw):
-        """Facade for pytest.mark.parametrize.
-
-        Automatically derives argument names from the callable which in our
-        case is always a method on a class with positional arguments.
-
-        ids for parameter sets are derived using an optional template.
-
-        """
-        # from sqlalchemy rel_1_3_14
-        from alembic.testing import exclusions
-
-        if sys.version_info.major == 3:
-            if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
-                arg_sets = list(arg_sets[0])
-        else:
-            if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
-                arg_sets = list(arg_sets[0])
-
-        argnames = kw.pop("argnames", None)
-
-        def _filter_exclusions(args):
-            result = []
-            gathered_exclusions = []
-            for a in args:
-                if isinstance(a, exclusions.compound):
-                    gathered_exclusions.append(a)
-                else:
-                    result.append(a)
-
-            return result, gathered_exclusions
-
-        id_ = kw.pop("id_", None)
-
-        tobuild_pytest_params = []
-        has_exclusions = False
-        if id_:
-            _combination_id_fns = self._combination_id_fns
-
-            # because itemgetter is not consistent for one argument vs.
-            # multiple, make it multiple in all cases and use a slice
-            # to omit the first argument
-            _arg_getter = operator.itemgetter(
-                0,
-                *[
-                    idx
-                    for idx, char in enumerate(id_)
-                    if char in ("n", "r", "s", "a")
-                ]
-            )
-            fns = [
-                (operator.itemgetter(idx), _combination_id_fns[char])
-                for idx, char in enumerate(id_)
-                if char in _combination_id_fns
-            ]
-
-            for arg in arg_sets:
-                if not isinstance(arg, tuple):
-                    arg = (arg,)
-
-                fn_params, param_exclusions = _filter_exclusions(arg)
-
-                parameters = _arg_getter(fn_params)[1:]
-
-                if param_exclusions:
-                    has_exclusions = True
-
-                tobuild_pytest_params.append(
-                    (
-                        parameters,
-                        param_exclusions,
-                        "-".join(
-                            comb_fn(getter(arg)) for getter, comb_fn in fns
-                        ),
-                    )
-                )
-
-        else:
-
-            for arg in arg_sets:
-                if not isinstance(arg, tuple):
-                    arg = (arg,)
-
-                fn_params, param_exclusions = _filter_exclusions(arg)
-
-                if param_exclusions:
-                    has_exclusions = True
-
-                tobuild_pytest_params.append(
-                    (fn_params, param_exclusions, None)
-                )
-
-        pytest_params = []
-        for parameters, param_exclusions, id_ in tobuild_pytest_params:
-            if has_exclusions:
-                parameters += (param_exclusions,)
-
-            param = pytest.param(*parameters, id=id_)
-            pytest_params.append(param)
-
-        def decorate(fn):
-            if inspect.isclass(fn):
-                if has_exclusions:
-                    raise NotImplementedError(
-                        "exclusions not supported for class level combinations"
-                    )
-                if "_sa_parametrize" not in fn.__dict__:
-                    fn._sa_parametrize = []
-                fn._sa_parametrize.append((argnames, pytest_params))
-                return fn
-            else:
-                if argnames is None:
-                    _argnames = getargspec(fn).args[1:]  # type: Sequence(str)
-                else:
-                    _argnames = re.split(
-                        r", *", argnames
-                    )  # type: Sequence(str)
-
-                if has_exclusions:
-                    _argnames += ["_exclusions"]
-
-                    @_pytest_fn_decorator
-                    def check_exclusions(fn, *args, **kw):
-                        _exclusions = args[-1]
-                        if _exclusions:
-                            exlu = exclusions.compound().add(*_exclusions)
-                            fn = exlu(fn)
-                        return fn(*args[0:-1], **kw)
-
-                    def process_metadata(spec):
-                        spec.args.append("_exclusions")
-
-                    fn = check_exclusions(
-                        fn, add_positional_parameters=("_exclusions",)
-                    )
-
-                return pytest.mark.parametrize(_argnames, pytest_params)(fn)
-
-        return decorate
-
-    def param_ident(self, *parameters):
-        ident = parameters[0]
-        return pytest.param(*parameters[1:], id=ident)
-
-    def fixture(self, *arg, **kw):
-        return pytest.fixture(*arg, **kw)
-
-    def get_current_test_name(self):
-        return os.environ.get("PYTEST_CURRENT_TEST")
index 456002d825c01332577e1ea1720bde95c6c47eb3..2de8a71793495772ff8565beacaa6d70fa77455b 100644 (file)
@@ -43,6 +43,15 @@ class SuiteRequirements(Requirements):
 
         return exclusions.skip_if(doesnt_have_check_uq_constraints)
 
+    @property
+    def sequences(self):
+        """Target database must support SEQUENCEs."""
+
+        return exclusions.only_if(
+            [lambda config: config.db.dialect.supports_sequences],
+            "no sequence support",
+        )
+
     @property
     def foreign_key_match(self):
         return exclusions.open()
index 3e766456cf2889452c20cbcdd2936addda62a83e..ccabf9cdc705a1bce6db85143ee76e37de4143db 100644 (file)
@@ -95,3 +95,13 @@ def metadata_fixture(ddl="function"):
         return fixture_functions.fixture(scope=ddl)(run_ddl)
 
     return decorate
+
+
+def testing_engine(url=None, options=None, future=False):
+    from sqlalchemy.testing import config
+    from sqlalchemy.testing.engines import testing_engine
+
+    if not future:
+        future = getattr(config._current.options, "future_engine", False)
+    kw = {"future": future} if future else {}
+    return testing_engine(url, options, **kw)
diff --git a/alembic/testing/warnings.py b/alembic/testing/warnings.py
new file mode 100644 (file)
index 0000000..0182032
--- /dev/null
@@ -0,0 +1,48 @@
+# testing/warnings.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+import warnings
+
+from sqlalchemy import exc as sa_exc
+
+
+def setup_filters():
+    """Set global warning behavior for the test suite."""
+
+    warnings.resetwarnings()
+
+    warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
+    warnings.filterwarnings("error", category=sa_exc.SAWarning)
+
+    # some selected deprecations...
+    warnings.filterwarnings("error", category=DeprecationWarning)
+    try:
+        import pytest
+    except ImportError:
+        pass
+    else:
+        warnings.filterwarnings(
+            "once", category=pytest.PytestDeprecationWarning
+        )
+
+    if hasattr(sa_exc, "RemovedIn20Warning"):
+        for msg in [
+            #
+            # Core execution - need to remove this after SQLAlchemy
+            # repairs it in provisioning
+            #
+            r"The connection.execute\(\) method in SQLAlchemy 2.0 will accept "
+            "parameters as a single dictionary or a single sequence of "
+            "dictionaries only.",
+        ]:
+            warnings.filterwarnings(
+                "ignore",
+                message=msg,
+                category=sa_exc.RemovedIn20Warning,
+            )
index 141ba4502528f57aafb8fb868d755d530df079dc..cfdab4984b06ce179a4d8ff43f1ae014a7f32566 100644 (file)
@@ -1,4 +1,4 @@
-from .compat import raise_from_cause  # noqa
+from .compat import raise_  # noqa
 from .exc import CommandError
 from .langhelpers import _with_legacy_names  # noqa
 from .langhelpers import asbool  # noqa
index f5a04ef05deace5818e2df32d68a3a572b9d0afb..c8919b62ece82d3dda37daf3d0f9c5418784c491 100644 (file)
@@ -261,34 +261,54 @@ def with_metaclass(meta, base=object):
 
 if py3k:
 
-    def reraise(tp, value, tb=None, cause=None):
-        if cause is not None:
-            value.__cause__ = cause
-        if value.__traceback__ is not tb:
-            raise value.with_traceback(tb)
-        raise value
+    def raise_(
+        exception, with_traceback=None, replace_context=None, from_=False
+    ):
+        r"""implement "raise" with cause support.
+
+        :param exception: exception to raise
+        :param with_traceback: will call exception.with_traceback()
+        :param replace_context: an as-yet-unsupported feature.  This is
+         an exception object which we are "replacing", e.g., it's our
+         "cause" but we don't want it printed.    Basically just what
+         ``__suppress_context__`` does but we don't want to suppress
+         the enclosing context, if any.  So for now we make it the
+         cause.
+        :param from\_: the cause.  this actually sets the cause and doesn't
+         hope to hide it someday.
 
-    def raise_from_cause(exception, exc_info=None):
-        if exc_info is None:
-            exc_info = sys.exc_info()
-        exc_type, exc_value, exc_tb = exc_info
-        reraise(type(exception), exception, tb=exc_tb, cause=exc_value)
+        """
+        if with_traceback is not None:
+            exception = exception.with_traceback(with_traceback)
+
+        if from_ is not False:
+            exception.__cause__ = from_
+        elif replace_context is not None:
+            # no good solution here, we would like to have the exception
+            # have only the context of replace_context.__context__ so that the
+            # intermediary exception does not change, but we can't figure
+            # that out.
+            exception.__cause__ = replace_context
+
+        try:
+            raise exception
+        finally:
+            # credit to
+            # https://cosmicpercolator.com/2016/01/13/exception-leaks-in-python-2-and-3/
+            # as the __traceback__ object creates a cycle
+            del exception, replace_context, from_, with_traceback
 
 
 else:
     exec(
-        "def reraise(tp, value, tb=None, cause=None):\n"
-        "    raise tp, value, tb\n"
+        "def raise_(exception, with_traceback=None, replace_context=None, "
+        "from_=False):\n"
+        "    if with_traceback:\n"
+        "        raise type(exception), exception, with_traceback\n"
+        "    else:\n"
+        "        raise exception\n"
     )
 
-    def raise_from_cause(exception, exc_info=None):
-        # not as nice as that of Py3K, but at least preserves
-        # the code line where the issue occurred
-        if exc_info is None:
-            exc_info = sys.exc_info()
-        exc_type, exc_value, exc_tb = exc_info
-        reraise(type(exception), exception, tb=exc_tb)
-
 
 # produce a wrapper that allows encoded text to stream
 # into a given buffer, but doesn't close it.
index bb9c8f534e16e8ece150f8def20a1e7d507e5e0e..cc07f4b4ac10f082d9bbc43b157467161adf5e6b 100644 (file)
@@ -7,6 +7,7 @@ from .compat import callable
 from .compat import collections_abc
 from .compat import exec_
 from .compat import inspect_getargspec
+from .compat import raise_
 from .compat import string_types
 from .compat import with_metaclass
 
@@ -74,13 +75,16 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
     def _create_method_proxy(cls, name, globals_, locals_):
         fn = getattr(cls, name)
 
-        def _name_error(name):
-            raise NameError(
-                "Can't invoke function '%s', as the proxy object has "
-                "not yet been "
-                "established for the Alembic '%s' class.  "
-                "Try placing this code inside a callable."
-                % (name, cls.__name__)
+        def _name_error(name, from_):
+            raise_(
+                NameError(
+                    "Can't invoke function '%s', as the proxy object has "
+                    "not yet been "
+                    "established for the Alembic '%s' class.  "
+                    "Try placing this code inside a callable."
+                    % (name, cls.__name__)
+                ),
+                from_=from_,
             )
 
         globals_["_name_error"] = _name_error
@@ -142,8 +146,8 @@ class ModuleClsProxy(with_metaclass(_ModuleClsMeta)):
             %(translate)s
             try:
                 p = _proxy
-            except NameError:
-                _name_error('%(name)s')
+            except NameError as ne:
+                _name_error('%(name)s', ne)
             return _proxy.%(name)s(%(apply_kw)s)
             e
         """
index b65df2c4886df8cbea75e9b925da5398ab329d05..8dfb9db0b3f40677ce1e819df2f48c187e0ff63f 100644 (file)
@@ -10,6 +10,7 @@ from .compat import has_pep3147
 from .compat import load_module_py
 from .compat import load_module_pyc
 from .compat import py3k
+from .compat import raise_
 from .exc import CommandError
 
 
@@ -82,7 +83,7 @@ def edit(path):
     try:
         editor.edit(path)
     except Exception as exc:
-        raise CommandError("Error executing editor (%s)" % (exc,))
+        raise_(CommandError("Error executing editor (%s)" % (exc,)), from_=exc)
 
 
 def load_python_file(dir_, filename):
index 159d0f09b10b56aeff9309441fb99284626187a2..29c2519d000e2890bb431a4d7e59eaf422ee14f3 100644 (file)
@@ -1,3 +1,4 @@
+import contextlib
 import re
 
 from sqlalchemy import __version__
@@ -64,6 +65,46 @@ else:
 AUTOINCREMENT_DEFAULT = "auto"
 
 
+@contextlib.contextmanager
+def _ensure_scope_for_ddl(connection):
+    try:
+        in_transaction = connection.in_transaction
+    except AttributeError:
+        # catch for MockConnection
+        yield
+    else:
+        if not in_transaction():
+            with connection.begin():
+                yield
+        else:
+            yield
+
+
+def _safe_begin_connection_transaction(connection):
+    transaction = _get_connection_transaction(connection)
+    if transaction:
+        return transaction
+    else:
+        return connection.begin()
+
+
+def _get_connection_in_transaction(connection):
+    try:
+        in_transaction = connection.in_transaction
+    except AttributeError:
+        # catch for MockConnection
+        return False
+    else:
+        return in_transaction()
+
+
+def _get_connection_transaction(connection):
+    if sqla_14:
+        return connection.get_transaction()
+    else:
+        return connection._Connection__transaction
+
+
 def _create_url(*arg, **kw):
     if hasattr(url.URL, "create"):
         return url.URL.create(*arg, **kw)
@@ -314,8 +355,16 @@ def _mariadb_normalized_version_info(mysql_dialect):
     return mysql_dialect._mariadb_normalized_version_info
 
 
+def _insert_inline(table):
+    if sqla_14:
+        return table.insert().inline()
+    else:
+        return table.insert(inline=True)
+
+
 if sqla_14:
     from sqlalchemy import create_mock_engine
+    from sqlalchemy import select as _select
 else:
     from sqlalchemy import create_engine
 
@@ -323,3 +372,6 @@ else:
         return create_engine(
             "postgresql://", strategy="mock", executor=executor
         )
+
+    def _select(*columns):
+        return sql.select(list(columns))
index 7072aa3c7c7afc8ba80d053e8181b759827d76b7..46dde52302bfa49cebdc91448be3404c6e077791 100644 (file)
@@ -465,7 +465,7 @@ are being used::
 
 Above, ``inspected_column`` is a :class:`sqlalchemy.schema.Column` as
 returned by
-:meth:`sqlalchemy.engine.reflection.Inspector.reflecttable`, whereas
+:meth:`sqlalchemy.engine.reflection.Inspector.reflect_table`, whereas
 ``metadata_column`` is a :class:`sqlalchemy.schema.Column` from the
 local model environment.  A return value of ``None`` indicates that default
 type comparison to proceed.
diff --git a/docs/build/unreleased/autocommit.rst b/docs/build/unreleased/autocommit.rst
new file mode 100644 (file)
index 0000000..39d6098
--- /dev/null
@@ -0,0 +1,21 @@
+.. change::
+       :tags: change, environment
+
+       To accommodate SQLAlchemy 1.4 and 2.0, the migration model now no longer
+       assumes that the SQLAlchemy Connection will autocommit an individual
+       operation.   This essentially means that for databases that use
+       non-transactional DDL (pysqlite current driver behavior, MySQL), there is
+       still a BEGIN/COMMIT block that will surround each individual migration.
+       Databases that support transactional DDL should continue to have the
+       same flow, either per migration or per-entire run, depending on the
+       value of the :paramref:`.Environment.configure.transaction_per_migration`
+       flag.
+
+
+.. change::
+       :tags: change, environment
+
+       It now raises a :class:`.CommandError` if a ``sqlalchemy.engine.Engine``
+       is passed to the :meth:`.MigrationContext.configure` method instead of
+       a ``sqlalchemy.engine.Connection`` object.  Previously, this would
+       be a warning only.
\ No newline at end of file
index e41b58f16b6207e1b238f62731bbf8ef841bef81..054c9fc5d008a7dc6bc9cfc2edbf882dc996a647 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,3 +1,66 @@
+[metadata]
+
+name = alembic
+
+# version comes from setup.py; setuptools
+# can't read the "attr:" here without importing
+# until version 47.0.0 which is too recent
+
+
+description = A database migration tool for SQLAlchemy.
+long_description = file: README.rst
+long_description_content_type = text/x-rst
+url=https://alembic.sqlalchemy.org
+author = Mike Bayer
+author_email = mike_mp@zzzcomputing.com
+license = MIT
+license_file = LICENSE
+
+
+classifiers =
+    Development Status :: 5 - Production/Stable
+    Intended Audience :: Developers
+    Environment :: Console
+    License :: OSI Approved :: MIT License
+    Operating System :: OS Independent
+    Programming Language :: Python
+    Programming Language :: Python :: 2
+    Programming Language :: Python :: 2.7
+    Programming Language :: Python :: 3
+    Programming Language :: Python :: 3.6
+    Programming Language :: Python :: 3.7
+    Programming Language :: Python :: 3.8
+    Programming Language :: Python :: 3.9
+    Programming Language :: Python :: Implementation :: CPython
+    Programming Language :: Python :: Implementation :: PyPy
+    Topic :: Database :: Front-Ends
+
+[options]
+packages = find:
+include_package_data = true
+zip_safe = false
+python_requires = >=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*
+package_dir =
+    =.
+
+install_requires =
+    SQLAlchemy>=1.3.0
+    Mako
+    python-editor>=0.3
+    python-dateutil
+
+[options.packages.find]
+exclude =
+    test*
+    examples*
+
+[options.exclude_package_data]
+'' = test*
+
+[options.entry_points]
+console_scripts =
+    alembic = alembic.config:main
+
 [egg_info]
 tag_build=dev
 
@@ -40,8 +103,9 @@ default=sqlite:///:memory:
 sqlite=sqlite:///:memory:
 sqlite_file=sqlite:///querytest.db
 postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
-mysql=mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8
-mssql=mssql+pyodbc://scott:tiger@ms_2008
+mysql=mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
+mariadb = mariadb://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
+mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server
 oracle=oracle://scott:tiger@127.0.0.1:1521
 oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
 
@@ -49,7 +113,7 @@ oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
 
 
 [tool:pytest]
-addopts= --tb native -v -r fxX -p no:warnings -p no:logging --maxfail=25
+addopts= --tb native -v -r sfxX -p no:warnings -p no:logging --maxfail=25
 python_files=tests/test_*.py
 
 
index 54374390dd047a313ada1be30ef8ac62f2777628..1ca5fd71969681fd3c4a9ab6145b954c207979d0 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -2,7 +2,6 @@ import os
 import re
 import sys
 
-from setuptools import find_packages
 from setuptools import setup
 from setuptools.command.test import test as TestCommand
 
@@ -16,16 +15,6 @@ VERSION = (
 v.close()
 
 
-readme = os.path.join(os.path.dirname(__file__), "README.rst")
-
-requires = [
-    "SQLAlchemy>=1.3.0",
-    "Mako",
-    "python-editor>=0.3",
-    "python-dateutil",
-]
-
-
 class UseTox(TestCommand):
     RED = 31
     RESET_SEQ = "\033[0m"
@@ -42,40 +31,6 @@ class UseTox(TestCommand):
 
 
 setup(
-    name="alembic",
     version=VERSION,
-    description="A database migration tool for SQLAlchemy.",
-    long_description=open(readme).read(),
-    python_requires=(
-        ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
-    ),
-    classifiers=[
-        "Development Status :: 5 - Production/Stable",
-        "Environment :: Console",
-        "License :: OSI Approved :: MIT License",
-        "Intended Audience :: Developers",
-        "Programming Language :: Python",
-        "Programming Language :: Python :: 2",
-        "Programming Language :: Python :: 2.7",
-        "Programming Language :: Python :: 3",
-        "Programming Language :: Python :: 3.6",
-        "Programming Language :: Python :: 3.7",
-        "Programming Language :: Python :: 3.8",
-        "Programming Language :: Python :: 3.9",
-        "Programming Language :: Python :: Implementation :: CPython",
-        "Programming Language :: Python :: Implementation :: PyPy",
-        "Topic :: Database :: Front-Ends",
-    ],
-    keywords="SQLAlchemy migrations",
-    author="Mike Bayer",
-    author_email="mike@zzzcomputing.com",
-    url="https://alembic.sqlalchemy.org",
-    project_urls={"Issue Tracker": "https://github.com/sqlalchemy/alembic/"},
-    license="MIT",
-    packages=find_packages(".", exclude=["examples*", "test*"]),
-    include_package_data=True,
     cmdclass={"test": UseTox},
-    zip_safe=False,
-    install_requires=requires,
-    entry_points={"console_scripts": ["alembic = alembic.config:main"]},
 )
index a83dff5833678cd5c7c8b62acce8583ba6b44025..325bb45cc88b76d49f4ee46e5252396bd0776d57 100755 (executable)
@@ -10,6 +10,8 @@ import os
 
 import pytest
 
+os.environ["SQLALCHEMY_WARN_20"] = "true"
+
 pytest.register_assert_rewrite("sqlalchemy.testing.assertions")
 
 
@@ -32,4 +34,12 @@ with open(bootstrap_file) as f:
     code = compile(f.read(), "bootstrap.py", "exec")
     to_bootstrap = "pytest"
     exec(code, globals(), locals())
-    from pytestplugin import *  # noqa
+    from sqlalchemy.testing.plugin.pytestplugin import *  # noqa
+
+    wrap_pytest_sessionstart = pytest_sessionstart  # noqa
+
+    def pytest_sessionstart(session):
+        wrap_pytest_sessionstart(session)
+        from alembic.testing import warnings
+
+        warnings.setup_filters()
index 830d4deaebb8277b26b5957849ae6a2d172dca71..8c818892649cd0631f7c89bb6eaea6b71ece4985 100644 (file)
@@ -271,3 +271,9 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def supports_identity_on_null(self):
         return self.identity_columns + exclusions.only_on(["oracle"])
+
+    @property
+    def legacy_engine(self):
+        return exclusions.only_if(
+            lambda config: not getattr(config.db, "_is_future", False)
+        )
index 94546ffbef418360fd7e74f160a88264de566c45..943e61ae424af89d416ecd6403b63c2ad2985e27 100644 (file)
@@ -14,9 +14,9 @@ from sqlalchemy import UniqueConstraint
 
 from alembic.testing import assertions
 from alembic.testing import config
-from alembic.testing import engines
 from alembic.testing import eq_
 from alembic.testing import TestBase
+from alembic.testing import util
 from alembic.testing.env import staging_env
 from alembic.util import sqla_compat
 from ._autogen_fixtures import AutogenFixtureTest
@@ -29,7 +29,7 @@ class NoUqReflection(object):
 
     def setUp(self):
         staging_env()
-        self.bind = eng = engines.testing_engine()
+        self.bind = eng = util.testing_engine()
 
         def unimpl(*arg, **kw):
             raise NotImplementedError()
@@ -1508,7 +1508,7 @@ class IncludeHooksTest(AutogenFixtureTest, TestBase):
 
 class TruncatedIdxTest(AutogenFixtureTest, TestBase):
     def setUp(self):
-        self.bind = engines.testing_engine()
+        self.bind = util.testing_engine()
         self.bind.dialect.max_identifier_length = 30
 
     def test_idx_matches_long(self):
index c9785f28473f0faa8cc0e3496d5cad5e71bbaa44..23ab364c1bde3b905938a319f05b7d2394e8a8ea 100644 (file)
@@ -24,7 +24,6 @@ from sqlalchemy.dialects import sqlite as sqlite_dialect
 from sqlalchemy.schema import CreateIndex
 from sqlalchemy.schema import CreateTable
 from sqlalchemy.sql import column
-from sqlalchemy.sql import select
 from sqlalchemy.sql import text
 
 from alembic.ddl import sqlite
@@ -40,6 +39,7 @@ from alembic.testing import mock
 from alembic.testing import TestBase
 from alembic.testing.fixtures import op_fixture
 from alembic.util import exc as alembic_exc
+from alembic.util.sqla_compat import _select
 from alembic.util.sqla_compat import sqla_14
 
 
@@ -851,8 +851,10 @@ class BatchApplyTest(TestBase):
 class BatchAPITest(TestBase):
     @contextmanager
     def _fixture(self, schema=None):
+
         migration_context = mock.Mock(
-            opts={}, impl=mock.MagicMock(__dialect__="sqlite")
+            opts={},
+            impl=mock.MagicMock(__dialect__="sqlite", connection=object()),
         )
         op = Operations(migration_context)
         batch = op.batch_alter_table(
@@ -1256,90 +1258,105 @@ class BatchRoundTripTest(TestBase):
             Column("x", Integer),
             mysql_engine="InnoDB",
         )
-        t1.create(self.conn)
+        with self.conn.begin():
+            t1.create(self.conn)
 
-        self.conn.execute(
-            t1.insert(),
-            [
-                {"id": 1, "data": "d1", "x": 5},
-                {"id": 2, "data": "22", "x": 6},
-                {"id": 3, "data": "8.5", "x": 7},
-                {"id": 4, "data": "9.46", "x": 8},
-                {"id": 5, "data": "d5", "x": 9},
-            ],
-        )
+            self.conn.execute(
+                t1.insert(),
+                [
+                    {"id": 1, "data": "d1", "x": 5},
+                    {"id": 2, "data": "22", "x": 6},
+                    {"id": 3, "data": "8.5", "x": 7},
+                    {"id": 4, "data": "9.46", "x": 8},
+                    {"id": 5, "data": "d5", "x": 9},
+                ],
+            )
         context = MigrationContext.configure(self.conn)
         self.op = Operations(context)
 
     @contextmanager
     def _sqlite_referential_integrity(self):
-        self.conn.execute("PRAGMA foreign_keys=ON")
+        self.conn.exec_driver_sql("PRAGMA foreign_keys=ON")
         try:
             yield
         finally:
-            self.conn.execute("PRAGMA foreign_keys=OFF")
+            self.conn.exec_driver_sql("PRAGMA foreign_keys=OFF")
+
+            # as these tests are typically intentional fails, clean out
+            # tables left over
+            m = MetaData()
+            m.reflect(self.conn)
+            with self.conn.begin():
+                m.drop_all(self.conn)
 
     def _no_pk_fixture(self):
-        nopk = Table(
-            "nopk",
-            self.metadata,
-            Column("a", Integer),
-            Column("b", Integer),
-            Column("c", Integer),
-            mysql_engine="InnoDB",
-        )
-        nopk.create(self.conn)
-        self.conn.execute(
-            nopk.insert(), [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 4, "c": 5}]
-        )
-        return nopk
+        with self.conn.begin():
+            nopk = Table(
+                "nopk",
+                self.metadata,
+                Column("a", Integer),
+                Column("b", Integer),
+                Column("c", Integer),
+                mysql_engine="InnoDB",
+            )
+            nopk.create(self.conn)
+            self.conn.execute(
+                nopk.insert(),
+                [{"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 4, "c": 5}],
+            )
+            return nopk
 
     def _table_w_index_fixture(self):
-        t = Table(
-            "t_w_ix",
-            self.metadata,
-            Column("id", Integer, primary_key=True),
-            Column("thing", Integer),
-            Column("data", String(20)),
-        )
-        Index("ix_thing", t.c.thing)
-        t.create(self.conn)
-        return t
+        with self.conn.begin():
+            t = Table(
+                "t_w_ix",
+                self.metadata,
+                Column("id", Integer, primary_key=True),
+                Column("thing", Integer),
+                Column("data", String(20)),
+            )
+            Index("ix_thing", t.c.thing)
+            t.create(self.conn)
+            return t
 
     def _boolean_fixture(self):
-        t = Table(
-            "hasbool",
-            self.metadata,
-            Column("x", Boolean(create_constraint=True, name="ck1")),
-            Column("y", Integer),
-        )
-        t.create(self.conn)
+        with self.conn.begin():
+            t = Table(
+                "hasbool",
+                self.metadata,
+                Column("x", Boolean(create_constraint=True, name="ck1")),
+                Column("y", Integer),
+            )
+            t.create(self.conn)
 
     def _timestamp_fixture(self):
-        t = Table("hasts", self.metadata, Column("x", DateTime()))
-        t.create(self.conn)
-        return t
+        with self.conn.begin():
+            t = Table("hasts", self.metadata, Column("x", DateTime()))
+            t.create(self.conn)
+            return t
 
     def _datetime_server_default_fixture(self):
         return func.datetime("now", "localtime")
 
     def _timestamp_w_expr_default_fixture(self):
-        t = Table(
-            "hasts",
-            self.metadata,
-            Column(
-                "x",
-                DateTime(),
-                server_default=self._datetime_server_default_fixture(),
-                nullable=False,
-            ),
-        )
-        t.create(self.conn)
-        return t
+        with self.conn.begin():
+            t = Table(
+                "hasts",
+                self.metadata,
+                Column(
+                    "x",
+                    DateTime(),
+                    server_default=self._datetime_server_default_fixture(),
+                    nullable=False,
+                ),
+            )
+            t.create(self.conn)
+            return t
 
     def _int_to_boolean_fixture(self):
-        t = Table("hasbool", self.metadata, Column("x", Integer))
-        t.create(self.conn)
+        with self.conn.begin():
+            t = Table("hasbool", self.metadata, Column("x", Integer))
+            t.create(self.conn)
 
     def test_change_type_boolean_to_int(self):
         self._boolean_fixture()
@@ -1365,15 +1382,16 @@ class BatchRoundTripTest(TestBase):
 
         import datetime
 
-        self.conn.execute(
-            t.insert(), {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)}
-        )
+        with self.conn.begin():
+            self.conn.execute(
+                t.insert(), {"x": datetime.datetime(2012, 5, 18, 15, 32, 5)}
+            )
 
         with self.op.batch_alter_table("hasts") as batch_op:
             batch_op.alter_column("x", type_=DateTime())
 
         eq_(
-            self.conn.execute(select([t.c.x])).fetchall(),
+            self.conn.execute(_select(t.c.x)).fetchall(),
             [(datetime.datetime(2012, 5, 18, 15, 32, 5),)],
         )
 
@@ -1388,10 +1406,14 @@ class BatchRoundTripTest(TestBase):
                 server_default=self._datetime_server_default_fixture(),
             )
 
-        self.conn.execute(t.insert())
-
-        row = self.conn.execute(select([t.c.x])).fetchone()
-        assert row["x"] is not None
+        with self.conn.begin():
+            self.conn.execute(t.insert())
+        res = self.conn.execute(_select(t.c.x))
+        if sqla_14:
+            assert res.scalar_one_or_none() is not None
+        else:
+            row = res.fetchone()
+            assert row["x"] is not None
 
     def test_drop_col_schematype(self):
         self._boolean_fixture()
@@ -1429,19 +1451,18 @@ class BatchRoundTripTest(TestBase):
             )
 
     def tearDown(self):
-        self.metadata.drop_all(self.conn)
+        in_t = getattr(self.conn, "in_transaction", lambda: False)
+        if in_t():
+            self.conn.rollback()
+        with self.conn.begin():
+            self.metadata.drop_all(self.conn)
         self.conn.close()
 
     def _assert_data(self, data, tablename="foo"):
-        eq_(
-            [
-                dict(row)
-                for row in self.conn.execute(
-                    text("select * from %s" % tablename)
-                )
-            ],
-            data,
-        )
+        res = self.conn.execute(text("select * from %s" % tablename))
+        if sqla_14:
+            res = res.mappings()
+        eq_([dict(row) for row in res], data)
 
     def test_ix_existing(self):
         self._table_w_index_fixture()
@@ -1486,8 +1507,9 @@ class BatchRoundTripTest(TestBase):
             Column("foo_id", Integer, ForeignKey("foo.id")),
             mysql_engine="InnoDB",
         )
-        bar.create(self.conn)
-        self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
+        with self.conn.begin():
+            bar.create(self.conn)
+            self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
 
         with self.op.batch_alter_table("foo", recreate=recreate) as batch_op:
             batch_op.alter_column(
@@ -1532,9 +1554,14 @@ class BatchRoundTripTest(TestBase):
             Column("data", String(50)),
             mysql_engine="InnoDB",
         )
-        bar.create(self.conn)
-        self.conn.execute(bar.insert(), {"id": 1, "data": "x", "bar_id": None})
-        self.conn.execute(bar.insert(), {"id": 2, "data": "y", "bar_id": 1})
+        with self.conn.begin():
+            bar.create(self.conn)
+            self.conn.execute(
+                bar.insert(), {"id": 1, "data": "x", "bar_id": None}
+            )
+            self.conn.execute(
+                bar.insert(), {"id": 2, "data": "y", "bar_id": 1}
+            )
 
         with self.op.batch_alter_table("bar", recreate=recreate) as batch_op:
             batch_op.alter_column(
@@ -1649,8 +1676,9 @@ class BatchRoundTripTest(TestBase):
             Column("foo_id", Integer, ForeignKey("foo.id")),
             mysql_engine="InnoDB",
         )
-        bar.create(self.conn)
-        self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
+        with self.conn.begin():
+            bar.create(self.conn)
+            self.conn.execute(bar.insert(), {"id": 1, "foo_id": 3})
 
         naming_convention = {
             "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s"
@@ -1773,9 +1801,10 @@ class BatchRoundTripTest(TestBase):
             Column("flag", Boolean(create_constraint=True)),
             mysql_engine="InnoDB",
         )
-        bar.create(self.conn)
-        self.conn.execute(bar.insert(), {"id": 1, "flag": True})
-        self.conn.execute(bar.insert(), {"id": 2, "flag": False})
+        with self.conn.begin():
+            bar.create(self.conn)
+            self.conn.execute(bar.insert(), {"id": 1, "flag": True})
+            self.conn.execute(bar.insert(), {"id": 2, "flag": False})
 
         with self.op.batch_alter_table("bar") as batch_op:
             batch_op.alter_column(
@@ -1795,15 +1824,16 @@ class BatchRoundTripTest(TestBase):
             Column("flag", Boolean(create_constraint=False)),
             mysql_engine="InnoDB",
         )
-        bar.create(self.conn)
-        self.conn.execute(bar.insert(), {"id": 1, "flag": True})
-        self.conn.execute(bar.insert(), {"id": 2, "flag": False})
-        self.conn.execute(
-            # override Boolean type which as of 1.1 coerces numerics
-            # to 1/0
-            text("insert into bar (id, flag) values (:id, :flag)"),
-            {"id": 3, "flag": 5},
-        )
+        with self.conn.begin():
+            bar.create(self.conn)
+            self.conn.execute(bar.insert(), {"id": 1, "flag": True})
+            self.conn.execute(bar.insert(), {"id": 2, "flag": False})
+            self.conn.execute(
+                # override Boolean type which as of 1.1 coerces numerics
+                # to 1/0
+                text("insert into bar (id, flag) values (:id, :flag)"),
+                {"id": 3, "flag": 5},
+            )
 
         with self.op.batch_alter_table(
             "bar",
@@ -2042,7 +2072,8 @@ class BatchRoundTripPostgresqlTest(BatchRoundTripTest):
             ),
             Column("y", Integer),
         )
-        t.create(self.conn)
+        with self.conn.begin():
+            t.create(self.conn)
 
     def _datetime_server_default_fixture(self):
         return func.current_timestamp()
index aedf6e9639430dbeff36df615a8b7cea01dbb005..09c641a2e721b6191fcf0d0529060a875b07c563 100644 (file)
@@ -237,23 +237,28 @@ class RoundTripTest(TestBase):
 
     def setUp(self):
         self.conn = config.db.connect()
-        self.conn.execute(
-            text(
-                """
-            create table foo(
-                id integer primary key,
-                data varchar(50),
-                x integer
-            )
-        """
+        with self.conn.begin():
+            self.conn.execute(
+                text(
+                    """
+                create table foo(
+                    id integer primary key,
+                    data varchar(50),
+                    x integer
+                )
+            """
+                )
             )
-        )
         context = MigrationContext.configure(self.conn)
         self.op = op.Operations(context)
         self.t1 = table("foo", column("id"), column("data"), column("x"))
 
+        self.trans = self.conn.begin()
+
     def tearDown(self):
-        self.conn.execute(text("drop table foo"))
+        self.trans.rollback()
+        with self.conn.begin():
+            self.conn.execute(text("drop table foo"))
         self.conn.close()
 
     def test_single_insert_round_trip(self):
index a616e9cc189f84ac7d984b91063b4d4c86c7e79c..8350e829a58041cd27da0f619b950ef312aa1fac 100644 (file)
@@ -399,7 +399,7 @@ finally:
         r2 = command.revision(self.cfg)
         db = _sqlite_file_db()
         command.upgrade(self.cfg, "head")
-        with db.connect() as conn:
+        with db.begin() as conn:
             conn.execute(
                 text("insert into alembic_version values ('%s')" % r2.revision)
             )
@@ -681,7 +681,7 @@ class StampMultipleHeadsTest(TestBase, _StampTest):
         command.stamp(self.cfg, [self.a])
 
         eng = _sqlite_file_db()
-        with eng.connect() as conn:
+        with eng.begin() as conn:
             result = conn.execute(
                 text("update alembic_version set version_num='fake'")
             )
index 4fb6bbe19c4bc9e7c8e00b43184d9adf93460639..63de6cd8f06a64b8f6fe243fe10a5c30c12c9205 100644 (file)
@@ -1,6 +1,7 @@
 #!coding: utf-8
 from alembic import command
 from alembic import testing
+from alembic import util
 from alembic.environment import EnvironmentContext
 from alembic.migration import MigrationContext
 from alembic.script import ScriptDirectory
@@ -11,7 +12,7 @@ from alembic.testing import is_
 from alembic.testing import is_false
 from alembic.testing import is_true
 from alembic.testing import mock
-from alembic.testing.assertions import expect_warnings
+from alembic.testing.assertions import expect_raises_message
 from alembic.testing.env import _no_sql_testing_config
 from alembic.testing.env import _sqlite_file_db
 from alembic.testing.env import clear_staging_env
@@ -94,10 +95,11 @@ def upgrade():
             command.upgrade(self.cfg, "arev", sql=True)
         assert "do some SQL thing with a % percent sign %" in buf.getvalue()
 
+    @config.requirements.legacy_engine
     @testing.uses_deprecated(
         r"The Engine.execute\(\) function/method is considered legacy"
     )
-    def test_warning_on_passing_engine(self):
+    def test_error_on_passing_engine(self):
         env = self._fixture()
 
         engine = _sqlite_file_db()
@@ -131,18 +133,15 @@ def downgrade():
             migration_fn(rev, context)
             return env.script._upgrade_revs(a_rev, rev)
 
-        with expect_warnings(
+        with expect_raises_message(
+            util.CommandError,
             r"'connection' argument to configure\(\) is "
-            r"expected to be a sqlalchemy.engine.Connection "
+            r"expected to be a sqlalchemy.engine.Connection ",
         ):
             env.configure(
                 connection=engine, fn=upgrade, transactional_ddl=False
             )
 
-        env.run_migrations()
-
-        eq_(migration_fn.mock_calls, [mock.call((), env._migration_context)])
-
 
 class MigrationTransactionTest(TestBase):
     __backend__ = True
@@ -238,7 +237,7 @@ class MigrationTransactionTest(TestBase):
         with context.begin_transaction():
             is_false(self.conn.in_transaction())
             with context.begin_transaction(_per_migration=True):
-                is_false(self.conn.in_transaction())
+                is_true(self.conn.in_transaction())
 
             is_false(self.conn.in_transaction())
         is_false(self.conn.in_transaction())
@@ -264,7 +263,7 @@ class MigrationTransactionTest(TestBase):
         with context.begin_transaction():
             is_false(self.conn.in_transaction())
             with context.begin_transaction(_per_migration=True):
-                is_false(self.conn.in_transaction())
+                is_true(self.conn.in_transaction())
 
             is_false(self.conn.in_transaction())
         is_false(self.conn.in_transaction())
@@ -334,18 +333,12 @@ class MigrationTransactionTest(TestBase):
         with context.begin_transaction():
             is_false(self.conn.in_transaction())
             with context.begin_transaction(_per_migration=True):
-                if context.impl.transactional_ddl:
-                    is_true(self.conn.in_transaction())
-                else:
-                    is_false(self.conn.in_transaction())
+                is_true(self.conn.in_transaction())
 
                 with context.autocommit_block():
                     is_false(self.conn.in_transaction())
 
-                if context.impl.transactional_ddl:
-                    is_true(self.conn.in_transaction())
-                else:
-                    is_false(self.conn.in_transaction())
+                is_true(self.conn.in_transaction())
 
             is_false(self.conn.in_transaction())
         is_false(self.conn.in_transaction())
diff --git a/tests/test_impl.py b/tests/test_impl.py
new file mode 100644 (file)
index 0000000..8a73b87
--- /dev/null
@@ -0,0 +1,45 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import Table
+from sqlalchemy.sql import text
+
+from alembic import testing
+from alembic.testing import eq_
+from alembic.testing.fixtures import FutureEngineMixin
+from alembic.testing.fixtures import TablesTest
+
+
+class ImplTest(TablesTest):
+    __only_on__ = "sqlite"
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "some_table", metadata, Column("x", Integer), Column("y", Integer)
+        )
+
+    @testing.fixture
+    def impl(self, migration_context):
+        with migration_context.begin_transaction(_per_migration=True):
+            yield migration_context.impl
+
+    def test_execute_params(self, impl):
+        result = impl._exec(text("select :my_param"), params={"my_param": 5})
+        eq_(result.scalar(), 5)
+
+    def test_execute_multiparams(self, impl):
+        some_table = self.tables.some_table
+        impl._exec(
+            some_table.insert(),
+            multiparams=[{"x": 1, "y": 2}, {"x": 2, "y": 3}, {"x": 5, "y": 7}],
+        )
+        eq_(
+            impl._exec(
+                some_table.select().order_by(some_table.c.x)
+            ).fetchall(),
+            [(1, 2), (2, 3), (5, 7)],
+        )
+
+
+class FutureImplTest(FutureEngineMixin, ImplTest):
+    pass
index caef197f611677b481123b460d39d8fe3a767cab..ba43e3ab2abb516e7637458f558b1db9de871024 100644 (file)
@@ -594,10 +594,11 @@ class MySQLDefaultCompareTest(TestBase):
         clear_staging_env()
 
     def setUp(self):
-        self.metadata = MetaData(self.bind)
+        self.metadata = MetaData()
 
     def tearDown(self):
-        self.metadata.drop_all()
+        with config.db.begin() as conn:
+            self.metadata.drop_all(conn)
 
     def _compare_default_roundtrip(self, type_, txt, alternate=None):
         if alternate:
index 08f70d84310ba9b853a3d4c2391720b8c1af2d9e..10f17d462e5a846926c9d94f7d1b38eb90ca5685 100644 (file)
@@ -35,7 +35,6 @@ from alembic.autogenerate.compare import _compare_server_default
 from alembic.autogenerate.compare import _compare_tables
 from alembic.autogenerate.compare import _render_server_default_for_compare
 from alembic.migration import MigrationContext
-from alembic.operations import Operations
 from alembic.operations import ops
 from alembic.script import ScriptDirectory
 from alembic.testing import assert_raises_message
@@ -50,6 +49,7 @@ from alembic.testing.env import staging_env
 from alembic.testing.env import write_script
 from alembic.testing.fixtures import capture_context_buffer
 from alembic.testing.fixtures import op_fixture
+from alembic.testing.fixtures import TablesTest
 from alembic.testing.fixtures import TestBase
 from alembic.util import sqla_compat
 
@@ -436,11 +436,12 @@ class PGAutocommitBlockTest(TestBase):
         with self.conn.begin():
             self.conn.execute(text("DROP TYPE mood"))
 
-    def test_alter_enum(self):
-        context = MigrationContext.configure(connection=self.conn)
-        with context.begin_transaction(_per_migration=True):
-            with context.autocommit_block():
-                context.execute(text("ALTER TYPE mood ADD VALUE 'soso'"))
+    def test_alter_enum(self, migration_context):
+        with migration_context.begin_transaction(_per_migration=True):
+            with migration_context.autocommit_block():
+                migration_context.execute(
+                    text("ALTER TYPE mood ADD VALUE 'soso'")
+                )
 
 
 class PGOfflineEnumTest(TestBase):
@@ -546,58 +547,38 @@ def downgrade():
         assert "DROP TYPE pgenum" in buf.getvalue()
 
 
-class PostgresqlInlineLiteralTest(TestBase):
+class PostgresqlInlineLiteralTest(TablesTest):
     __only_on__ = "postgresql"
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
-        cls.bind = config.db
-        with config.db.connect() as conn:
-            conn.execute(
-                text(
-                    """
-                create table tab (
-                    col varchar(50)
-                )
-            """
-                )
-            )
-            conn.execute(
-                text(
-                    """
+    def define_tables(cls, metadata):
+        Table("tab", metadata, Column("col", String(50)))
+
+    @classmethod
+    def insert_data(cls, connection):
+        connection.execute(
+            text(
+                """
                 insert into tab (col) values
                     ('old data 1'),
                     ('old data 2.1'),
                     ('old data 3')
             """
-                )
             )
+        )
 
-    @classmethod
-    def teardown_class(cls):
-        with cls.bind.connect() as conn:
-            conn.execute(text("drop table tab"))
-
-    def setUp(self):
-        self.conn = self.bind.connect()
-        ctx = MigrationContext.configure(self.conn)
-        self.op = Operations(ctx)
-
-    def tearDown(self):
-        self.conn.close()
-
-    def test_inline_percent(self):
+    def test_inline_percent(self, connection, ops_context):
         # TODO: here's the issue, you need to escape this.
         tab = table("tab", column("col"))
-        self.op.execute(
+        ops_context.execute(
             tab.update()
-            .where(tab.c.col.like(self.op.inline_literal("%.%")))
-            .values(col=self.op.inline_literal("new data")),
+            .where(tab.c.col.like(ops_context.inline_literal("%.%")))
+            .values(col=ops_context.inline_literal("new data")),
             execution_options={"no_parameters": True},
         )
         eq_(
-            self.conn.execute(
+            connection.execute(
                 text("select count(*) from tab where col='new data'")
             ).scalar(),
             1,
@@ -618,7 +599,7 @@ class PostgresqlDefaultCompareTest(TestBase):
         )
 
     def setUp(self):
-        self.metadata = MetaData(self.bind)
+        self.metadata = MetaData()
         self.autogen_context = api.AutogenContext(self.migration_context)
 
     @classmethod
@@ -626,7 +607,8 @@ class PostgresqlDefaultCompareTest(TestBase):
         clear_staging_env()
 
     def tearDown(self):
-        self.metadata.drop_all()
+        with config.db.begin() as conn:
+            self.metadata.drop_all(conn)
 
     def _compare_default_roundtrip(
         self, type_, orig_default, alternate=None, diff_expected=None
index 17bf0375ae8cd64e9e49ace007848622e120ef09..e1b094f6e0017d801ce9270a642d39e9b0ab6de7 100644 (file)
@@ -5,12 +5,16 @@ import os
 import re
 import textwrap
 
+import sqlalchemy as sa
+
 from alembic import command
+from alembic import testing
 from alembic import util
 from alembic.environment import EnvironmentContext
 from alembic.script import Script
 from alembic.script import ScriptDirectory
 from alembic.testing import assert_raises_message
+from alembic.testing import config
 from alembic.testing import eq_
 from alembic.testing import mock
 from alembic.testing.env import _no_sql_testing_config
@@ -22,31 +26,81 @@ from alembic.testing.env import staging_env
 from alembic.testing.env import three_rev_fixture
 from alembic.testing.env import write_script
 from alembic.testing.fixtures import capture_context_buffer
+from alembic.testing.fixtures import FutureEngineMixin
 from alembic.testing.fixtures import TestBase
 from alembic.util import compat
 
 
-class ApplyVersionsFunctionalTest(TestBase):
+class PatchEnvironment(object):
+    @contextmanager
+    def _patch_environment(self, transactional_ddl, transaction_per_migration):
+        conf = EnvironmentContext.configure
+
+        conn = [None]
+
+        def configure(*arg, **opt):
+            opt.update(
+                transactional_ddl=transactional_ddl,
+                transaction_per_migration=transaction_per_migration,
+            )
+            conn[0] = opt["connection"]
+            return conf(*arg, **opt)
+
+        with mock.patch.object(EnvironmentContext, "configure", configure):
+            yield
+
+            # it's no longer possible for the conn to be in a transaction
+            # assuming normal env.py as context.begin_transaction()
+            # will always run a real DB transaction, no longer uses autocommit
+            # mode
+            assert not conn[0].in_transaction()
+
+
+@testing.combinations(
+    (
+        False,
+        True,
+    ),
+    (
+        True,
+        False,
+    ),
+    (
+        True,
+        True,
+    ),
+    argnames="transactional_ddl,transaction_per_migration",
+    id_="rr",
+)
+class ApplyVersionsFunctionalTest(PatchEnvironment, TestBase):
     __only_on__ = "sqlite"
 
     sourceless = False
+    future = False
+    transactional_ddl = False
+    transaction_per_migration = True
 
     def setUp(self):
-        self.bind = _sqlite_file_db()
+        self.bind = _sqlite_file_db(future=self.future)
         self.env = staging_env(sourceless=self.sourceless)
-        self.cfg = _sqlite_testing_config(sourceless=self.sourceless)
+        self.cfg = _sqlite_testing_config(
+            sourceless=self.sourceless, future=self.future
+        )
 
     def tearDown(self):
         clear_staging_env()
 
     def test_steps(self):
-        self._test_001_revisions()
-        self._test_002_upgrade()
-        self._test_003_downgrade()
-        self._test_004_downgrade()
-        self._test_005_upgrade()
-        self._test_006_upgrade_again()
-        self._test_007_stamp_upgrade()
+        with self._patch_environment(
+            self.transactional_ddl, self.transaction_per_migration
+        ):
+            self._test_001_revisions()
+            self._test_002_upgrade()
+            self._test_003_downgrade()
+            self._test_004_downgrade()
+            self._test_005_upgrade()
+            self._test_006_upgrade_again()
+            self._test_007_stamp_upgrade()
 
     def _test_001_revisions(self):
         self.a = a = util.rev_id()
@@ -166,22 +220,39 @@ class ApplyVersionsFunctionalTest(TestBase):
         assert not db.dialect.has_table(db.connect(), "bat")
 
 
+# class level combinations can't do the skips for SQLAlchemy 1.3
+# so we have a separate class
+@testing.combinations(
+    (
+        False,
+        True,
+    ),
+    (
+        True,
+        False,
+    ),
+    (
+        True,
+        True,
+    ),
+    argnames="transactional_ddl,transaction_per_migration",
+    id_="rr",
+)
+class FutureApplyVersionsTest(FutureEngineMixin, ApplyVersionsFunctionalTest):
+    future = True
+
+
 class SimpleSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest):
     sourceless = "simple"
 
 
-class NewFangledSourcelessEnvOnlyApplyVersionsTest(
-    ApplyVersionsFunctionalTest
-):
-    sourceless = "pep3147_envonly"
-
-    __requires__ = ("pep3147",)
-
-
-class NewFangledSourcelessEverythingApplyVersionsTest(
-    ApplyVersionsFunctionalTest
-):
-    sourceless = "pep3147_everything"
+@testing.combinations(
+    ("pep3147_envonly",),
+    ("pep3147_everything",),
+    argnames="sourceless",
+    id_="r",
+)
+class NewFangledSourcelessApplyVersionsTest(ApplyVersionsFunctionalTest):
 
     __requires__ = ("pep3147",)
 
@@ -313,13 +384,17 @@ class OfflineTransactionalDDLTest(TestBase):
         )
 
 
-class OnlineTransactionalDDLTest(TestBase):
+class OnlineTransactionalDDLTest(PatchEnvironment, TestBase):
     def tearDown(self):
         clear_staging_env()
 
-    def _opened_transaction_fixture(self):
+    def _opened_transaction_fixture(self, future=False):
         self.env = staging_env()
-        self.cfg = _sqlite_testing_config()
+
+        if future:
+            self.cfg = _sqlite_testing_config(future=future)
+        else:
+            self.cfg = _sqlite_testing_config()
 
         script = ScriptDirectory.from_config(self.cfg)
         a = util.rev_id()
@@ -358,6 +433,8 @@ from alembic import op
 
 def upgrade():
     conn = op.get_bind()
+    # this should fail for a SQLAlchemy 2.0 connection b.c. there is
+    # already a transaction.
     trans = conn.begin()
 
 
@@ -391,59 +468,89 @@ def downgrade():
         )
         return a, b, c
 
-    @contextmanager
-    def _patch_environment(self, transactional_ddl, transaction_per_migration):
-        conf = EnvironmentContext.configure
-
-        def configure(*arg, **opt):
-            opt.update(
-                transactional_ddl=transactional_ddl,
-                transaction_per_migration=transaction_per_migration,
-            )
-            return conf(*arg, **opt)
-
-        with mock.patch.object(EnvironmentContext, "configure", configure):
-            yield
+    # these tests might not be supported anymore; the connection is always
+    # going to be in a transaction now even on 1.3.
 
-    def test_raise_when_rev_leaves_open_transaction(self):
-        a, b, c = self._opened_transaction_fixture()
+    @testing.combinations((False,), (True, config.requirements.sqlalchemy_14))
+    def test_raise_when_rev_leaves_open_transaction(self, future):
+        a, b, c = self._opened_transaction_fixture(future)
 
         with self._patch_environment(
             transactional_ddl=False, transaction_per_migration=False
         ):
-            assert_raises_message(
-                util.CommandError,
-                r'Migration "upgrade .*, rev b" has left an uncommitted '
-                r"transaction opened; transactional_ddl is False so Alembic "
-                r"is not committing transactions",
-                command.upgrade,
-                self.cfg,
-                c,
-            )
+            if future:
+                with testing.expect_raises_message(
+                    sa.exc.InvalidRequestError,
+                    "a transaction is already begun",
+                ):
+                    command.upgrade(self.cfg, c)
+            elif config.requirements.sqlalchemy_14.enabled:
+                if self.is_sqlalchemy_future:
+                    with testing.expect_raises_message(
+                        sa.exc.InvalidRequestError,
+                        r"a transaction is already begun for this connection",
+                    ):
+                        command.upgrade(self.cfg, c)
+                else:
+                    with testing.expect_sqlalchemy_deprecated_20(
+                        r"Calling .begin\(\) when a transaction "
+                        "is already begun"
+                    ):
+                        command.upgrade(self.cfg, c)
+            else:
+                command.upgrade(self.cfg, c)
 
-    def test_raise_when_rev_leaves_open_transaction_tpm(self):
-        a, b, c = self._opened_transaction_fixture()
+    @testing.combinations((False,), (True, config.requirements.sqlalchemy_14))
+    def test_raise_when_rev_leaves_open_transaction_tpm(self, future):
+        a, b, c = self._opened_transaction_fixture(future)
 
         with self._patch_environment(
             transactional_ddl=False, transaction_per_migration=True
         ):
-            assert_raises_message(
-                util.CommandError,
-                r'Migration "upgrade .*, rev b" has left an uncommitted '
-                r"transaction opened; transactional_ddl is False so Alembic "
-                r"is not committing transactions",
-                command.upgrade,
-                self.cfg,
-                c,
-            )
+            if future:
+                with testing.expect_raises_message(
+                    sa.exc.InvalidRequestError,
+                    "a transaction is already begun",
+                ):
+                    command.upgrade(self.cfg, c)
+            elif config.requirements.sqlalchemy_14.enabled:
+                if self.is_sqlalchemy_future:
+                    with testing.expect_raises_message(
+                        sa.exc.InvalidRequestError,
+                        r"a transaction is already begun for this connection",
+                    ):
+                        command.upgrade(self.cfg, c)
+                else:
+                    with testing.expect_sqlalchemy_deprecated_20(
+                        r"Calling .begin\(\) when a transaction is "
+                        "already begun"
+                    ):
+                        command.upgrade(self.cfg, c)
+            else:
+                command.upgrade(self.cfg, c)
 
-    def test_noerr_rev_leaves_open_transaction_transactional_ddl(self):
+    @testing.combinations((False,), (True, config.requirements.sqlalchemy_14))
+    def test_noerr_rev_leaves_open_transaction_transactional_ddl(self, future):
         a, b, c = self._opened_transaction_fixture()
 
         with self._patch_environment(
             transactional_ddl=True, transaction_per_migration=False
         ):
-            command.upgrade(self.cfg, c)
+            if config.requirements.sqlalchemy_14.enabled:
+                if self.is_sqlalchemy_future:
+                    with testing.expect_raises_message(
+                        sa.exc.InvalidRequestError,
+                        r"a transaction is already begun for this connection",
+                    ):
+                        command.upgrade(self.cfg, c)
+                else:
+                    with testing.expect_sqlalchemy_deprecated_20(
+                        r"Calling .begin\(\) when a transaction "
+                        "is already begun"
+                    ):
+                        command.upgrade(self.cfg, c)
+            else:
+                command.upgrade(self.cfg, c)
 
     def test_noerr_transaction_opened_externally(self):
         a, b, c = self._opened_transaction_fixture()
@@ -477,6 +584,12 @@ run_migrations_online()
         command.stamp(self.cfg, c)
 
 
+class FutureOnlineTransactionalDDLTest(
+    FutureEngineMixin, OnlineTransactionalDDLTest
+):
+    pass
+
+
 class EncodingTest(TestBase):
     def setUp(self):
         self.env = staging_env()
index 3ea1975c1fa2a9acf6f4445d82d2a0a66ea55d94..946f69f64bc1be4a31b72fc729fcaba3d38afddf 100644 (file)
@@ -96,7 +96,7 @@ class SQLiteDefaultCompareTest(TestBase):
         )
 
     def setUp(self):
-        self.metadata = MetaData(self.bind)
+        self.metadata = MetaData()
         self.autogen_context = api.AutogenContext(self.migration_context)
 
     @classmethod
@@ -104,7 +104,7 @@ class SQLiteDefaultCompareTest(TestBase):
         clear_staging_env()
 
     def tearDown(self):
-        self.metadata.drop_all()
+        self.metadata.drop_all(config.db)
 
     def _compare_default_roundtrip(
         self, type_, orig_default, alternate=None, diff_expected=None
index 1801346ce3c22f571cb405c30f5f528b1ef12b38..5ad3c21d4bb1c4f9a60c585771397ac582acc63b 100644 (file)
@@ -39,7 +39,8 @@ class TestMigrationContext(TestBase):
 
     def tearDown(self):
         self.transaction.rollback()
-        version_table.drop(self.connection, checkfirst=True)
+        with self.connection.begin():
+            version_table.drop(self.connection, checkfirst=True)
         self.connection.close()
 
     def make_one(self, **kwargs):
@@ -182,11 +183,16 @@ class UpdateRevTest(TestBase):
         self.context = migration.MigrationContext.configure(
             connection=self.connection, opts={"version_table": "version_table"}
         )
-        version_table.create(self.connection)
+        with self.connection.begin():
+            version_table.create(self.connection)
         self.updater = migration.HeadMaintainer(self.context, ())
 
     def tearDown(self):
-        version_table.drop(self.connection, checkfirst=True)
+        in_t = getattr(self.connection, "in_transaction", lambda: False)
+        if in_t():
+            self.connection.rollback()
+        with self.connection.begin():
+            version_table.drop(self.connection, checkfirst=True)
         self.connection.close()
 
     def _assert_heads(self, heads):
@@ -194,145 +200,176 @@ class UpdateRevTest(TestBase):
         eq_(self.updater.heads, set(heads))
 
     def test_update_none_to_single(self):
-        self.updater.update_to_step(_up(None, "a", True))
-        self._assert_heads(("a",))
+        with self.connection.begin():
+            self.updater.update_to_step(_up(None, "a", True))
+            self._assert_heads(("a",))
 
     def test_update_single_to_single(self):
-        self.updater.update_to_step(_up(None, "a", True))
-        self.updater.update_to_step(_up("a", "b"))
-        self._assert_heads(("b",))
+        with self.connection.begin():
+            self.updater.update_to_step(_up(None, "a", True))
+            self.updater.update_to_step(_up("a", "b"))
+            self._assert_heads(("b",))
 
     def test_update_single_to_none(self):
-        self.updater.update_to_step(_up(None, "a", True))
-        self.updater.update_to_step(_down("a", None, True))
-        self._assert_heads(())
+        with self.connection.begin():
+            self.updater.update_to_step(_up(None, "a", True))
+            self.updater.update_to_step(_down("a", None, True))
+            self._assert_heads(())
 
     def test_add_branches(self):
-        self.updater.update_to_step(_up(None, "a", True))
-        self.updater.update_to_step(_up("a", "b"))
-        self.updater.update_to_step(_up(None, "c", True))
-        self._assert_heads(("b", "c"))
-        self.updater.update_to_step(_up("c", "d"))
-        self.updater.update_to_step(_up("d", "e1"))
-        self.updater.update_to_step(_up("d", "e2", True))
-        self._assert_heads(("b", "e1", "e2"))
+        with self.connection.begin():
+            self.updater.update_to_step(_up(None, "a", True))
+            self.updater.update_to_step(_up("a", "b"))
+            self.updater.update_to_step(_up(None, "c", True))
+            self._assert_heads(("b", "c"))
+            self.updater.update_to_step(_up("c", "d"))
+            self.updater.update_to_step(_up("d", "e1"))
+            self.updater.update_to_step(_up("d", "e2", True))
+            self._assert_heads(("b", "e1", "e2"))
 
     def test_teardown_branches(self):
-        self.updater.update_to_step(_up(None, "d1", True))
-        self.updater.update_to_step(_up(None, "d2", True))
-        self._assert_heads(("d1", "d2"))
+        with self.connection.begin():
+            self.updater.update_to_step(_up(None, "d1", True))
+            self.updater.update_to_step(_up(None, "d2", True))
+            self._assert_heads(("d1", "d2"))
 
-        self.updater.update_to_step(_down("d1", "c"))
-        self._assert_heads(("c", "d2"))
+            self.updater.update_to_step(_down("d1", "c"))
+            self._assert_heads(("c", "d2"))
 
-        self.updater.update_to_step(_down("d2", "c", True))
+            self.updater.update_to_step(_down("d2", "c", True))
 
-        self._assert_heads(("c",))
-        self.updater.update_to_step(_down("c", "b"))
-        self._assert_heads(("b",))
+            self._assert_heads(("c",))
+            self.updater.update_to_step(_down("c", "b"))
+            self._assert_heads(("b",))
 
     def test_resolve_merges(self):
-        self.updater.update_to_step(_up(None, "a", True))
-        self.updater.update_to_step(_up("a", "b"))
-        self.updater.update_to_step(_up("b", "c1"))
-        self.updater.update_to_step(_up("b", "c2", True))
-        self.updater.update_to_step(_up("c1", "d1"))
-        self.updater.update_to_step(_up("c2", "d2"))
-        self._assert_heads(("d1", "d2"))
-        self.updater.update_to_step(_up(("d1", "d2"), "e"))
-        self._assert_heads(("e",))
+        with self.connection.begin():
+            self.updater.update_to_step(_up(None, "a", True))
+            self.updater.update_to_step(_up("a", "b"))
+            self.updater.update_to_step(_up("b", "c1"))
+            self.updater.update_to_step(_up("b", "c2", True))
+            self.updater.update_to_step(_up("c1", "d1"))
+            self.updater.update_to_step(_up("c2", "d2"))
+            self._assert_heads(("d1", "d2"))
+            self.updater.update_to_step(_up(("d1", "d2"), "e"))
+            self._assert_heads(("e",))
 
     def test_unresolve_merges(self):
-        self.updater.update_to_step(_up(None, "e", True))
+        with self.connection.begin():
+            self.updater.update_to_step(_up(None, "e", True))
 
-        self.updater.update_to_step(_down("e", ("d1", "d2")))
-        self._assert_heads(("d2", "d1"))
+            self.updater.update_to_step(_down("e", ("d1", "d2")))
+            self._assert_heads(("d2", "d1"))
 
-        self.updater.update_to_step(_down("d2", "c2"))
-        self._assert_heads(("c2", "d1"))
+            self.updater.update_to_step(_down("d2", "c2"))
+            self._assert_heads(("c2", "d1"))
 
     def test_update_no_match(self):
-        self.updater.update_to_step(_up(None, "a", True))
-        self.updater.heads.add("x")
-        assert_raises_message(
-            CommandError,
-            "Online migration expected to match one row when updating "
-            "'x' to 'b' in 'version_table'; 0 found",
-            self.updater.update_to_step,
-            _up("x", "b"),
-        )
+        with self.connection.begin():
+            self.updater.update_to_step(_up(None, "a", True))
+            self.updater.heads.add("x")
+            assert_raises_message(
+                CommandError,
+                "Online migration expected to match one row when updating "
+                "'x' to 'b' in 'version_table'; 0 found",
+                self.updater.update_to_step,
+                _up("x", "b"),
+            )
 
     def test_update_no_match_no_sane_rowcount(self):
-        self.updater.update_to_step(_up(None, "a", True))
-        self.updater.heads.add("x")
-        with mock.patch.object(
-            self.connection.dialect, "supports_sane_rowcount", False
-        ):
-            self.updater.update_to_step(_up("x", "b"))
+        with self.connection.begin():
+            self.updater.update_to_step(_up(None, "a", True))
+            self.updater.heads.add("x")
+            with mock.patch.object(
+                self.connection.dialect, "supports_sane_rowcount", False
+            ):
+                self.updater.update_to_step(_up("x", "b"))
 
     def test_update_multi_match(self):
-        self.connection.execute(version_table.insert(), version_num="a")
-        self.connection.execute(version_table.insert(), version_num="a")
-
-        self.updater.heads.add("a")
-        assert_raises_message(
-            CommandError,
-            "Online migration expected to match one row when updating "
-            "'a' to 'b' in 'version_table'; 2 found",
-            self.updater.update_to_step,
-            _up("a", "b"),
-        )
+        with self.connection.begin():
+            self.connection.execute(
+                version_table.insert(), dict(version_num="a")
+            )
+            self.connection.execute(
+                version_table.insert(), dict(version_num="a")
+            )
+
+            self.updater.heads.add("a")
+            assert_raises_message(
+                CommandError,
+                "Online migration expected to match one row when updating "
+                "'a' to 'b' in 'version_table'; 2 found",
+                self.updater.update_to_step,
+                _up("a", "b"),
+            )
 
     def test_update_multi_match_no_sane_rowcount(self):
-        self.connection.execute(version_table.insert(), version_num="a")
-        self.connection.execute(version_table.insert(), version_num="a")
-
-        self.updater.heads.add("a")
-        with mock.patch.object(
-            self.connection.dialect, "supports_sane_rowcount", False
-        ):
-            self.updater.update_to_step(_up("a", "b"))
+        with self.connection.begin():
+            self.connection.execute(
+                version_table.insert(), dict(version_num="a")
+            )
+            self.connection.execute(
+                version_table.insert(), dict(version_num="a")
+            )
+
+            self.updater.heads.add("a")
+            with mock.patch.object(
+                self.connection.dialect, "supports_sane_rowcount", False
+            ):
+                self.updater.update_to_step(_up("a", "b"))
 
     def test_delete_no_match(self):
-        self.updater.update_to_step(_up(None, "a", True))
-
-        self.updater.heads.add("x")
-        assert_raises_message(
-            CommandError,
-            "Online migration expected to match one row when "
-            "deleting 'x' in 'version_table'; 0 found",
-            self.updater.update_to_step,
-            _down("x", None, True),
-        )
+        with self.connection.begin():
+            self.updater.update_to_step(_up(None, "a", True))
+
+            self.updater.heads.add("x")
+            assert_raises_message(
+                CommandError,
+                "Online migration expected to match one row when "
+                "deleting 'x' in 'version_table'; 0 found",
+                self.updater.update_to_step,
+                _down("x", None, True),
+            )
 
     def test_delete_no_matchno_sane_rowcount(self):
-        self.updater.update_to_step(_up(None, "a", True))
+        with self.connection.begin():
+            self.updater.update_to_step(_up(None, "a", True))
 
-        self.updater.heads.add("x")
-        with mock.patch.object(
-            self.connection.dialect, "supports_sane_rowcount", False
-        ):
-            self.updater.update_to_step(_down("x", None, True))
+            self.updater.heads.add("x")
+            with mock.patch.object(
+                self.connection.dialect, "supports_sane_rowcount", False
+            ):
+                self.updater.update_to_step(_down("x", None, True))
 
     def test_delete_multi_match(self):
-        self.connection.execute(version_table.insert(), version_num="a")
-        self.connection.execute(version_table.insert(), version_num="a")
-
-        self.updater.heads.add("a")
-        assert_raises_message(
-            CommandError,
-            "Online migration expected to match one row when "
-            "deleting 'a' in 'version_table'; 2 found",
-            self.updater.update_to_step,
-            _down("a", None, True),
-        )
+        with self.connection.begin():
+            self.connection.execute(
+                version_table.insert(), dict(version_num="a")
+            )
+            self.connection.execute(
+                version_table.insert(), dict(version_num="a")
+            )
+
+            self.updater.heads.add("a")
+            assert_raises_message(
+                CommandError,
+                "Online migration expected to match one row when "
+                "deleting 'a' in 'version_table'; 2 found",
+                self.updater.update_to_step,
+                _down("a", None, True),
+            )
 
     def test_delete_multi_match_no_sane_rowcount(self):
-        self.connection.execute(version_table.insert(), version_num="a")
-        self.connection.execute(version_table.insert(), version_num="a")
-
-        self.updater.heads.add("a")
-        with mock.patch.object(
-            self.connection.dialect, "supports_sane_rowcount", False
-        ):
-            self.updater.update_to_step(_down("a", None, True))
+        with self.connection.begin():
+            self.connection.execute(
+                version_table.insert(), dict(version_num="a")
+            )
+            self.connection.execute(
+                version_table.insert(), dict(version_num="a")
+            )
+
+            self.updater.heads.add("a")
+            with mock.patch.object(
+                self.connection.dialect, "supports_sane_rowcount", False
+            ):
+                self.updater.update_to_step(_down("a", None, True))
diff --git a/tox.ini b/tox.ini
index 52f684295fb7e534c9ba5d1531ea6cdb54a727c0..a33263595ccd987d52a66451a9bc631ea9e2d01a 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -1,6 +1,6 @@
 [tox]
 
-envlist = py
+envlist = py-sqlalchemy
 
 SQLA_REPO = {env:SQLA_REPO:git+https://github.com/sqlalchemy/sqlalchemy.git}
 
@@ -10,8 +10,8 @@ cov_args=--cov=alembic --cov-report term --cov-report xml
 deps=pytest>4.6
      pytest-xdist
      mock
-     sqla13: {[tox]SQLA_REPO}@rel_1_3
-     sqlamaster: {[tox]SQLA_REPO}@master
+     sqla13: {[tox]SQLA_REPO}@rel_1_3#egg=sqlalchemy
+     sqlamaster: {[tox]SQLA_REPO}@master#egg=sqlalchemy
      postgresql: psycopg2
      mysql: mysqlclient
      mysql: pymysql
@@ -19,6 +19,10 @@ deps=pytest>4.6
      oracle: cx_oracle>=7;python_version>="3"
      mssql: pymssql
      cov: pytest-cov
+     sqlalchemy: sqlalchemy>=1.3.0
+     mako
+     python-editor>=0.3
+     python-dateutil
 
 
 
@@ -39,6 +43,9 @@ setenv=
     mssql: MSSQL={env:TOX_MSSQL:--db pymssql}
     pyoptimize: PYTHONOPTIMIZE=1
     pyoptimize: LIMITTESTS="tests/test_script_consumption.py"
+    future: SQLALCHEMY_TESTING_FUTURE_ENGINE=1
+    SQLALCHEMY_WARN_20=1
+
 
 # tox as of 2.0 blocks all environment variables from the
 # outside, unless they are here (or in TOX_TESTENV_PASSENV,
@@ -64,4 +71,4 @@ deps=
       black==20.8b1
 commands =
      flake8 ./alembic/ ./tests/ setup.py docs/build/conf.py {posargs}
-     black --check .
+     black --check setup.py tests alembic