]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
reinvent xdist hooks in terms of pytest fixtures
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 10 Jan 2021 18:44:14 +0000 (13:44 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Jan 2021 03:10:13 +0000 (22:10 -0500)
To allow the "connection" pytest fixture and others work
correctly in conjunction with setup/teardown that expects
to be external to the transaction, remove and prevent any usage
of "xdist" style names that are hardcoded by pytest to run
inside of fixtures, even function level ones.   Instead use
pytest autouse fixtures to implement our own
r"setup|teardown_test(?:_class)?" methods so that we can ensure
function-scoped fixtures are run within them.   A new more
explicit flow is set up within plugin_base and pytestplugin
such that the order of setup/teardown steps, which there are now
many, is fully documented and controllable.   New granularity
has been added to the test teardown phase to distinguish
between "end of the test" when lock-holding structures on
connections should be released to allow for table drops,
vs. "end of the test plus its teardown steps" when we can
perform final cleanup on connections and run assertions
that everything is closed out.

From there we can remove most of the defensive "tear down everything"
logic inside of engines which for many years would frequently dispose
of pools over and over again, creating for a broken and expensive
connection flow.  A quick test shows that running test/sql/ against
a single Postgresql engine with the new approach uses 75% fewer new
connections, creating 42 new connections total, vs. 164 new
connections total with the previous system.

As part of this, the new fixtures metadata/connection/future_connection
have been integrated such that they can be combined together
effectively.  The fixture_session(), provide_metadata() fixtures
have been improved, including that fixture_session() now strongly
references sessions which are explicitly torn down before
table drops occur afer a test.

Major changes have been made to the
ConnectionKiller such that it now features different "scopes" for
testing engines and will limit its cleanup to those testing
engines corresponding to end of test, end of test class, or
end of test session.   The system by which it tracks DBAPI
connections has been reworked, is ultimately somewhat similar to
how it worked before but is organized more clearly along
with the proxy-tracking logic.  A "testing_engine" fixture
is also added that works as a pytest fixture rather than a
standalone function.  The connection cleanup logic should
now be very robust, as we now can use the same global
connection pools for the whole suite without ever disposing
them, while also running a query for PostgreSQL
locks remaining after every test and assert there are no open
transactions leaking between tests at all.  Additional steps
are added that also accommodate for asyncio connections not
explicitly closed, as is the case for legacy sync-style
tests as well as the async tests themselves.

As always, hundreds of tests are further refined to use the
new fixtures where problems with loose connections were identified,
largely as a result of the new PostgreSQL assertions,
many more tests have moved from legacy patterns into the newest.

An unfortunate discovery during the creation of this system is that
autouse fixtures (as well as if they are set up by
@pytest.mark.usefixtures) are not usable at our current scale with pytest
4.6.11 running under Python 2.  It's unclear if this is due
to the older version of pytest or how it implements itself for
Python 2, as well as if the issue is CPU slowness or just large
memory use, but collecting the full span of tests takes over
a minute for a single process when any autouse fixtures are in
place and on CI the jobs just time out after ten minutes.
So at the moment this patch also reinvents a small version of
"autouse" fixtures when py2k is running, which skips generating
the real fixture and instead uses two global pytest fixtures
(which don't seem to impact performance) to invoke the
"autouse" fixtures ourselves outside of pytest.
This will limit our ability to do more with fixtures
until we can remove py2k support.

py.test is still observed to be much slower in collection in the
4.6.11 version compared to modern 6.2 versions, so add support for new
TOX_POSTGRESQL_PY2K and TOX_MYSQL_PY2K environment variables that
will run the suite for fewer backends under Python 2.  For Python 3
pin pytest to modern 6.2 versions where performance for collection
has been improved greatly.

Includes the following improvements:

Fixed bug in asyncio connection pool where ``asyncio.TimeoutError`` would
be raised rather than :class:`.exc.TimeoutError`.  Also repaired the
:paramref:`_sa.create_engine.pool_timeout` parameter set to zero when using
the async engine, which previously would ignore the timeout and block
rather than timing out immediately as is the behavior with regular
:class:`.QueuePool`.

For asyncio the connection pool will now also not interact
at all with an asyncio connection whose ConnectionFairy is
being garbage collected; a warning that the connection was
not properly closed is emitted and the connection is discarded.
Within the test suite the ConnectionKiller is now maintaining
strong references to all DBAPI connections and ensuring they
are released when tests end, including those whose ConnectionFairy
proxies are GCed.

Identified cx_Oracle.stmtcachesize as a major factor in Oracle
test scalability issues, this can be reset on a per-test basis
rather than setting it to zero across the board.  the addition
of this flag has resolved the long-standing oracle "two task"
error problem.

For SQL Server, changed the temp table style used by the
"suite" tests to be the double-pound-sign, i.e. global,
variety, which is much easier to test generically.  There
are already reflection tests that are more finely tuned
to both styles of temp table within the mssql test
suite.  Additionally, added an extra step to the
"dropfirst" mechanism for SQL Server that will remove
all foreign key constraints first as some issues were
observed when using this flag when multiple schemas
had not been torn down.

Identified and fixed two subtle failure modes in the
engine, when commit/rollback fails in a begin()
context manager, the connection is explicitly closed,
and when "initialize()" fails on the first new connection
of a dialect, the transactional state on that connection
is still rolled back.

Fixes: #5826
Fixes: #5827
Change-Id: Ib1d05cb8c7cf84f9a4bfd23df397dc23c9329bfe

115 files changed:
doc/build/changelog/unreleased_14/5823.rst [new file with mode: 0644]
doc/build/changelog/unreleased_14/5827.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mssql/provision.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/dialects/oracle/provision.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/provision.py
lib/sqlalchemy/dialects/sqlite/provision.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/create.py
lib/sqlalchemy/future/engine.py
lib/sqlalchemy/pool/base.py
lib/sqlalchemy/testing/__init__.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/engines.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/testing/plugin/bootstrap.py
lib/sqlalchemy/testing/plugin/plugin_base.py
lib/sqlalchemy/testing/plugin/pytestplugin.py
lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py [new file with mode: 0644]
lib/sqlalchemy/testing/provision.py
lib/sqlalchemy/testing/suite/test_reflection.py
lib/sqlalchemy/testing/suite/test_results.py
lib/sqlalchemy/testing/suite/test_types.py
lib/sqlalchemy/testing/util.py
lib/sqlalchemy/util/queue.py
test/aaa_profiling/test_compiler.py
test/aaa_profiling/test_memusage.py
test/aaa_profiling/test_misc.py
test/aaa_profiling/test_orm.py
test/aaa_profiling/test_pool.py
test/base/test_events.py
test/base/test_inspect.py
test/base/test_tutorials.py
test/dialect/mssql/test_compiler.py
test/dialect/mssql/test_deprecations.py
test/dialect/mssql/test_query.py
test/dialect/mysql/test_compiler.py
test/dialect/mysql/test_reflection.py
test/dialect/oracle/test_compiler.py
test/dialect/oracle/test_dialect.py
test/dialect/oracle/test_reflection.py
test/dialect/oracle/test_types.py
test/dialect/postgresql/test_async_pg_py3k.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_dialect.py
test/dialect/postgresql/test_query.py
test/dialect/postgresql/test_reflection.py
test/dialect/postgresql/test_types.py
test/dialect/test_sqlite.py
test/engine/test_ddlevents.py
test/engine/test_deprecations.py
test/engine/test_execute.py
test/engine/test_logging.py
test/engine/test_pool.py
test/engine/test_processors.py
test/engine/test_reconnect.py
test/engine/test_reflection.py
test/engine/test_transaction.py
test/ext/asyncio/test_engine_py3k.py
test/ext/declarative/test_inheritance.py
test/ext/declarative/test_reflection.py
test/ext/test_associationproxy.py
test/ext/test_baked.py
test/ext/test_compiler.py
test/ext/test_extendedattr.py
test/ext/test_horizontal_shard.py
test/ext/test_hybrid.py
test/ext/test_mutable.py
test/ext/test_orderinglist.py
test/orm/declarative/test_basic.py
test/orm/declarative/test_concurrency.py
test/orm/declarative/test_inheritance.py
test/orm/declarative/test_mixin.py
test/orm/declarative/test_reflection.py
test/orm/inheritance/test_basic.py
test/orm/test_attributes.py
test/orm/test_bind.py
test/orm/test_collection.py
test/orm/test_compile.py
test/orm/test_cycles.py
test/orm/test_deprecations.py
test/orm/test_eager_relations.py
test/orm/test_events.py
test/orm/test_froms.py
test/orm/test_lazy_relations.py
test/orm/test_load_on_fks.py
test/orm/test_mapper.py
test/orm/test_options.py
test/orm/test_query.py
test/orm/test_rel_fn.py
test/orm/test_relationships.py
test/orm/test_selectin_relations.py
test/orm/test_session.py
test/orm/test_subquery_relations.py
test/orm/test_transaction.py
test/orm/test_unitofwork.py
test/orm/test_unitofworkv2.py
test/requirements.py
test/sql/test_case_statement.py
test/sql/test_compare.py
test/sql/test_compiler.py
test/sql/test_defaults.py
test/sql/test_deprecations.py
test/sql/test_external_traversal.py
test/sql/test_from_linter.py
test/sql/test_functions.py
test/sql/test_metadata.py
test/sql/test_operators.py
test/sql/test_resultset.py
test/sql/test_sequences.py
test/sql/test_types.py
tox.ini

diff --git a/doc/build/changelog/unreleased_14/5823.rst b/doc/build/changelog/unreleased_14/5823.rst
new file mode 100644 (file)
index 0000000..74debda
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, pool, asyncio
+    :tickets: 5823
+
+    When using an asyncio engine, the connection pool will now detach and
+    discard a pooled connection that is was not explicitly closed/returned to
+    the pool when its tracking object is garbage collected, emitting a warning
+    that the connection was not properly closed.   As this operation occurs
+    during Python gc finalizers, it's not safe to run any IO operations upon
+    the connection including transaction rollback or connection close as this
+    will often be outside of the event loop.
+
+
diff --git a/doc/build/changelog/unreleased_14/5827.rst b/doc/build/changelog/unreleased_14/5827.rst
new file mode 100644 (file)
index 0000000..d5c8acd
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, asyncio
+    :tickets: 5827
+
+    Fixed bug in asyncio connection pool where ``asyncio.TimeoutError`` would
+    be raised rather than :class:`.exc.TimeoutError`.  Also repaired the
+    :paramref:`_sa.create_engine.pool_timeout` parameter set to zero when using
+    the async engine, which previously would ignore the timeout and block
+    rather than timing out immediately as is the behavior with regular
+    :class:`.QueuePool`.
index 538679fcf432ad9c862d921998c6bcf23cb99ba7..0227e515d3035a77aa72b2de31bf585089593db5 100644 (file)
@@ -2785,15 +2785,14 @@ class MSDialect(default.DefaultDialect):
     def has_table(self, connection, tablename, dbname, owner, schema):
         if tablename.startswith("#"):  # temporary table
             tables = ischema.mssql_temp_table_columns
-            result = connection.execute(
-                sql.select(tables.c.table_name)
-                .where(
-                    tables.c.table_name.like(
-                        self._temp_table_name_like_pattern(tablename)
-                    )
+
+            s = sql.select(tables.c.table_name).where(
+                tables.c.table_name.like(
+                    self._temp_table_name_like_pattern(tablename)
                 )
-                .limit(1)
             )
+
+            result = connection.execute(s.limit(1))
             return result.scalar() is not None
         else:
             tables = ischema.tables
index 269eb164f70117a8b01f137c56e9a862b35eb965..56f3305a704105ebecaeec7e9b7a341da2b800a1 100644 (file)
@@ -1,6 +1,14 @@
+from sqlalchemy import inspect
+from sqlalchemy import Integer
 from ... import create_engine
 from ... import exc
+from ...schema import Column
+from ...schema import DropConstraint
+from ...schema import ForeignKeyConstraint
+from ...schema import MetaData
+from ...schema import Table
 from ...testing.provision import create_db
+from ...testing.provision import drop_all_schema_objects_pre_tables
 from ...testing.provision import drop_db
 from ...testing.provision import get_temp_table_name
 from ...testing.provision import log
@@ -38,7 +46,6 @@ def _mssql_drop_ignore(conn, ident):
         #        "where database_id=db_id('%s')" % ident):
         #    log.info("killing SQL server session %s", row['session_id'])
         #    conn.exec_driver_sql("kill %s" % row['session_id'])
-
         conn.exec_driver_sql("drop database %s" % ident)
         log.info("Reaped db: %s", ident)
         return True
@@ -83,4 +90,27 @@ def _mssql_temp_table_keyword_args(cfg, eng):
 
 @get_temp_table_name.for_db("mssql")
 def _mssql_get_temp_table_name(cfg, eng, base_name):
-    return "#" + base_name
+    return "##" + base_name
+
+
+@drop_all_schema_objects_pre_tables.for_db("mssql")
+def drop_all_schema_objects_pre_tables(cfg, eng):
+    with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+        inspector = inspect(conn)
+        for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2):
+            for tname in inspector.get_table_names(schema=schema):
+                tb = Table(
+                    tname,
+                    MetaData(),
+                    Column("x", Integer),
+                    Column("y", Integer),
+                    schema=schema,
+                )
+                for fk in inspect(conn).get_foreign_keys(tname, schema=schema):
+                    conn.execute(
+                        DropConstraint(
+                            ForeignKeyConstraint(
+                                [tb.c.x], [tb.c.y], name=fk["name"]
+                            )
+                        )
+                    )
index 042443692dd5de5718e9acfa400a25616c38910d..b8b4df760c3ce3f16b91f65ec0bbd64582522517 100644 (file)
@@ -93,6 +93,7 @@ The parameters accepted by the cx_oracle dialect are as follows:
 
 * ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail.
 
+
 .. _cx_oracle_unicode:
 
 Unicode
index d51131c0b6169b8bdc74089843faa4cea7153d55..e0dadd58ea0e490d41ac05a3c9f8751f98d4918b 100644 (file)
@@ -6,11 +6,11 @@ from ...testing.provision import create_db
 from ...testing.provision import drop_db
 from ...testing.provision import follower_url_from_main
 from ...testing.provision import log
+from ...testing.provision import post_configure_engine
 from ...testing.provision import run_reap_dbs
 from ...testing.provision import set_default_schema_on_connection
-from ...testing.provision import stop_test_class
+from ...testing.provision import stop_test_class_outside_fixtures
 from ...testing.provision import temp_table_keyword_args
-from ...testing.provision import update_db_opts
 
 
 @create_db.for_db("oracle")
@@ -57,21 +57,39 @@ def _oracle_drop_db(cfg, eng, ident):
         _ora_drop_ignore(conn, "%s_ts2" % ident)
 
 
-@update_db_opts.for_db("oracle")
-def _oracle_update_db_opts(db_url, db_opts):
-    pass
+@stop_test_class_outside_fixtures.for_db("oracle")
+def stop_test_class_outside_fixtures(config, db, cls):
 
+    with db.begin() as conn:
+        # run magic command to get rid of identity sequences
+        # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/  # noqa E501
+        conn.exec_driver_sql("purge recyclebin")
 
-@stop_test_class.for_db("oracle")
-def stop_test_class(config, db, cls):
-    """run magic command to get rid of identity sequences
+    # clear statement cache on all connections that were used
+    # https://github.com/oracle/python-cx_Oracle/issues/519
 
-    # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/
+    for cx_oracle_conn in _all_conns:
+        try:
+            sc = cx_oracle_conn.stmtcachesize
+        except db.dialect.dbapi.InterfaceError:
+            # connection closed
+            pass
+        else:
+            cx_oracle_conn.stmtcachesize = 0
+            cx_oracle_conn.stmtcachesize = sc
+    _all_conns.clear()
 
-    """
 
-    with db.begin() as conn:
-        conn.exec_driver_sql("purge recyclebin")
+_all_conns = set()
+
+
+@post_configure_engine.for_db("oracle")
+def _oracle_post_configure_engine(url, engine, follower_ident):
+    from sqlalchemy import event
+
+    @event.listens_for(engine, "checkout")
+    def checkout(dbapi_con, con_record, con_proxy):
+        _all_conns.add(dbapi_con)
 
 
 @run_reap_dbs.for_db("oracle")
index 7c6e8fb02cc9b2ce44577dedcaad3b0b88f532f9..e542c77f43354424e2aa7c13c77c92fc9d46e88e 100644 (file)
@@ -670,7 +670,6 @@ class AsyncAdapt_asyncpg_connection:
     def rollback(self):
         if self._started:
             self.await_(self._transaction.rollback())
-
             self._transaction = None
             self._started = False
 
index d345cdfdfecf6eeab2879526cf63cb64a83551da..70c3908000318ed44d8119cdccd6585f01313934 100644 (file)
@@ -8,6 +8,7 @@ from ...testing.provision import drop_all_schema_objects_post_tables
 from ...testing.provision import drop_all_schema_objects_pre_tables
 from ...testing.provision import drop_db
 from ...testing.provision import log
+from ...testing.provision import prepare_for_drop_tables
 from ...testing.provision import set_default_schema_on_connection
 from ...testing.provision import temp_table_keyword_args
 
@@ -102,3 +103,23 @@ def drop_all_schema_objects_post_tables(cfg, eng):
                     postgresql.ENUM(name=enum["name"], schema=enum["schema"])
                 )
             )
+
+
+@prepare_for_drop_tables.for_db("postgresql")
+def prepare_for_drop_tables(config, connection):
+    """Ensure there are no locks on the current username/database."""
+
+    result = connection.exec_driver_sql(
+        "select pid, state, wait_event_type, query "
+        # "select pg_terminate_backend(pid), state, wait_event_type "
+        "from pg_stat_activity where "
+        "usename=current_user "
+        "and datname=current_database() and state='idle in transaction' "
+        "and pid != pg_backend_pid()"
+    )
+    rows = result.all()  # noqa
+    assert not rows, (
+        "PostgreSQL may not be able to DROP tables due to "
+        "idle in transaction: %s"
+        % ("; ".join(row._mapping["query"] for row in rows))
+    )
index f26c21e223351123b47b2731e02c305c8efa08e9..a481be27ef0cb0849e5334d4d2a0abb772a35709 100644 (file)
@@ -7,7 +7,7 @@ from ...testing.provision import follower_url_from_main
 from ...testing.provision import log
 from ...testing.provision import post_configure_engine
 from ...testing.provision import run_reap_dbs
-from ...testing.provision import stop_test_class
+from ...testing.provision import stop_test_class_outside_fixtures
 from ...testing.provision import temp_table_keyword_args
 
 
@@ -57,8 +57,8 @@ def _sqlite_drop_db(cfg, eng, ident):
             os.remove(path)
 
 
-@stop_test_class.for_db("sqlite")
-def stop_test_class(config, db, cls):
+@stop_test_class_outside_fixtures.for_db("sqlite")
+def stop_test_class_outside_fixtures(config, db, cls):
     with db.connect() as conn:
         files = [
             row.file
index 50f00c025d82353b837a17732adda82c03cdd2a6..72d66b7c82567a5c83c5fe54463f185eaa02e3dd 100644 (file)
@@ -2729,14 +2729,16 @@ class Engine(Connectable, log.Identified):
             return self.conn
 
         def __exit__(self, type_, value, traceback):
-
-            if type_ is not None:
-                self.transaction.rollback()
-            else:
-                if self.transaction.is_active:
-                    self.transaction.commit()
-            if not self.close_with_result:
-                self.conn.close()
+            try:
+                if type_ is not None:
+                    if self.transaction.is_active:
+                        self.transaction.rollback()
+                else:
+                    if self.transaction.is_active:
+                        self.transaction.commit()
+            finally:
+                if not self.close_with_result:
+                    self.conn.close()
 
     def begin(self, close_with_result=False):
         """Return a context manager delivering a :class:`_engine.Connection`
index f89be1809f32442466bc333bf3d4e7d77868b876..72d232085e3828277f43ec0295c2e4d2901dd837 100644 (file)
@@ -655,9 +655,12 @@ def create_engine(url, **kwargs):
             c = base.Connection(
                 engine, connection=dbapi_connection, _has_events=False
             )
-            c._execution_options = util.immutabledict()
-            dialect.initialize(c)
-            dialect.do_rollback(c.connection)
+            c._execution_options = util.EMPTY_DICT
+
+            try:
+                dialect.initialize(c)
+            finally:
+                dialect.do_rollback(c.connection)
 
         # previously, the "first_connect" event was used here, which was then
         # scaled back if the "on_connect" handler were present.  now,
index d2f609326ae621e956ae38d8c92bed03fd809937..bfdcdfc7f8447482f8f6684063fe012c400caffa 100644 (file)
@@ -368,12 +368,15 @@ class Engine(_LegacyEngine):
             return self.conn
 
         def __exit__(self, type_, value, traceback):
-            if type_ is not None:
-                self.transaction.rollback()
-            else:
-                if self.transaction.is_active:
-                    self.transaction.commit()
-            self.conn.close()
+            try:
+                if type_ is not None:
+                    if self.transaction.is_active:
+                        self.transaction.rollback()
+                else:
+                    if self.transaction.is_active:
+                        self.transaction.commit()
+            finally:
+                self.conn.close()
 
     def begin(self):
         """Return a :class:`_future.Connection` object with a transaction
index 7c9509e452d71efbab5cb45476876f49124dc8d9..6c3aad037fe82c99d81149781b0ec9906da302b0 100644 (file)
@@ -426,6 +426,7 @@ class _ConnectionRecord(object):
                 rec._checkin_failed(err)
         echo = pool._should_log_debug()
         fairy = _ConnectionFairy(dbapi_connection, rec, echo)
+
         rec.fairy_ref = weakref.ref(
             fairy,
             lambda ref: _finalize_fairy
@@ -609,6 +610,15 @@ def _finalize_fairy(
         assert connection is None
         connection = connection_record.connection
 
+    dont_restore_gced = pool._is_asyncio
+
+    if dont_restore_gced:
+        detach = not connection_record or ref
+        can_manipulate_connection = not ref
+    else:
+        detach = not connection_record
+        can_manipulate_connection = True
+
     if connection is not None:
         if connection_record and echo:
             pool.logger.debug(
@@ -620,13 +630,26 @@ def _finalize_fairy(
                 connection, connection_record, echo
             )
             assert fairy.connection is connection
-            fairy._reset(pool)
+            if can_manipulate_connection:
+                fairy._reset(pool)
+
+            if detach:
+                if connection_record:
+                    fairy._pool = pool
+                    fairy.detach()
+
+                if can_manipulate_connection:
+                    if pool.dispatch.close_detached:
+                        pool.dispatch.close_detached(connection)
+
+                    pool._close_connection(connection)
+                else:
+                    util.warn(
+                        "asyncio connection is being garbage "
+                        "collected without being properly closed: %r"
+                        % connection
+                    )
 
-            # Immediately close detached instances
-            if not connection_record:
-                if pool.dispatch.close_detached:
-                    pool.dispatch.close_detached(connection)
-                pool._close_connection(connection)
         except BaseException as e:
             pool.logger.error(
                 "Exception during reset or similar", exc_info=True
index 191252bfbb22e7a4e752a57d4a984053f4dadfde..9f2d0b857cf2b9daa0a5dce2cc02a2321db4d6be 100644 (file)
@@ -29,8 +29,10 @@ from .assertions import in_  # noqa
 from .assertions import is_  # noqa
 from .assertions import is_false  # noqa
 from .assertions import is_instance_of  # noqa
+from .assertions import is_none  # noqa
 from .assertions import is_not  # noqa
 from .assertions import is_not_  # noqa
+from .assertions import is_not_none  # noqa
 from .assertions import is_true  # noqa
 from .assertions import le_  # noqa
 from .assertions import ne_  # noqa
index 0a2aed9d85f6d9a49c13de3282e475d2fa2b4563..db530a961b9c1f5d360e83b9aba8539c57c005a7 100644 (file)
@@ -232,6 +232,14 @@ def is_false(a, msg=None):
     is_(bool(a), False, msg=msg)
 
 
+def is_none(a, msg=None):
+    is_(a, None, msg=msg)
+
+
+def is_not_none(a, msg=None):
+    is_not(a, None, msg=msg)
+
+
 def is_(a, b, msg=None):
     """Assert a is b, with repr messaging on failure."""
     assert a is b, msg or "%r is not %r" % (a, b)
index f64153f338e4020303e99319c7f63e47445e8f0b..750671f9f7c60623514b75ad166c30088d8a40c5 100644 (file)
@@ -97,6 +97,10 @@ def get_current_test_name():
     return _fixture_functions.get_current_test_name()
 
 
+def mark_base_test_class():
+    return _fixture_functions.mark_base_test_class()
+
+
 class Config(object):
     def __init__(self, db, db_opts, options, file_config):
         self._set_name(db)
index a4c1f3973baa7a55f7b5c04e529933b51f2e5164..8b334fde20e635a59d314b369750471351f5ef63 100644 (file)
@@ -7,6 +7,7 @@
 
 from __future__ import absolute_import
 
+import collections
 import re
 import warnings
 import weakref
@@ -20,26 +21,29 @@ from .. import pool
 class ConnectionKiller(object):
     def __init__(self):
         self.proxy_refs = weakref.WeakKeyDictionary()
-        self.testing_engines = weakref.WeakKeyDictionary()
-        self.conns = set()
+        self.testing_engines = collections.defaultdict(set)
+        self.dbapi_connections = set()
 
     def add_pool(self, pool):
-        event.listen(pool, "connect", self.connect)
-        event.listen(pool, "checkout", self.checkout)
-        event.listen(pool, "invalidate", self.invalidate)
-
-    def add_engine(self, engine):
-        self.add_pool(engine.pool)
-        self.testing_engines[engine] = True
+        event.listen(pool, "checkout", self._add_conn)
+        event.listen(pool, "checkin", self._remove_conn)
+        event.listen(pool, "close", self._remove_conn)
+        event.listen(pool, "close_detached", self._remove_conn)
+        # note we are keeping "invalidated" here, as those are still
+        # opened connections we would like to roll back
+
+    def _add_conn(self, dbapi_con, con_record, con_proxy):
+        self.dbapi_connections.add(dbapi_con)
+        self.proxy_refs[con_proxy] = True
 
-    def connect(self, dbapi_conn, con_record):
-        self.conns.add((dbapi_conn, con_record))
+    def _remove_conn(self, dbapi_conn, *arg):
+        self.dbapi_connections.discard(dbapi_conn)
 
-    def checkout(self, dbapi_con, con_record, con_proxy):
-        self.proxy_refs[con_proxy] = True
+    def add_engine(self, engine, scope):
+        self.add_pool(engine.pool)
 
-    def invalidate(self, dbapi_con, con_record, exception):
-        self.conns.discard((dbapi_con, con_record))
+        assert scope in ("class", "global", "function", "fixture")
+        self.testing_engines[scope].add(engine)
 
     def _safe(self, fn):
         try:
@@ -54,53 +58,76 @@ class ConnectionKiller(object):
             if rec is not None and rec.is_valid:
                 self._safe(rec.rollback)
 
-    def close_all(self):
+    def checkin_all(self):
+        # run pool.checkin() for all ConnectionFairy instances we have
+        # tracked.
+
         for rec in list(self.proxy_refs):
             if rec is not None and rec.is_valid:
-                self._safe(rec._close)
-
-    def _after_test_ctx(self):
-        # this can cause a deadlock with pg8000 - pg8000 acquires
-        # prepared statement lock inside of rollback() - if async gc
-        # is collecting in finalize_fairy, deadlock.
-        # not sure if this should be for non-cpython only.
-        # note that firebird/fdb definitely needs this though
-        for conn, rec in list(self.conns):
-            if rec.connection is None:
-                # this is a hint that the connection is closed, which
-                # is causing segfaults on mysqlclient due to
-                # https://github.com/PyMySQL/mysqlclient-python/issues/270;
-                # try to work around here
-                continue
-            self._safe(conn.rollback)
-
-    def _stop_test_ctx(self):
-        if config.options.low_connections:
-            self._stop_test_ctx_minimal()
-        else:
-            self._stop_test_ctx_aggressive()
+                self.dbapi_connections.discard(rec.connection)
+                self._safe(rec._checkin)
 
-    def _stop_test_ctx_minimal(self):
-        self.close_all()
+        # for fairy refs that were GCed and could not close the connection,
+        # such as asyncio, roll back those remaining connections
+        for con in self.dbapi_connections:
+            self._safe(con.rollback)
+        self.dbapi_connections.clear()
 
-        self.conns = set()
+    def close_all(self):
+        self.checkin_all()
 
-        for rec in list(self.testing_engines):
-            if rec is not config.db:
-                rec.dispose()
+    def prepare_for_drop_tables(self, connection):
+        # don't do aggressive checks for third party test suites
+        if not config.bootstrapped_as_sqlalchemy:
+            return
 
-    def _stop_test_ctx_aggressive(self):
-        self.close_all()
-        for conn, rec in list(self.conns):
-            self._safe(conn.close)
-            rec.connection = None
+        from . import provision
+
+        provision.prepare_for_drop_tables(connection.engine.url, connection)
+
+    def _drop_testing_engines(self, scope):
+        eng = self.testing_engines[scope]
+        for rec in list(eng):
+            for proxy_ref in list(self.proxy_refs):
+                if proxy_ref is not None and proxy_ref.is_valid:
+                    if (
+                        proxy_ref._pool is not None
+                        and proxy_ref._pool is rec.pool
+                    ):
+                        self._safe(proxy_ref._checkin)
+            rec.dispose()
+        eng.clear()
+
+    def after_test(self):
+        self._drop_testing_engines("function")
+
+    def after_test_outside_fixtures(self, test):
+        # don't do aggressive checks for third party test suites
+        if not config.bootstrapped_as_sqlalchemy:
+            return
+
+        if test.__class__.__leave_connections_for_teardown__:
+            return
 
-        self.conns = set()
-        for rec in list(self.testing_engines):
-            if hasattr(rec, "sync_engine"):
-                rec.sync_engine.dispose()
-            else:
-                rec.dispose()
+        self.checkin_all()
+
+        # on PostgreSQL, this will test for any "idle in transaction"
+        # connections.   useful to identify tests with unusual patterns
+        # that can't be cleaned up correctly.
+        from . import provision
+
+        with config.db.connect() as conn:
+            provision.prepare_for_drop_tables(conn.engine.url, conn)
+
+    def stop_test_class_inside_fixtures(self):
+        self.checkin_all()
+        self._drop_testing_engines("function")
+        self._drop_testing_engines("class")
+
+    def final_cleanup(self):
+        self.checkin_all()
+        for scope in self.testing_engines:
+            self._drop_testing_engines(scope)
 
     def assert_all_closed(self):
         for rec in self.proxy_refs:
@@ -111,20 +138,6 @@ class ConnectionKiller(object):
 testing_reaper = ConnectionKiller()
 
 
-def drop_all_tables(metadata, bind):
-    testing_reaper.close_all()
-    if hasattr(bind, "close"):
-        bind.close()
-
-    if not config.db.dialect.supports_alter:
-        from . import assertions
-
-        with assertions.expect_warnings("Can't sort tables", assert_=False):
-            metadata.drop_all(bind)
-    else:
-        metadata.drop_all(bind)
-
-
 @decorator
 def assert_conns_closed(fn, *args, **kw):
     try:
@@ -147,7 +160,7 @@ def rollback_open_connections(fn, *args, **kw):
 def close_first(fn, *args, **kw):
     """Decorator that closes all connections before fn execution."""
 
-    testing_reaper.close_all()
+    testing_reaper.checkin_all()
     fn(*args, **kw)
 
 
@@ -157,7 +170,7 @@ def close_open_connections(fn, *args, **kw):
     try:
         fn(*args, **kw)
     finally:
-        testing_reaper.close_all()
+        testing_reaper.checkin_all()
 
 
 def all_dialects(exclude=None):
@@ -239,12 +252,14 @@ def reconnecting_engine(url=None, options=None):
     return engine
 
 
-def testing_engine(url=None, options=None, future=False, asyncio=False):
+def testing_engine(url=None, options=None, future=None, asyncio=False):
     """Produce an engine configured by --options with optional overrides."""
 
     if asyncio:
         from sqlalchemy.ext.asyncio import create_async_engine as create_engine
-    elif future or config.db and config.db._is_future:
+    elif future or (
+        config.db and config.db._is_future and future is not False
+    ):
         from sqlalchemy.future import create_engine
     else:
         from sqlalchemy import create_engine
@@ -252,8 +267,10 @@ def testing_engine(url=None, options=None, future=False, asyncio=False):
 
     if not options:
         use_reaper = True
+        scope = "function"
     else:
         use_reaper = options.pop("use_reaper", True)
+        scope = options.pop("scope", "function")
 
     url = url or config.db.url
 
@@ -268,16 +285,20 @@ def testing_engine(url=None, options=None, future=False, asyncio=False):
         default_opt.update(options)
 
     engine = create_engine(url, **options)
-    if asyncio:
-        engine.sync_engine._has_events = True
-    else:
-        engine._has_events = True  # enable event blocks, helps with profiling
+
+    if scope == "global":
+        if asyncio:
+            engine.sync_engine._has_events = True
+        else:
+            engine._has_events = (
+                True  # enable event blocks, helps with profiling
+            )
 
     if isinstance(engine.pool, pool.QueuePool):
         engine.pool._timeout = 0
-        engine.pool._max_overflow = 5
+        engine.pool._max_overflow = 0
     if use_reaper:
-        testing_reaper.add_engine(engine)
+        testing_reaper.add_engine(engine, scope)
 
     return engine
 
index ac4d3d8fa037cf055079f2818d25aedebdd01150..f19b4652adfe972c4ed34c4877680b728a6a4237 100644 (file)
@@ -5,6 +5,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
+import contextlib
 import re
 import sys
 
@@ -12,12 +13,11 @@ import sqlalchemy as sa
 from . import assertions
 from . import config
 from . import schema
-from .engines import drop_all_tables
-from .engines import testing_engine
 from .entities import BasicEntity
 from .entities import ComparableEntity
 from .entities import ComparableMixin  # noqa
 from .util import adict
+from .util import drop_all_tables_from_metadata
 from .. import event
 from .. import util
 from ..orm import declarative_base
@@ -25,10 +25,8 @@ from ..orm import registry
 from ..orm.decl_api import DeclarativeMeta
 from ..schema import sort_tables_and_constraints
 
-# whether or not we use unittest changes things dramatically,
-# as far as how pytest collection works.
-
 
+@config.mark_base_test_class()
 class TestBase(object):
     # A sequence of database names to always run, regardless of the
     # constraints below.
@@ -48,81 +46,114 @@ class TestBase(object):
     # skipped.
     __skip_if__ = None
 
+    # if True, the testing reaper will not attempt to touch connection
+    # state after a test is completed and before the outer teardown
+    # starts
+    __leave_connections_for_teardown__ = False
+
     def assert_(self, val, msg=None):
         assert val, msg
 
-    # apparently a handful of tests are doing this....OK
-    def setup(self):
-        if hasattr(self, "setUp"):
-            self.setUp()
-
-    def teardown(self):
-        if hasattr(self, "tearDown"):
-            self.tearDown()
-
     @config.fixture()
     def connection(self):
-        eng = getattr(self, "bind", config.db)
+        global _connection_fixture_connection
+
+        eng = getattr(self, "bind", None) or config.db
 
         conn = eng.connect()
         trans = conn.begin()
-        try:
-            yield conn
-        finally:
-            if trans.is_active:
-                trans.rollback()
-            conn.close()
+
+        _connection_fixture_connection = conn
+        yield conn
+
+        _connection_fixture_connection = None
+
+        if trans.is_active:
+            trans.rollback()
+        # trans would not be active here if the test is using
+        # the legacy @provide_metadata decorator still, as it will
+        # run a close all connections.
+        conn.close()
 
     @config.fixture()
-    def future_connection(self):
+    def future_connection(self, future_engine, connection):
+        # integrate the future_engine and connection fixtures so
+        # that users of the "connection" fixture will get at the
+        # "future" connection
+        yield connection
 
-        eng = testing_engine(future=True)
-        conn = eng.connect()
-        trans = conn.begin()
-        try:
-            yield conn
-        finally:
-            if trans.is_active:
-                trans.rollback()
-            conn.close()
+    @config.fixture()
+    def future_engine(self):
+        eng = getattr(self, "bind", None) or config.db
+        with _push_future_engine(eng):
+            yield
+
+    @config.fixture()
+    def testing_engine(self):
+        from . import engines
+
+        def gen_testing_engine(
+            url=None, options=None, future=False, asyncio=False
+        ):
+            if options is None:
+                options = {}
+            options["scope"] = "fixture"
+            return engines.testing_engine(
+                url=url, options=options, future=future, asyncio=asyncio
+            )
+
+        yield gen_testing_engine
+
+        engines.testing_reaper._drop_testing_engines("fixture")
 
     @config.fixture()
-    def metadata(self):
+    def metadata(self, request):
         """Provide bound MetaData for a single test, dropping afterwards."""
 
-        from . import engines
         from ..sql import schema
 
         metadata = schema.MetaData()
-        try:
-            yield metadata
-        finally:
-            engines.drop_all_tables(metadata, config.db)
+        request.instance.metadata = metadata
+        yield metadata
+        del request.instance.metadata
 
+        if (
+            _connection_fixture_connection
+            and _connection_fixture_connection.in_transaction()
+        ):
+            trans = _connection_fixture_connection.get_transaction()
+            trans.rollback()
+            with _connection_fixture_connection.begin():
+                drop_all_tables_from_metadata(
+                    metadata, _connection_fixture_connection
+                )
+        else:
+            drop_all_tables_from_metadata(metadata, config.db)
 
-class FutureEngineMixin(object):
-    @classmethod
-    def setup_class(cls):
 
-        from ..future.engine import Engine
-        from sqlalchemy import testing
+_connection_fixture_connection = None
 
-        facade = Engine._future_facade(config.db)
-        config._current.push_engine(facade, testing)
 
-        super_ = super(FutureEngineMixin, cls)
-        if hasattr(super_, "setup_class"):
-            super_.setup_class()
+@contextlib.contextmanager
+def _push_future_engine(engine):
 
-    @classmethod
-    def teardown_class(cls):
-        super_ = super(FutureEngineMixin, cls)
-        if hasattr(super_, "teardown_class"):
-            super_.teardown_class()
+    from ..future.engine import Engine
+    from sqlalchemy import testing
+
+    facade = Engine._future_facade(engine)
+    config._current.push_engine(facade, testing)
+
+    yield facade
 
-        from sqlalchemy import testing
+    config._current.pop(testing)
 
-        config._current.pop(testing)
+
+class FutureEngineMixin(object):
+    @config.fixture(autouse=True, scope="class")
+    def _push_future_engine(self):
+        eng = getattr(self, "bind", None) or config.db
+        with _push_future_engine(eng):
+            yield
 
 
 class TablesTest(TestBase):
@@ -151,18 +182,32 @@ class TablesTest(TestBase):
     other = None
     sequences = None
 
-    @property
-    def tables_test_metadata(self):
-        return self._tables_metadata
-
-    @classmethod
-    def setup_class(cls):
+    @config.fixture(autouse=True, scope="class")
+    def _setup_tables_test_class(self):
+        cls = self.__class__
         cls._init_class()
 
         cls._setup_once_tables()
 
         cls._setup_once_inserts()
 
+        yield
+
+        cls._teardown_once_metadata_bind()
+
+    @config.fixture(autouse=True, scope="function")
+    def _setup_tables_test_instance(self):
+        self._setup_each_tables()
+        self._setup_each_inserts()
+
+        yield
+
+        self._teardown_each_tables()
+
+    @property
+    def tables_test_metadata(self):
+        return self._tables_metadata
+
     @classmethod
     def _init_class(cls):
         if cls.run_define_tables == "each":
@@ -213,10 +258,10 @@ class TablesTest(TestBase):
         if self.run_define_tables == "each":
             self.tables.clear()
             if self.run_create_tables == "each":
-                drop_all_tables(self._tables_metadata, self.bind)
+                drop_all_tables_from_metadata(self._tables_metadata, self.bind)
             self._tables_metadata.clear()
         elif self.run_create_tables == "each":
-            drop_all_tables(self._tables_metadata, self.bind)
+            drop_all_tables_from_metadata(self._tables_metadata, self.bind)
 
         # no need to run deletes if tables are recreated on setup
         if (
@@ -242,17 +287,10 @@ class TablesTest(TestBase):
                             file=sys.stderr,
                         )
 
-    def setup(self):
-        self._setup_each_tables()
-        self._setup_each_inserts()
-
-    def teardown(self):
-        self._teardown_each_tables()
-
     @classmethod
     def _teardown_once_metadata_bind(cls):
         if cls.run_create_tables:
-            drop_all_tables(cls._tables_metadata, cls.bind)
+            drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
 
         if cls.run_dispose_bind == "once":
             cls.dispose_bind(cls.bind)
@@ -262,10 +300,6 @@ class TablesTest(TestBase):
         if cls.run_setup_bind is not None:
             cls.bind = None
 
-    @classmethod
-    def teardown_class(cls):
-        cls._teardown_once_metadata_bind()
-
     @classmethod
     def setup_bind(cls):
         return config.db
@@ -332,38 +366,47 @@ class RemovesEvents(object):
         self._event_fns.add((target, name, fn))
         event.listen(target, name, fn, **kw)
 
-    def teardown(self):
+    @config.fixture(autouse=True, scope="function")
+    def _remove_events(self):
+        yield
         for key in self._event_fns:
             event.remove(*key)
-        super_ = super(RemovesEvents, self)
-        if hasattr(super_, "teardown"):
-            super_.teardown()
-
-
-class _ORMTest(object):
-    @classmethod
-    def teardown_class(cls):
-        sa.orm.session.close_all_sessions()
-        sa.orm.clear_mappers()
 
 
-def create_session(**kw):
-    kw.setdefault("autoflush", False)
-    kw.setdefault("expire_on_commit", False)
-    return sa.orm.Session(config.db, **kw)
+_fixture_sessions = set()
 
 
 def fixture_session(**kw):
     kw.setdefault("autoflush", True)
     kw.setdefault("expire_on_commit", True)
-    return sa.orm.Session(config.db, **kw)
+    sess = sa.orm.Session(config.db, **kw)
+    _fixture_sessions.add(sess)
+    return sess
+
+
+def _close_all_sessions():
+    # will close all still-referenced sessions
+    sa.orm.session.close_all_sessions()
+    _fixture_sessions.clear()
+
+
+def stop_test_class_inside_fixtures(cls):
+    _close_all_sessions()
+    sa.orm.clear_mappers()
 
 
-class ORMTest(_ORMTest, TestBase):
+def after_test():
+
+    if _fixture_sessions:
+
+        _close_all_sessions()
+
+
+class ORMTest(TestBase):
     pass
 
 
-class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
+class MappedTest(TablesTest, assertions.AssertsExecutionResults):
     # 'once', 'each', None
     run_setup_classes = "once"
 
@@ -372,8 +415,9 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
 
     classes = None
 
-    @classmethod
-    def setup_class(cls):
+    @config.fixture(autouse=True, scope="class")
+    def _setup_tables_test_class(self):
+        cls = self.__class__
         cls._init_class()
 
         if cls.classes is None:
@@ -384,18 +428,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
         cls._setup_once_mappers()
         cls._setup_once_inserts()
 
-    @classmethod
-    def teardown_class(cls):
+        yield
+
         cls._teardown_once_class()
         cls._teardown_once_metadata_bind()
 
-    def setup(self):
+    @config.fixture(autouse=True, scope="function")
+    def _setup_tables_test_instance(self):
         self._setup_each_tables()
         self._setup_each_classes()
         self._setup_each_mappers()
         self._setup_each_inserts()
 
-    def teardown(self):
+        yield
+
         sa.orm.session.close_all_sessions()
         self._teardown_each_mappers()
         self._teardown_each_classes()
@@ -404,7 +450,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
     @classmethod
     def _teardown_once_class(cls):
         cls.classes.clear()
-        _ORMTest.teardown_class()
 
     @classmethod
     def _setup_once_classes(cls):
@@ -440,6 +485,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults):
         """
         cls_registry = cls.classes
 
+        assert cls_registry is not None
+
         class FindFixture(type):
             def __init__(cls, classname, bases, dict_):
                 cls_registry[classname] = cls
index a95c947e2002842d88385e2482d17e805829e64a..1f568dfc8f52638d0e433b8c8e3a756a4d450ffd 100644 (file)
@@ -40,6 +40,11 @@ def load_file_as_module(name):
 
 if to_bootstrap == "pytest":
     sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
+    sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
+    if sys.version_info < (3, 0):
+        sys.modules["sqla_reinvent_fixtures"] = load_file_as_module(
+            "reinvent_fixtures_py2k"
+        )
     sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
 else:
     raise Exception("unknown bootstrap: %s" % to_bootstrap)  # noqa
index 3594cd276dab376315e04a306bbd992c0071c547..7851fbb3ece67429cb8d444fdcf799c6f50083da 100644 (file)
@@ -21,6 +21,9 @@ import logging
 import re
 import sys
 
+# flag which indicates we are in the SQLAlchemy testing suite,
+# and not that of Alembic or a third party dialect.
+bootstrapped_as_sqlalchemy = False
 
 log = logging.getLogger("sqlalchemy.testing.plugin_base")
 
@@ -381,7 +384,7 @@ def _init_symbols(options, file_config):
 
 @post
 def _set_disable_asyncio(opt, file_config):
-    if opt.disable_asyncio:
+    if opt.disable_asyncio or not py3k:
         from sqlalchemy.testing import asyncio
 
         asyncio.ENABLE_ASYNCIO = False
@@ -458,6 +461,8 @@ def _setup_requirements(argument):
 
     config.requirements = testing.requires = req_cls()
 
+    config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
+
 
 @post
 def _prep_testing_database(options, file_config):
@@ -566,17 +571,22 @@ def generate_sub_tests(cls, module):
         yield cls
 
 
-def start_test_class(cls):
+def start_test_class_outside_fixtures(cls):
     _do_skips(cls)
     _setup_engine(cls)
 
 
 def stop_test_class(cls):
-    # from sqlalchemy import inspect
-    # assert not inspect(testing.db).get_table_names()
+    # close sessions, immediate connections, etc.
+    fixtures.stop_test_class_inside_fixtures(cls)
+
+    # close outstanding connection pool connections, dispose of
+    # additional engines
+    engines.testing_reaper.stop_test_class_inside_fixtures()
 
-    provision.stop_test_class(config, config.db, cls)
-    engines.testing_reaper._stop_test_ctx()
+
+def stop_test_class_outside_fixtures(cls):
+    provision.stop_test_class_outside_fixtures(config, config.db, cls)
     try:
         if not options.low_connections:
             assertions.global_cleanup_assertions()
@@ -590,14 +600,16 @@ def _restore_engine():
 
 
 def final_process_cleanup():
-    engines.testing_reaper._stop_test_ctx_aggressive()
+    engines.testing_reaper.final_cleanup()
     assertions.global_cleanup_assertions()
     _restore_engine()
 
 
 def _setup_engine(cls):
     if getattr(cls, "__engine_options__", None):
-        eng = engines.testing_engine(options=cls.__engine_options__)
+        opts = dict(cls.__engine_options__)
+        opts["scope"] = "class"
+        eng = engines.testing_engine(options=opts)
         config._current.push_engine(eng, testing)
 
 
@@ -614,7 +626,12 @@ def before_test(test, test_module_name, test_class, test_name):
 
 
 def after_test(test):
-    engines.testing_reaper._after_test_ctx()
+    fixtures.after_test()
+    engines.testing_reaper.after_test()
+
+
+def after_test_fixtures(test):
+    engines.testing_reaper.after_test_outside_fixtures(test)
 
 
 def _possible_configs_for_cls(cls, reasons=None, sparse=False):
@@ -748,6 +765,10 @@ class FixtureFunctions(ABC):
     def get_current_test_name(self):
         raise NotImplementedError()
 
+    @abc.abstractmethod
+    def mark_base_test_class(self):
+        raise NotImplementedError()
+
 
 _fixture_fn_class = None
 
index 46468a07dcb70ab55f4e000fdd2b1c5acbfe8413..4eaaecebb1c785e45f478404109ba6455c93948c 100644 (file)
@@ -17,6 +17,7 @@ import sys
 
 import pytest
 
+
 try:
     import typing
 except ImportError:
@@ -33,6 +34,14 @@ except ImportError:
     has_xdist = False
 
 
+py2k = sys.version_info < (3, 0)
+if py2k:
+    try:
+        import sqla_reinvent_fixtures as reinvent_fixtures_py2k
+    except ImportError:
+        from . import reinvent_fixtures_py2k
+
+
 def pytest_addoption(parser):
     group = parser.getgroup("sqlalchemy")
 
@@ -238,6 +247,10 @@ def pytest_collection_modifyitems(session, config, items):
         else:
             newitems.append(item)
 
+    if py2k:
+        for item in newitems:
+            reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item)
+
     # seems like the functions attached to a test class aren't sorted already?
     # is that true and why's that? (when using unittest, they're sorted)
     items[:] = sorted(
@@ -251,7 +264,6 @@ def pytest_collection_modifyitems(session, config, items):
 
 
 def pytest_pycollect_makeitem(collector, name, obj):
-
     if inspect.isclass(obj) and plugin_base.want_class(name, obj):
         from sqlalchemy.testing import config
 
@@ -259,7 +271,6 @@ def pytest_pycollect_makeitem(collector, name, obj):
             obj = _apply_maybe_async(obj)
 
         ctor = getattr(pytest.Class, "from_parent", pytest.Class)
-
         return [
             ctor(name=parametrize_cls.__name__, parent=collector)
             for parametrize_cls in _parametrize_cls(collector.module, obj)
@@ -287,12 +298,11 @@ def _is_wrapped_coroutine_function(fn):
 def _apply_maybe_async(obj, recurse=True):
     from sqlalchemy.testing import asyncio
 
-    setup_names = {"setup", "setup_class", "teardown", "teardown_class"}
     for name, value in vars(obj).items():
         if (
             (callable(value) or isinstance(value, classmethod))
             and not getattr(value, "_maybe_async_applied", False)
-            and (name.startswith("test_") or name in setup_names)
+            and (name.startswith("test_"))
             and not _is_wrapped_coroutine_function(value)
         ):
             is_classmethod = False
@@ -317,9 +327,6 @@ def _apply_maybe_async(obj, recurse=True):
     return obj
 
 
-_current_class = None
-
-
 def _parametrize_cls(module, cls):
     """implement a class-based version of pytest parametrize."""
 
@@ -355,63 +362,153 @@ def _parametrize_cls(module, cls):
     return classes
 
 
+_current_class = None
+
+
 def pytest_runtest_setup(item):
     from sqlalchemy.testing import asyncio
 
-    # here we seem to get called only based on what we collected
-    # in pytest_collection_modifyitems.   So to do class-based stuff
-    # we have to tear that out.
-    global _current_class
-
     if not isinstance(item, pytest.Function):
         return
 
-    # ... so we're doing a little dance here to figure it out...
+    # pytest_runtest_setup runs *before* pytest fixtures with scope="class".
+    # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
+    # for the whole class and has to run things that are across all current
+    # databases, so we run this outside of the pytest fixture system altogether
+    # and ensure asyncio greenlet if any engines are async
+
+    global _current_class
+
     if _current_class is None:
-        asyncio._maybe_async(class_setup, item.parent.parent)
+        asyncio._maybe_async_provisioning(
+            plugin_base.start_test_class_outside_fixtures,
+            item.parent.parent.cls,
+        )
         _current_class = item.parent.parent
 
-        # this is needed for the class-level, to ensure that the
-        # teardown runs after the class is completed with its own
-        # class-level teardown...
         def finalize():
             global _current_class
-            asyncio._maybe_async(class_teardown, item.parent.parent)
             _current_class = None
 
+            asyncio._maybe_async_provisioning(
+                plugin_base.stop_test_class_outside_fixtures,
+                item.parent.parent.cls,
+            )
+
         item.parent.parent.addfinalizer(finalize)
 
-    asyncio._maybe_async(test_setup, item)
 
+def pytest_runtest_call(item):
+    # runs inside of pytest function fixture scope
+    # before test function runs
 
-def pytest_runtest_teardown(item):
     from sqlalchemy.testing import asyncio
 
-    # ...but this works better as the hook here rather than
-    # using a finalizer, as the finalizer seems to get in the way
-    # of the test reporting failures correctly (you get a bunch of
-    # pytest assertion stuff instead)
-    asyncio._maybe_async(test_teardown, item)
+    asyncio._maybe_async(
+        plugin_base.before_test,
+        item,
+        item.parent.module.__name__,
+        item.parent.cls,
+        item.name,
+    )
 
 
-def test_setup(item):
-    plugin_base.before_test(
-        item, item.parent.module.__name__, item.parent.cls, item.name
-    )
+def pytest_runtest_teardown(item, nextitem):
+    # runs inside of pytest function fixture scope
+    # after test function runs
 
+    from sqlalchemy.testing import asyncio
 
-def test_teardown(item):
-    plugin_base.after_test(item)
+    asyncio._maybe_async(plugin_base.after_test, item)
 
 
-def class_setup(item):
+@pytest.fixture(scope="class")
+def setup_class_methods(request):
     from sqlalchemy.testing import asyncio
 
-    asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls)
+    cls = request.cls
+
+    if hasattr(cls, "setup_test_class"):
+        asyncio._maybe_async(cls.setup_test_class)
+
+    if py2k:
+        reinvent_fixtures_py2k.run_class_fixture_setup(request)
+
+    yield
+
+    if py2k:
+        reinvent_fixtures_py2k.run_class_fixture_teardown(request)
 
+    if hasattr(cls, "teardown_test_class"):
+        asyncio._maybe_async(cls.teardown_test_class)
 
-def class_teardown(item):
-    plugin_base.stop_test_class(item.cls)
+    asyncio._maybe_async(plugin_base.stop_test_class, cls)
+
+
+@pytest.fixture(scope="function")
+def setup_test_methods(request):
+    from sqlalchemy.testing import asyncio
+
+    # called for each test
+
+    self = request.instance
+
+    # 1. run outer xdist-style setup
+    if hasattr(self, "setup_test"):
+        asyncio._maybe_async(self.setup_test)
+
+    # alembic test suite is using setUp and tearDown
+    # xdist methods; support these in the test suite
+    # for the near term
+    if hasattr(self, "setUp"):
+        asyncio._maybe_async(self.setUp)
+
+    # 2. run homegrown function level "autouse" fixtures under py2k
+    if py2k:
+        reinvent_fixtures_py2k.run_fn_fixture_setup(request)
+
+    # inside the yield:
+
+    # 3. function level "autouse" fixtures under py3k (examples: TablesTest
+    #    define tables / data, MappedTest define tables / mappers / data)
+
+    # 4. function level fixtures defined on test functions themselves,
+    #    e.g. "connection", "metadata" run next
+
+    # 5. pytest hook pytest_runtest_call then runs
+
+    # 6. test itself runs
+
+    yield
+
+    # yield finishes:
+
+    # 7. pytest hook pytest_runtest_teardown hook runs, this is associated
+    #    with fixtures close all sessions, provisioning.stop_test_class(),
+    #    engines.testing_reaper -> ensure all connection pool connections
+    #    are returned, engines created by testing_engine that aren't the
+    #    config engine are disposed
+
+    # 8. function level fixtures defined on test functions
+    #    themselves, e.g. "connection" rolls back the transaction, "metadata"
+    #    emits drop all
+
+    # 9. function level "autouse" fixtures under py3k (examples: TablesTest /
+    #    MappedTest delete table data, possibly drop tables and clear mappers
+    #    depending on the flags defined by the test class)
+
+    # 10. run homegrown function-level "autouse" fixtures under py2k
+    if py2k:
+        reinvent_fixtures_py2k.run_fn_fixture_teardown(request)
+
+    asyncio._maybe_async(plugin_base.after_test_fixtures, self)
+
+    # 11. run outer xdist-style teardown
+    if hasattr(self, "tearDown"):
+        asyncio._maybe_async(self.tearDown)
+
+    if hasattr(self, "teardown_test"):
+        asyncio._maybe_async(self.teardown_test)
 
 
 def getargspec(fn):
@@ -461,6 +558,8 @@ def %(name)s(%(args)s):
             # for the wrapped function
             decorated.__module__ = fn.__module__
             decorated.__name__ = fn.__name__
+            if hasattr(fn, "pytestmark"):
+                decorated.pytestmark = fn.pytestmark
             return decorated
 
     return decorate
@@ -470,6 +569,11 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
     def skip_test_exception(self, *arg, **kw):
         return pytest.skip.Exception(*arg, **kw)
 
+    def mark_base_test_class(self):
+        return pytest.mark.usefixtures(
+            "setup_class_methods", "setup_test_methods"
+        )
+
     _combination_id_fns = {
         "i": lambda obj: obj,
         "r": repr,
@@ -647,8 +751,18 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
                 fn = asyncio._maybe_async_wrapper(fn)
             # other wrappers may be added here
 
-            # now apply FixtureFunctionMarker
-            fn = fixture(fn)
+            if py2k and "autouse" in kw:
+                # py2k workaround for too-slow collection of autouse fixtures
+                # in pytest 4.6.11.  See notes in reinvent_fixtures_py2k for
+                # rationale.
+
+                # comment this condition out in order to disable the
+                # py2k workaround entirely.
+                reinvent_fixtures_py2k.add_fixture(fn, fixture)
+            else:
+                # now apply FixtureFunctionMarker
+                fn = fixture(fn)
+
             return fn
 
         if fn:
diff --git a/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
new file mode 100644 (file)
index 0000000..36b6841
--- /dev/null
@@ -0,0 +1,112 @@
+"""
+invent a quick version of pytest autouse fixtures as pytest's unacceptably slow
+collection/high memory use in pytest 4.6.11, which is the highest version that
+works in py2k.
+
+by "too-slow" we mean the test suite can't even manage to be collected for a
+single process in less than 70 seconds or so and memory use seems to be very
+high as well.   for two or four workers the job just times out after ten
+minutes.
+
+so instead we have invented a very limited form of these fixtures, as our
+current use of "autouse" fixtures are limited to those in fixtures.py.
+
+assumptions for these fixtures:
+
+1. we are only using "function" or "class" scope
+
+2. the functions must be associated with a test class
+
+3. the fixture functions cannot themselves use pytest fixtures
+
+4. the fixture functions must use yield, not return
+
+When py2k support is removed and we can stay on a modern pytest version, this
+can all be removed.
+
+
+"""
+import collections
+
+
+_py2k_fixture_fn_names = collections.defaultdict(set)
+_py2k_class_fixtures = collections.defaultdict(
+    lambda: collections.defaultdict(set)
+)
+_py2k_function_fixtures = collections.defaultdict(
+    lambda: collections.defaultdict(set)
+)
+
+_py2k_cls_fixture_stack = []
+_py2k_fn_fixture_stack = []
+
+
+def add_fixture(fn, fixture):
+    assert fixture.scope in ("class", "function")
+    _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope))
+
+
+def scan_for_fixtures_to_use_for_class(item):
+    test_class = item.parent.parent.obj
+
+    for name in _py2k_fixture_fn_names:
+        for fixture_fn, scope in _py2k_fixture_fn_names[name]:
+            meth = getattr(test_class, name, None)
+            if meth and meth.im_func is fixture_fn:
+                for sup in test_class.__mro__:
+                    if name in sup.__dict__:
+                        if scope == "class":
+                            _py2k_class_fixtures[test_class][sup].add(meth)
+                        elif scope == "function":
+                            _py2k_function_fixtures[test_class][sup].add(meth)
+                        break
+                break
+
+
+def run_class_fixture_setup(request):
+
+    cls = request.cls
+    self = cls.__new__(cls)
+
+    fixtures_for_this_class = _py2k_class_fixtures.get(cls)
+
+    if fixtures_for_this_class:
+        for sup_ in cls.__mro__:
+            for fn in fixtures_for_this_class.get(sup_, ()):
+                iter_ = fn(self)
+                next(iter_)
+
+                _py2k_cls_fixture_stack.append(iter_)
+
+
+def run_class_fixture_teardown(request):
+    while _py2k_cls_fixture_stack:
+        iter_ = _py2k_cls_fixture_stack.pop(-1)
+        try:
+            next(iter_)
+        except StopIteration:
+            pass
+
+
+def run_fn_fixture_setup(request):
+    cls = request.cls
+    self = request.instance
+
+    fixtures_for_this_class = _py2k_function_fixtures.get(cls)
+
+    if fixtures_for_this_class:
+        for sup_ in reversed(cls.__mro__):
+            for fn in fixtures_for_this_class.get(sup_, ()):
+                iter_ = fn(self)
+                next(iter_)
+
+                _py2k_fn_fixture_stack.append(iter_)
+
+
+def run_fn_fixture_teardown(request):
+    while _py2k_fn_fixture_stack:
+        iter_ = _py2k_fn_fixture_stack.pop(-1)
+        try:
+            next(iter_)
+        except StopIteration:
+            pass
index 4ee0567f22402220e7781d36b1b470a96b62a232..2fade1c32d3524748a8d19ad235853d4882b6c0f 100644 (file)
@@ -67,6 +67,7 @@ def setup_config(db_url, options, file_config, follower_ident):
         db_url = follower_url_from_main(db_url, follower_ident)
     db_opts = {}
     update_db_opts(db_url, db_opts)
+    db_opts["scope"] = "global"
     eng = engines.testing_engine(db_url, db_opts)
     post_configure_engine(db_url, eng, follower_ident)
     eng.connect().close()
@@ -264,6 +265,7 @@ def drop_all_schema_objects(cfg, eng):
 
     if config.requirements.schemas.enabled_for_config(cfg):
         util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
+        util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
 
     drop_all_schema_objects_post_tables(cfg, eng)
 
@@ -299,7 +301,7 @@ def update_db_opts(db_url, db_opts):
 def post_configure_engine(url, engine, follower_ident):
     """Perform extra steps after configuring an engine for testing.
 
-    (For the internal dialects, currently only used by sqlite.)
+    (For the internal dialects, currently only used by sqlite, oracle)
     """
     pass
 
@@ -375,7 +377,12 @@ def temp_table_keyword_args(cfg, eng):
 
 
 @register.init
-def stop_test_class(config, db, testcls):
+def prepare_for_drop_tables(config, connection):
+    pass
+
+
+@register.init
+def stop_test_class_outside_fixtures(config, db, testcls):
     pass
 
 
index 6c3c1005ab54b196e5f92199a43e42103d498d2d..de157d028d8dda1fe40ceba8276be8dbaedf934f 100644 (file)
@@ -293,7 +293,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
             from sqlalchemy import pool
 
             return engines.testing_engine(
-                options=dict(poolclass=pool.StaticPool)
+                options=dict(poolclass=pool.StaticPool, scope="class"),
             )
         else:
             return config.db
index e0fdbe47a799b2322b02163b0ca47720129daee1..e8dd6cf2c9e08bd517eee84aebadf28274b6d5fe 100644 (file)
@@ -261,10 +261,6 @@ class ServerSideCursorsTest(
             )
         return self.engine
 
-    def tearDown(self):
-        engines.testing_reaper.close_all()
-        self.engine.dispose()
-
     @testing.combinations(
         ("global_string", True, "select 1", True),
         ("global_text", True, text("select 1"), True),
@@ -309,24 +305,22 @@ class ServerSideCursorsTest(
     def test_conn_option(self):
         engine = self._fixture(False)
 
-        # should be enabled for this one
-        result = (
-            engine.connect()
-            .execution_options(stream_results=True)
-            .exec_driver_sql("select 1")
-        )
-        assert self._is_server_side(result.cursor)
+        with engine.connect() as conn:
+            # should be enabled for this one
+            result = conn.execution_options(
+                stream_results=True
+            ).exec_driver_sql("select 1")
+            assert self._is_server_side(result.cursor)
 
     def test_stmt_enabled_conn_option_disabled(self):
         engine = self._fixture(False)
 
         s = select(1).execution_options(stream_results=True)
 
-        # not this one
-        result = (
-            engine.connect().execution_options(stream_results=False).execute(s)
-        )
-        assert not self._is_server_side(result.cursor)
+        with engine.connect() as conn:
+            # not this one
+            result = conn.execution_options(stream_results=False).execute(s)
+            assert not self._is_server_side(result.cursor)
 
     def test_aliases_and_ss(self):
         engine = self._fixture(False)
@@ -344,8 +338,7 @@ class ServerSideCursorsTest(
             assert not self._is_server_side(result.cursor)
             result.close()
 
-    @testing.provide_metadata
-    def test_roundtrip_fetchall(self):
+    def test_roundtrip_fetchall(self, metadata):
         md = self.metadata
 
         engine = self._fixture(True)
@@ -385,8 +378,7 @@ class ServerSideCursorsTest(
                 0,
             )
 
-    @testing.provide_metadata
-    def test_roundtrip_fetchmany(self):
+    def test_roundtrip_fetchmany(self, metadata):
         md = self.metadata
 
         engine = self._fixture(True)
index 3a5e02c32bbe7595c7f29bf857102ff599cef611..ebcceaae7c06faba434b7baf69ae391b8e2b4ce1 100644 (file)
@@ -511,24 +511,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
     __backend__ = True
 
     @testing.fixture
-    def do_numeric_test(self, metadata):
+    def do_numeric_test(self, metadata, connection):
         @testing.emits_warning(
             r".*does \*not\* support Decimal objects natively"
         )
         def run(type_, input_, output, filter_=None, check_scale=False):
             t = Table("t", metadata, Column("x", type_))
-            t.create(testing.db)
-            with config.db.begin() as conn:
-                conn.execute(t.insert(), [{"x": x} for x in input_])
-
-                result = {row[0] for row in conn.execute(t.select())}
-                output = set(output)
-                if filter_:
-                    result = set(filter_(x) for x in result)
-                    output = set(filter_(x) for x in output)
-                eq_(result, output)
-                if check_scale:
-                    eq_([str(x) for x in result], [str(x) for x in output])
+            t.create(connection)
+            connection.execute(t.insert(), [{"x": x} for x in input_])
+
+            result = {row[0] for row in connection.execute(t.select())}
+            output = set(output)
+            if filter_:
+                result = set(filter_(x) for x in result)
+                output = set(filter_(x) for x in output)
+            eq_(result, output)
+            if check_scale:
+                eq_([str(x) for x in result], [str(x) for x in output])
 
         return run
 
@@ -1165,40 +1164,39 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
                 },
             )
 
-    def test_eval_none_flag_orm(self):
+    def test_eval_none_flag_orm(self, connection):
 
         Base = declarative_base()
 
         class Data(Base):
             __table__ = self.tables.data_table
 
-        s = Session(testing.db)
+        with Session(connection) as s:
+            d1 = Data(name="d1", data=None, nulldata=None)
+            s.add(d1)
+            s.commit()
 
-        d1 = Data(name="d1", data=None, nulldata=None)
-        s.add(d1)
-        s.commit()
-
-        s.bulk_insert_mappings(
-            Data, [{"name": "d2", "data": None, "nulldata": None}]
-        )
-        eq_(
-            s.query(
-                cast(self.tables.data_table.c.data, String()),
-                cast(self.tables.data_table.c.nulldata, String),
+            s.bulk_insert_mappings(
+                Data, [{"name": "d2", "data": None, "nulldata": None}]
             )
-            .filter(self.tables.data_table.c.name == "d1")
-            .first(),
-            ("null", None),
-        )
-        eq_(
-            s.query(
-                cast(self.tables.data_table.c.data, String()),
-                cast(self.tables.data_table.c.nulldata, String),
+            eq_(
+                s.query(
+                    cast(self.tables.data_table.c.data, String()),
+                    cast(self.tables.data_table.c.nulldata, String),
+                )
+                .filter(self.tables.data_table.c.name == "d1")
+                .first(),
+                ("null", None),
+            )
+            eq_(
+                s.query(
+                    cast(self.tables.data_table.c.data, String()),
+                    cast(self.tables.data_table.c.nulldata, String),
+                )
+                .filter(self.tables.data_table.c.name == "d2")
+                .first(),
+                ("null", None),
             )
-            .filter(self.tables.data_table.c.name == "d2")
-            .first(),
-            ("null", None),
-        )
 
 
 class JSONLegacyStringCastIndexTest(
index eb9fcd1cd1a31bee389e1692b1405332408a092e..01185c2841d88542e3b33dbfaafbc60145383f26 100644 (file)
@@ -14,6 +14,7 @@ import types
 from . import config
 from . import mock
 from .. import inspect
+from ..engine import Connection
 from ..schema import Column
 from ..schema import DropConstraint
 from ..schema import DropTable
@@ -207,11 +208,13 @@ def fail(msg):
 
 @decorator
 def provide_metadata(fn, *args, **kw):
-    """Provide bound MetaData for a single test, dropping afterwards."""
+    """Provide bound MetaData for a single test, dropping afterwards.
 
-    # import cycle that only occurs with py2k's import resolver
-    # in py3k this can be moved top level.
-    from . import engines
+    Legacy; use the "metadata" pytest fixture.
+
+    """
+
+    from . import fixtures
 
     metadata = schema.MetaData()
     self = args[0]
@@ -220,7 +223,31 @@ def provide_metadata(fn, *args, **kw):
     try:
         return fn(*args, **kw)
     finally:
-        engines.drop_all_tables(metadata, config.db)
+        # close out some things that get in the way of dropping tables.
+        # when using the "metadata" fixture, there is a set ordering
+        # of things that makes sure things are cleaned up in order, however
+        # the simple "decorator" nature of this legacy function means
+        # we have to hardcode some of that cleanup ahead of time.
+
+        # close ORM sessions
+        fixtures._close_all_sessions()
+
+        # integrate with the "connection" fixture as there are many
+        # tests where it is used along with provide_metadata
+        if fixtures._connection_fixture_connection:
+            # TODO: this warning can be used to find all the places
+            # this is used with connection fixture
+            # warn("mixing legacy provide metadata with connection fixture")
+            drop_all_tables_from_metadata(
+                metadata, fixtures._connection_fixture_connection
+            )
+            # as the provide_metadata fixture is often used with "testing.db",
+            # when we do the drop we have to commit the transaction so that
+            # the DB is actually updated as the CREATE would have been
+            # committed
+            fixtures._connection_fixture_connection.get_transaction().commit()
+        else:
+            drop_all_tables_from_metadata(metadata, config.db)
         self.metadata = prev_meta
 
 
@@ -359,6 +386,29 @@ class adict(dict):
     get_all = __call__
 
 
+def drop_all_tables_from_metadata(metadata, engine_or_connection):
+    from . import engines
+
+    def go(connection):
+        engines.testing_reaper.prepare_for_drop_tables(connection)
+
+        if not connection.dialect.supports_alter:
+            from . import assertions
+
+            with assertions.expect_warnings(
+                "Can't sort tables", assert_=False
+            ):
+                metadata.drop_all(connection)
+        else:
+            metadata.drop_all(connection)
+
+    if not isinstance(engine_or_connection, Connection):
+        with engine_or_connection.begin() as connection:
+            go(connection)
+    else:
+        go(engine_or_connection)
+
+
 def drop_all_tables(engine, inspector, schema=None, include_names=None):
 
     if include_names is not None:
index 99ecb4fb34b67cb0a39037c2190b5a6fdbe27c49..ca5a3abded150ec449333ef5d616f5962b300375 100644 (file)
@@ -230,13 +230,16 @@ class AsyncAdaptedQueue:
             return self.put_nowait(item)
 
         try:
-            if timeout:
+            if timeout is not None:
                 return self.await_(
                     asyncio.wait_for(self._queue.put(item), timeout)
                 )
             else:
                 return self.await_(self._queue.put(item))
-        except asyncio.queues.QueueFull as err:
+        except (
+            asyncio.queues.QueueFull,
+            asyncio.exceptions.TimeoutError,
+        ) as err:
             compat.raise_(
                 Full(),
                 replace_context=err,
@@ -254,14 +257,18 @@ class AsyncAdaptedQueue:
     def get(self, block=True, timeout=None):
         if not block:
             return self.get_nowait()
+
         try:
-            if timeout:
+            if timeout is not None:
                 return self.await_(
                     asyncio.wait_for(self._queue.get(), timeout)
                 )
             else:
                 return self.await_(self._queue.get())
-        except asyncio.queues.QueueEmpty as err:
+        except (
+            asyncio.queues.QueueEmpty,
+            asyncio.exceptions.TimeoutError,
+        ) as err:
             compat.raise_(
                 Empty(),
                 replace_context=err,
index 0202768ae49316eabec806d84136fe8071cab9b4..968a747008968c8644838cd03553259459b0ec2d 100644 (file)
@@ -18,7 +18,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
 
         global t1, t2, metadata
         metadata = MetaData()
index 75a4f51cf8d439561e273120e2ec20a7957ccd19..a41a8b9f11c543d21dbccc48b35412c47eedf5e7 100644 (file)
@@ -241,7 +241,7 @@ def assert_no_mappers():
 
 
 class EnsureZeroed(fixtures.ORMTest):
-    def setup(self):
+    def setup_test(self):
         _sessions.clear()
         _mapper_registry.clear()
 
@@ -1032,7 +1032,7 @@ class MemUsageWBackendTest(EnsureZeroed):
 
             t2_mapper = mapper(T2, t2)
             t1_mapper.add_property("bar", relationship(t2_mapper))
-            s1 = fixture_session()
+            s1 = Session(testing.db)
             # this causes the path_registry to be invoked
             s1.query(t1_mapper)._compile_context()
 
index db6fd4b718109771b009351fa3f6bd6fe4671d28..5b30a3968b25528edf39bb86b90ce599228dcfeb 100644 (file)
@@ -19,7 +19,7 @@ from sqlalchemy.util import classproperty
 class EnumTest(fixtures.TestBase):
     __requires__ = ("cpython", "python_profiling_backend")
 
-    def setup(self):
+    def setup_test(self):
         class SomeEnum(object):
             # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
 
index f163078d80925170dfcc049754279465f87be523..8116e5f215d6f8ada5bbd8fcdc6266c6fd41f19b 100644 (file)
@@ -29,15 +29,13 @@ class NoCache(object):
     run_setup_bind = "each"
 
     @classmethod
-    def setup_class(cls):
-        super(NoCache, cls).setup_class()
+    def setup_test_class(cls):
         cls._cache = config.db._compiled_cache
         config.db._compiled_cache = None
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         config.db._compiled_cache = cls._cache
-        super(NoCache, cls).teardown_class()
 
 
 class MergeTest(NoCache, fixtures.MappedTest):
index fd02f9139527a5f787feb62bdd7fa100957e47fe..da3c1c52560fcf5f114d54c5d956462673b40f8f 100644 (file)
@@ -17,7 +17,7 @@ class QueuePoolTest(fixtures.TestBase, AssertsExecutionResults):
         def close(self):
             pass
 
-    def setup(self):
+    def setup_test(self):
         # create a throwaway pool which
         # has the effect of initializing
         # class-level event listeners on Pool,
index 19f68e9a3509ba9a6dc0354ed752b8bfa5a514b7..68db5207ca0aad17eb218d74a3d6356c78634d16 100644 (file)
@@ -16,7 +16,7 @@ from sqlalchemy.testing.util import gc_collect
 
 
 class TearDownLocalEventsFixture(object):
-    def tearDown(self):
+    def teardown_test(self):
         classes = set()
         for entry in event.base._registrars.values():
             for evt_cls in entry:
@@ -30,7 +30,7 @@ class TearDownLocalEventsFixture(object):
 class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """Test class- and instance-level event registration."""
 
-    def setUp(self):
+    def setup_test(self):
         class TargetEvents(event.Events):
             def event_one(self, x, y):
                 pass
@@ -438,7 +438,7 @@ class NamedCallTest(TearDownLocalEventsFixture, fixtures.TestBase):
 class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """test adaption of legacy args"""
 
-    def setUp(self):
+    def setup_test(self):
         class TargetEventsOne(event.Events):
             @event._legacy_signature("0.9", ["x", "y"])
             def event_three(self, x, y, z, q):
@@ -608,7 +608,7 @@ class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
 
 
 class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
-    def setUp(self):
+    def setup_test(self):
         class TargetEventsOne(event.Events):
             def event_one(self, x, y):
                 pass
@@ -677,7 +677,7 @@ class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
 class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """Test default target acceptance."""
 
-    def setUp(self):
+    def setup_test(self):
         class TargetEventsOne(event.Events):
             def event_one(self, x, y):
                 pass
@@ -734,7 +734,7 @@ class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
 class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """Test custom target acceptance."""
 
-    def setUp(self):
+    def setup_test(self):
         class TargetEvents(event.Events):
             @classmethod
             def _accept_with(cls, target):
@@ -771,7 +771,7 @@ class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
 class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """test that ad-hoc subclasses are garbage collected."""
 
-    def setUp(self):
+    def setup_test(self):
         class TargetEvents(event.Events):
             def some_event(self, x, y):
                 pass
@@ -797,7 +797,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
     """Test custom listen functions which change the listener function
     signature."""
 
-    def setUp(self):
+    def setup_test(self):
         class TargetEvents(event.Events):
             @classmethod
             def _listen(cls, event_key, add=False):
@@ -855,7 +855,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
 
 
 class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
-    def setUp(self):
+    def setup_test(self):
         class TargetEvents(event.Events):
             def event_one(self, arg):
                 pass
@@ -889,7 +889,7 @@ class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
 
 
 class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
-    def setUp(self):
+    def setup_test(self):
         class TargetEvents(event.Events):
             def event_one(self, target, arg):
                 pass
@@ -1109,7 +1109,7 @@ class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
 
 
 class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
-    def setUp(self):
+    def setup_test(self):
         class TargetEvents(event.Events):
             def event_one(self, target, arg):
                 pass
index 15b98c848b54579f7a855abd34941c8bb274cb93..252d0d9777dd9cbc0e8c0695cd705ee9f3047367 100644 (file)
@@ -13,7 +13,7 @@ class TestFixture(object):
 
 
 class TestInspection(fixtures.TestBase):
-    def tearDown(self):
+    def teardown_test(self):
         for type_ in list(inspection._registrars):
             if issubclass(type_, TestFixture):
                 del inspection._registrars[type_]
index 14e87ef6900b867fc19cfdf5e35fd6b5eb1a1da2..6320ef052781a7ca1339944e434e714b5f5ab2f1 100644 (file)
@@ -48,11 +48,11 @@ class DocTest(fixtures.TestBase):
 
         ddl.sort_tables_and_constraints = self.orig_sort
 
-    def setup(self):
+    def setup_test(self):
         self._setup_logger()
         self._setup_create_table_patcher()
 
-    def teardown(self):
+    def teardown_test(self):
         self._teardown_create_table_patcher()
         self._teardown_logger()
 
index 8119612e1a32585eb9ddee5cdfb4c503870b7555..f0bb66aa9f9289263c63fe232aa5d6ebb39f04ea 100644 (file)
@@ -1814,7 +1814,7 @@ class CompileIdentityTest(fixtures.TestBase, AssertsCompiledSQL):
 
 
 class SchemaTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         t = Table(
             "sometable",
             MetaData(),
index c869182c5ab5de3a81d911bdc469a4de5bfa76f9..27709beb055a67b8efde6c87a42040d40b8ef5a3 100644 (file)
@@ -31,7 +31,7 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL):
 
     """
 
-    def setup(self):
+    def setup_test(self):
         metadata = MetaData()
         self.t1 = table(
             "t1",
index cdb37cc61571b9e28ce18013e6af2bc0a76d7ca7..b806b9247f65bf83cb05eb3f39ddc7ad0378394b 100644 (file)
@@ -455,7 +455,7 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
         return testing.db.execution_options(isolation_level="AUTOCOMMIT")
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         with testing.db.connect().execution_options(
             isolation_level="AUTOCOMMIT"
         ) as conn:
@@ -463,7 +463,6 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
                 conn.exec_driver_sql("DROP FULLTEXT CATALOG Catalog")
             except:
                 pass
-        super(MatchTest, cls).setup_class()
 
     @classmethod
     def insert_data(cls, connection):
index 62292b9daad78b379247b1ba3a4c0deeac93a0cb..7fd24e8b51c2c069912c2c1c540dd741a39fe0ee 100644 (file)
@@ -991,7 +991,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL):
 class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = mysql.dialect()
 
-    def setup(self):
+    def setup_test(self):
         self.table = Table(
             "foos",
             MetaData(),
@@ -1062,7 +1062,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
 
 
 class RegexpCommon(testing.AssertsCompiledSQL):
-    def setUp(self):
+    def setup_test(self):
         self.table = table(
             "mytable", column("myid", Integer), column("name", String)
         )
index 40617e59cedfe85d2b5f4a2d04fcd81c62b0b193..795b2cbd328a174b0cb53f35d34848c797ef289a 100644 (file)
@@ -1115,7 +1115,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
 class RawReflectionTest(fixtures.TestBase):
     __backend__ = True
 
-    def setup(self):
+    def setup_test(self):
         dialect = mysql.dialect()
         self.parser = _reflection.MySQLTableDefinitionParser(
             dialect, dialect.identifier_preparer
index 1b8b3fb89b76f929b5cdac4b2cc3d58ea25a81e4..f09346eb32d101d68401f8e452cf6257d4f1d2b7 100644 (file)
@@ -1355,7 +1355,7 @@ class SequenceTest(fixtures.TestBase, AssertsCompiledSQL):
 class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "oracle"
 
-    def setUp(self):
+    def setup_test(self):
         self.table = table(
             "mytable", column("myid", Integer), column("name", String)
         )
index df87fe89fc864b3babb630bf29f22a9f68a433c9..32234bf653b1068bc1eb754695865a88a7cfa091 100644 (file)
@@ -439,7 +439,7 @@ class OutParamTest(fixtures.TestBase, AssertsExecutionResults):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         with testing.db.begin() as c:
             c.exec_driver_sql(
                 """
@@ -471,7 +471,7 @@ end;
         assert isinstance(result.out_parameters["x_out"], int)
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         with testing.db.begin() as conn:
             conn.execute(text("DROP PROCEDURE foo"))
 
index 81e4e4ab5aa3b0de6ddd861fefd4ffe2b0c7f1e3..0df4236e25c4410cd928e0ed7cdf2125d942d962 100644 (file)
@@ -39,7 +39,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         # currently assuming full DBA privs for the user.
         # don't really know how else to go here unless
         # we connect as the other user.
@@ -85,7 +85,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
                     conn.exec_driver_sql(stmt)
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         with testing.db.begin() as conn:
             for stmt in (
                 """
@@ -379,7 +379,7 @@ class SystemTableTablenamesTest(fixtures.TestBase):
     __only_on__ = "oracle"
     __backend__ = True
 
-    def setup(self):
+    def setup_test(self):
         with testing.db.begin() as conn:
             conn.exec_driver_sql("create table my_table (id integer)")
             conn.exec_driver_sql(
@@ -389,7 +389,7 @@ class SystemTableTablenamesTest(fixtures.TestBase):
                 "create table foo_table (id integer) tablespace SYSTEM"
             )
 
-    def teardown(self):
+    def teardown_test(self):
         with testing.db.begin() as conn:
             conn.exec_driver_sql("drop table my_temp_table")
             conn.exec_driver_sql("drop table my_table")
@@ -421,7 +421,7 @@ class DontReflectIOTTest(fixtures.TestBase):
     __only_on__ = "oracle"
     __backend__ = True
 
-    def setup(self):
+    def setup_test(self):
         with testing.db.begin() as conn:
             conn.exec_driver_sql(
                 """
@@ -438,7 +438,7 @@ class DontReflectIOTTest(fixtures.TestBase):
             """,
             )
 
-    def teardown(self):
+    def teardown_test(self):
         with testing.db.begin() as conn:
             conn.exec_driver_sql("drop table admin_docindex")
 
@@ -715,7 +715,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         from sqlalchemy.testing import config
 
         cls.dblink = config.file_config.get("sqla_testing", "oracle_db_link")
@@ -734,7 +734,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
             )
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         with testing.db.begin() as conn:
             conn.exec_driver_sql("drop synonym test_table_syn")
             conn.exec_driver_sql("drop table test_table")
index f008ea01928b74a3fc0f5bfebc32a0061ecaaba2..8ea7c0e044ec06986f6abf803ddb6b331943163c 100644 (file)
@@ -1011,7 +1011,7 @@ class EuroNumericTest(fixtures.TestBase):
     __only_on__ = "oracle+cx_oracle"
     __backend__ = True
 
-    def setup(self):
+    def setup_test(self):
         connect = testing.db.pool._creator
 
         def _creator():
@@ -1023,7 +1023,7 @@ class EuroNumericTest(fixtures.TestBase):
 
         self.engine = testing_engine(options={"creator": _creator})
 
-    def teardown(self):
+    def teardown_test(self):
         self.engine.dispose()
 
     def test_were_getting_a_comma(self):
index fadf939b868d6a2ac3c831724191344d91497852..f6d48f3c65b0a4360d594544f3af3d65f0e1db40 100644 (file)
@@ -27,7 +27,7 @@ class AsyncPgTest(fixtures.TestBase):
         # TODO: remove when Iae6ab95938a7e92b6d42086aec534af27b5577d3
         # merges
 
-        from sqlalchemy.testing import engines
+        from sqlalchemy.testing import util as testing_util
         from sqlalchemy.sql import schema
 
         metadata = schema.MetaData()
@@ -35,7 +35,7 @@ class AsyncPgTest(fixtures.TestBase):
         try:
             yield metadata
         finally:
-            engines.drop_all_tables(metadata, testing.db)
+            testing_util.drop_all_tables_from_metadata(metadata, testing.db)
 
     @async_test
     async def test_detect_stale_ddl_cache_raise_recover(
index 1763b210b2bd1df28aedce46791b10d8ee633786..b3a0b9bbded4adbd8c692e0c079ffbb8fa8517b4 100644 (file)
@@ -1810,7 +1810,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = postgresql.dialect()
 
-    def setup(self):
+    def setup_test(self):
         self.table1 = table1 = table(
             "mytable",
             column("myid", Integer),
@@ -2222,7 +2222,7 @@ class DistinctOnTest(fixtures.TestBase, AssertsCompiledSQL):
 
     __dialect__ = postgresql.dialect()
 
-    def setup(self):
+    def setup_test(self):
         self.table = Table(
             "t",
             MetaData(),
@@ -2373,7 +2373,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
 
     __dialect__ = postgresql.dialect()
 
-    def setup(self):
+    def setup_test(self):
         self.table = Table(
             "t",
             MetaData(),
@@ -2464,7 +2464,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
 class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "postgresql"
 
-    def setUp(self):
+    def setup_test(self):
         self.table = table(
             "mytable", column("myid", Integer), column("name", String)
         )
index f760a309b429b8485e4228c647862086a0775d4b..9c9d817bba958d52fdfe7faba30e8def152e522a 100644 (file)
@@ -198,17 +198,17 @@ class ExecuteManyMode(object):
 
     @config.fixture()
     def connection(self):
-        eng = engines.testing_engine(options=self.options)
+        opts = dict(self.options)
+        opts["use_reaper"] = False
+        eng = engines.testing_engine(options=opts)
 
         conn = eng.connect()
         trans = conn.begin()
-        try:
-            yield conn
-        finally:
-            if trans.is_active:
-                trans.rollback()
-            conn.close()
-            eng.dispose()
+        yield conn
+        if trans.is_active:
+            trans.rollback()
+        conn.close()
+        eng.dispose()
 
     @classmethod
     def define_tables(cls, metadata):
@@ -510,8 +510,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
         # assert result.closed
         assert result.cursor is None
 
-    @testing.provide_metadata
-    def test_insert_returning_preexecute_pk(self, connection):
+    def test_insert_returning_preexecute_pk(self, metadata, connection):
         counter = itertools.count(1)
 
         t = Table(
@@ -525,7 +524,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
             ),
             Column("data", Integer),
         )
-        self.metadata.create_all(connection)
+        metadata.create_all(connection)
 
         result = connection.execute(
             t.insert().return_defaults(),
index 94af168eee0e925dfc9b85a1ee80f02b97615c89..c51fd19432c2e5b36c7fdade006e3ba63603d536 100644 (file)
@@ -40,10 +40,10 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
     __only_on__ = "postgresql"
     __backend__ = True
 
-    def setup(self):
+    def setup_test(self):
         self.metadata = MetaData()
 
-    def teardown(self):
+    def teardown_test(self):
         with testing.db.begin() as conn:
             self.metadata.drop_all(conn)
 
@@ -890,7 +890,7 @@ class ExtractTest(fixtures.TablesTest):
     def setup_bind(cls):
         from sqlalchemy import event
 
-        eng = engines.testing_engine()
+        eng = engines.testing_engine(options={"scope": "class"})
 
         @event.listens_for(eng, "connect")
         def connect(dbapi_conn, rec):
index 754eff25a0bfc82776035a2cdf1cf33afbe8f918..6586a8308d9cbc859ee6ab7ab10def39cc4dadab 100644 (file)
@@ -80,26 +80,24 @@ class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
         ]:
             sa.event.listen(metadata, "before_drop", sa.DDL(ddl))
 
-    def test_foreign_table_is_reflected(self):
+    def test_foreign_table_is_reflected(self, connection):
         metadata = MetaData()
-        table = Table("test_foreigntable", metadata, autoload_with=testing.db)
+        table = Table("test_foreigntable", metadata, autoload_with=connection)
         eq_(
             set(table.columns.keys()),
             set(["id", "data"]),
             "Columns of reflected foreign table didn't equal expected columns",
         )
 
-    def test_get_foreign_table_names(self):
-        inspector = inspect(testing.db)
-        with testing.db.connect():
-            ft_names = inspector.get_foreign_table_names()
-            eq_(ft_names, ["test_foreigntable"])
+    def test_get_foreign_table_names(self, connection):
+        inspector = inspect(connection)
+        ft_names = inspector.get_foreign_table_names()
+        eq_(ft_names, ["test_foreigntable"])
 
-    def test_get_table_names_no_foreign(self):
-        inspector = inspect(testing.db)
-        with testing.db.connect():
-            names = inspector.get_table_names()
-            eq_(names, ["testtable"])
+    def test_get_table_names_no_foreign(self, connection):
+        inspector = inspect(connection)
+        names = inspector.get_table_names()
+        eq_(names, ["testtable"])
 
 
 class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
@@ -133,22 +131,22 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
         if testing.against("postgresql >= 11"):
             Index("my_index", dv.c.q)
 
-    def test_get_tablenames(self):
+    def test_get_tablenames(self, connection):
         assert {"data_values", "data_values_4_10"}.issubset(
-            inspect(testing.db).get_table_names()
+            inspect(connection).get_table_names()
         )
 
-    def test_reflect_cols(self):
-        cols = inspect(testing.db).get_columns("data_values")
+    def test_reflect_cols(self, connection):
+        cols = inspect(connection).get_columns("data_values")
         eq_([c["name"] for c in cols], ["modulus", "data", "q"])
 
-    def test_reflect_cols_from_partition(self):
-        cols = inspect(testing.db).get_columns("data_values_4_10")
+    def test_reflect_cols_from_partition(self, connection):
+        cols = inspect(connection).get_columns("data_values_4_10")
         eq_([c["name"] for c in cols], ["modulus", "data", "q"])
 
     @testing.only_on("postgresql >= 11")
-    def test_reflect_index(self):
-        idx = inspect(testing.db).get_indexes("data_values")
+    def test_reflect_index(self, connection):
+        idx = inspect(connection).get_indexes("data_values")
         eq_(
             idx,
             [
@@ -162,8 +160,8 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
         )
 
     @testing.only_on("postgresql >= 11")
-    def test_reflect_index_from_partition(self):
-        idx = inspect(testing.db).get_indexes("data_values_4_10")
+    def test_reflect_index_from_partition(self, connection):
+        idx = inspect(connection).get_indexes("data_values_4_10")
         # note the name appears to be generated by PG, currently
         # 'data_values_4_10_q_idx'
         eq_(
@@ -220,44 +218,43 @@ class MaterializedViewReflectionTest(
             testtable, "before_drop", sa.DDL("DROP VIEW test_regview")
         )
 
-    def test_mview_is_reflected(self):
+    def test_mview_is_reflected(self, connection):
         metadata = MetaData()
-        table = Table("test_mview", metadata, autoload_with=testing.db)
+        table = Table("test_mview", metadata, autoload_with=connection)
         eq_(
             set(table.columns.keys()),
             set(["id", "data"]),
             "Columns of reflected mview didn't equal expected columns",
         )
 
-    def test_mview_select(self):
+    def test_mview_select(self, connection):
         metadata = MetaData()
-        table = Table("test_mview", metadata, autoload_with=testing.db)
-        with testing.db.connect() as conn:
-            eq_(conn.execute(table.select()).fetchall(), [(89, "d1")])
+        table = Table("test_mview", metadata, autoload_with=connection)
+        eq_(connection.execute(table.select()).fetchall(), [(89, "d1")])
 
-    def test_get_view_names(self):
-        insp = inspect(testing.db)
+    def test_get_view_names(self, connection):
+        insp = inspect(connection)
         eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
 
-    def test_get_view_names_plain(self):
-        insp = inspect(testing.db)
+    def test_get_view_names_plain(self, connection):
+        insp = inspect(connection)
         eq_(
             set(insp.get_view_names(include=("plain",))), set(["test_regview"])
         )
 
-    def test_get_view_names_plain_string(self):
-        insp = inspect(testing.db)
+    def test_get_view_names_plain_string(self, connection):
+        insp = inspect(connection)
         eq_(set(insp.get_view_names(include="plain")), set(["test_regview"]))
 
-    def test_get_view_names_materialized(self):
-        insp = inspect(testing.db)
+    def test_get_view_names_materialized(self, connection):
+        insp = inspect(connection)
         eq_(
             set(insp.get_view_names(include=("materialized",))),
             set(["test_mview"]),
         )
 
-    def test_get_view_names_reflection_cache_ok(self):
-        insp = inspect(testing.db)
+    def test_get_view_names_reflection_cache_ok(self, connection):
+        insp = inspect(connection)
         eq_(
             set(insp.get_view_names(include=("plain",))), set(["test_regview"])
         )
@@ -267,12 +264,12 @@ class MaterializedViewReflectionTest(
         )
         eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
 
-    def test_get_view_names_empty(self):
-        insp = inspect(testing.db)
+    def test_get_view_names_empty(self, connection):
+        insp = inspect(connection)
         assert_raises(ValueError, insp.get_view_names, include=())
 
-    def test_get_view_definition(self):
-        insp = inspect(testing.db)
+    def test_get_view_definition(self, connection):
+        insp = inspect(connection)
         eq_(
             re.sub(
                 r"[\n\t ]+",
@@ -290,7 +287,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         with testing.db.begin() as con:
             for ddl in [
                 'CREATE SCHEMA "SomeSchema"',
@@ -334,7 +331,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             )
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         with testing.db.begin() as con:
             con.exec_driver_sql("DROP TABLE testtable")
             con.exec_driver_sql("DROP TABLE test_schema.testtable")
@@ -350,9 +347,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
             con.exec_driver_sql('DROP SCHEMA "SomeSchema"')
 
-    def test_table_is_reflected(self):
+    def test_table_is_reflected(self, connection):
         metadata = MetaData()
-        table = Table("testtable", metadata, autoload_with=testing.db)
+        table = Table("testtable", metadata, autoload_with=connection)
         eq_(
             set(table.columns.keys()),
             set(["question", "answer"]),
@@ -360,9 +357,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
         )
         assert isinstance(table.c.answer.type, Integer)
 
-    def test_domain_is_reflected(self):
+    def test_domain_is_reflected(self, connection):
         metadata = MetaData()
-        table = Table("testtable", metadata, autoload_with=testing.db)
+        table = Table("testtable", metadata, autoload_with=connection)
         eq_(
             str(table.columns.answer.server_default.arg),
             "42",
@@ -372,28 +369,28 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             not table.columns.answer.nullable
         ), "Expected reflected column to not be nullable."
 
-    def test_enum_domain_is_reflected(self):
+    def test_enum_domain_is_reflected(self, connection):
         metadata = MetaData()
-        table = Table("enum_test", metadata, autoload_with=testing.db)
+        table = Table("enum_test", metadata, autoload_with=connection)
         eq_(table.c.data.type.enums, ["test"])
 
-    def test_array_domain_is_reflected(self):
+    def test_array_domain_is_reflected(self, connection):
         metadata = MetaData()
-        table = Table("array_test", metadata, autoload_with=testing.db)
+        table = Table("array_test", metadata, autoload_with=connection)
         eq_(table.c.data.type.__class__, ARRAY)
         eq_(table.c.data.type.item_type.__class__, INTEGER)
 
-    def test_quoted_remote_schema_domain_is_reflected(self):
+    def test_quoted_remote_schema_domain_is_reflected(self, connection):
         metadata = MetaData()
-        table = Table("quote_test", metadata, autoload_with=testing.db)
+        table = Table("quote_test", metadata, autoload_with=connection)
         eq_(table.c.data.type.__class__, INTEGER)
 
-    def test_table_is_reflected_test_schema(self):
+    def test_table_is_reflected_test_schema(self, connection):
         metadata = MetaData()
         table = Table(
             "testtable",
             metadata,
-            autoload_with=testing.db,
+            autoload_with=connection,
             schema="test_schema",
         )
         eq_(
@@ -403,12 +400,12 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
         )
         assert isinstance(table.c.anything.type, Integer)
 
-    def test_schema_domain_is_reflected(self):
+    def test_schema_domain_is_reflected(self, connection):
         metadata = MetaData()
         table = Table(
             "testtable",
             metadata,
-            autoload_with=testing.db,
+            autoload_with=connection,
             schema="test_schema",
         )
         eq_(
@@ -420,9 +417,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             table.columns.answer.nullable
         ), "Expected reflected column to be nullable."
 
-    def test_crosschema_domain_is_reflected(self):
+    def test_crosschema_domain_is_reflected(self, connection):
         metadata = MetaData()
-        table = Table("crosschema", metadata, autoload_with=testing.db)
+        table = Table("crosschema", metadata, autoload_with=connection)
         eq_(
             str(table.columns.answer.server_default.arg),
             "0",
@@ -432,7 +429,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             table.columns.answer.nullable
         ), "Expected reflected column to be nullable."
 
-    def test_unknown_types(self):
+    def test_unknown_types(self, connection):
         from sqlalchemy.dialects.postgresql import base
 
         ischema_names = base.PGDialect.ischema_names
@@ -440,13 +437,13 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
         try:
             m2 = MetaData()
             assert_raises(
-                exc.SAWarning, Table, "testtable", m2, autoload_with=testing.db
+                exc.SAWarning, Table, "testtable", m2, autoload_with=connection
             )
 
             @testing.emits_warning("Did not recognize type")
             def warns():
                 m3 = MetaData()
-                t3 = Table("testtable", m3, autoload_with=testing.db)
+                t3 = Table("testtable", m3, autoload_with=connection)
                 assert t3.c.answer.type.__class__ == sa.types.NullType
 
         finally:
@@ -471,9 +468,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
         subject = Table("subject", meta2, autoload_with=connection)
         eq_(subject.primary_key.columns.keys(), ["p2", "p1"])
 
-    @testing.provide_metadata
-    def test_pg_weirdchar_reflection(self):
-        meta1 = self.metadata
+    def test_pg_weirdchar_reflection(self, metadata, connection):
+        meta1 = metadata
         subject = Table(
             "subject", meta1, Column("id$", Integer, primary_key=True)
         )
@@ -483,101 +479,91 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             Column("id", Integer, primary_key=True),
             Column("ref", Integer, ForeignKey("subject.id$")),
         )
-        meta1.create_all(testing.db)
+        meta1.create_all(connection)
         meta2 = MetaData()
-        subject = Table("subject", meta2, autoload_with=testing.db)
-        referer = Table("referer", meta2, autoload_with=testing.db)
+        subject = Table("subject", meta2, autoload_with=connection)
+        referer = Table("referer", meta2, autoload_with=connection)
         self.assert_(
             (subject.c["id$"] == referer.c.ref).compare(
                 subject.join(referer).onclause
             )
         )
 
-    @testing.provide_metadata
-    def test_reflect_default_over_128_chars(self):
+    def test_reflect_default_over_128_chars(self, metadata, connection):
         Table(
             "t",
-            self.metadata,
+            metadata,
             Column("x", String(200), server_default="abcd" * 40),
-        ).create(testing.db)
+        ).create(connection)
 
         m = MetaData()
-        t = Table("t", m, autoload_with=testing.db)
+        t = Table("t", m, autoload_with=connection)
         eq_(
             t.c.x.server_default.arg.text,
             "'%s'::character varying" % ("abcd" * 40),
         )
 
-    @testing.fails_if("postgresql < 8.1", "schema name leaks in, not sure")
-    @testing.provide_metadata
-    def test_renamed_sequence_reflection(self):
-        metadata = self.metadata
+    def test_renamed_sequence_reflection(self, metadata, connection):
         Table("t", metadata, Column("id", Integer, primary_key=True))
-        metadata.create_all(testing.db)
+        metadata.create_all(connection)
         m2 = MetaData()
-        t2 = Table("t", m2, autoload_with=testing.db, implicit_returning=False)
+        t2 = Table("t", m2, autoload_with=connection, implicit_returning=False)
         eq_(t2.c.id.server_default.arg.text, "nextval('t_id_seq'::regclass)")
-        with testing.db.begin() as conn:
-            r = conn.execute(t2.insert())
-            eq_(r.inserted_primary_key, (1,))
+        r = connection.execute(t2.insert())
+        eq_(r.inserted_primary_key, (1,))
 
-        with testing.db.begin() as conn:
-            conn.exec_driver_sql(
-                "alter table t_id_seq rename to foobar_id_seq"
-            )
+        connection.exec_driver_sql(
+            "alter table t_id_seq rename to foobar_id_seq"
+        )
         m3 = MetaData()
-        t3 = Table("t", m3, autoload_with=testing.db, implicit_returning=False)
+        t3 = Table("t", m3, autoload_with=connection, implicit_returning=False)
         eq_(
             t3.c.id.server_default.arg.text,
             "nextval('foobar_id_seq'::regclass)",
         )
-        with testing.db.begin() as conn:
-            r = conn.execute(t3.insert())
-            eq_(r.inserted_primary_key, (2,))
+        r = connection.execute(t3.insert())
+        eq_(r.inserted_primary_key, (2,))
 
-    @testing.provide_metadata
-    def test_altered_type_autoincrement_pk_reflection(self):
-        metadata = self.metadata
+    def test_altered_type_autoincrement_pk_reflection(
+        self, metadata, connection
+    ):
+        metadata = metadata
         Table(
             "t",
             metadata,
             Column("id", Integer, primary_key=True),
             Column("x", Integer),
         )
-        metadata.create_all(testing.db)
+        metadata.create_all(connection)
 
-        with testing.db.begin() as conn:
-            conn.exec_driver_sql(
-                "alter table t alter column id type varchar(50)"
-            )
+        connection.exec_driver_sql(
+            "alter table t alter column id type varchar(50)"
+        )
         m2 = MetaData()
-        t2 = Table("t", m2, autoload_with=testing.db)
+        t2 = Table("t", m2, autoload_with=connection)
         eq_(t2.c.id.autoincrement, False)
         eq_(t2.c.x.autoincrement, False)
 
-    @testing.provide_metadata
-    def test_renamed_pk_reflection(self):
-        metadata = self.metadata
+    def test_renamed_pk_reflection(self, metadata, connection):
+        metadata = metadata
         Table("t", metadata, Column("id", Integer, primary_key=True))
-        metadata.create_all(testing.db)
-        with testing.db.begin() as conn:
-            conn.exec_driver_sql("alter table t rename id to t_id")
+        metadata.create_all(connection)
+        connection.exec_driver_sql("alter table t rename id to t_id")
         m2 = MetaData()
-        t2 = Table("t", m2, autoload_with=testing.db)
+        t2 = Table("t", m2, autoload_with=connection)
         eq_([c.name for c in t2.primary_key], ["t_id"])
 
-    @testing.provide_metadata
-    def test_has_temporary_table(self):
-        assert not inspect(testing.db).has_table("some_temp_table")
+    def test_has_temporary_table(self, metadata, connection):
+        assert not inspect(connection).has_table("some_temp_table")
         user_tmp = Table(
             "some_temp_table",
-            self.metadata,
+            metadata,
             Column("id", Integer, primary_key=True),
             Column("name", String(50)),
             prefixes=["TEMPORARY"],
         )
-        user_tmp.create(testing.db)
-        assert inspect(testing.db).has_table("some_temp_table")
+        user_tmp.create(connection)
+        assert inspect(connection).has_table("some_temp_table")
 
     def test_cross_schema_reflection_one(self, metadata, connection):
 
@@ -898,19 +884,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
         A_table.create(connection, checkfirst=True)
         assert inspect(connection).has_table("A")
 
-    def test_uppercase_lowercase_sequence(self):
+    def test_uppercase_lowercase_sequence(self, connection):
 
         a_seq = Sequence("a")
         A_seq = Sequence("A")
 
-        a_seq.create(testing.db)
-        assert testing.db.dialect.has_sequence(testing.db, "a")
-        assert not testing.db.dialect.has_sequence(testing.db, "A")
-        A_seq.create(testing.db, checkfirst=True)
-        assert testing.db.dialect.has_sequence(testing.db, "A")
+        a_seq.create(connection)
+        assert connection.dialect.has_sequence(connection, "a")
+        assert not connection.dialect.has_sequence(connection, "A")
+        A_seq.create(connection, checkfirst=True)
+        assert connection.dialect.has_sequence(connection, "A")
 
-        a_seq.drop(testing.db)
-        A_seq.drop(testing.db)
+        a_seq.drop(connection)
+        A_seq.drop(connection)
 
     def test_index_reflection(self, metadata, connection):
         """Reflecting expression-based indexes should warn"""
@@ -960,11 +946,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             ],
         )
 
-    @testing.provide_metadata
-    def test_index_reflection_partial(self, connection):
+    def test_index_reflection_partial(self, metadata, connection):
         """Reflect the filter defintion on partial indexes"""
 
-        metadata = self.metadata
+        metadata = metadata
 
         t1 = Table(
             "table1",
@@ -978,7 +963,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
 
         metadata.create_all(connection)
 
-        ind = testing.db.dialect.get_indexes(connection, t1, None)
+        ind = connection.dialect.get_indexes(connection, t1, None)
 
         partial_definitions = []
         for ix in ind:
@@ -1073,15 +1058,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             compile_exprs(r3.expressions),
         )
 
-    @testing.provide_metadata
-    def test_index_reflection_modified(self):
+    def test_index_reflection_modified(self, metadata, connection):
         """reflect indexes when a column name has changed - PG 9
         does not update the name of the column in the index def.
         [ticket:2141]
 
         """
 
-        metadata = self.metadata
+        metadata = metadata
 
         Table(
             "t",
@@ -1089,26 +1073,21 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             Column("id", Integer, primary_key=True),
             Column("x", Integer),
         )
-        metadata.create_all(testing.db)
-        with testing.db.begin() as conn:
-            conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
-            conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
+        metadata.create_all(connection)
+        connection.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
+        connection.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
 
-            ind = testing.db.dialect.get_indexes(conn, "t", None)
-            expected = [
-                {"name": "idx1", "unique": False, "column_names": ["y"]}
-            ]
-            if testing.requires.index_reflects_included_columns.enabled:
-                expected[0]["include_columns"] = []
+        ind = connection.dialect.get_indexes(connection, "t", None)
+        expected = [{"name": "idx1", "unique": False, "column_names": ["y"]}]
+        if testing.requires.index_reflects_included_columns.enabled:
+            expected[0]["include_columns"] = []
 
-            eq_(ind, expected)
+        eq_(ind, expected)
 
-    @testing.fails_if("postgresql < 8.2", "reloptions not supported")
-    @testing.provide_metadata
-    def test_index_reflection_with_storage_options(self):
+    def test_index_reflection_with_storage_options(self, metadata, connection):
         """reflect indexes with storage options set"""
 
-        metadata = self.metadata
+        metadata = metadata
 
         Table(
             "t",
@@ -1116,70 +1095,63 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             Column("id", Integer, primary_key=True),
             Column("x", Integer),
         )
-        metadata.create_all(testing.db)
+        metadata.create_all(connection)
 
-        with testing.db.begin() as conn:
-            conn.exec_driver_sql(
-                "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
-            )
+        connection.exec_driver_sql(
+            "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
+        )
 
-            ind = testing.db.dialect.get_indexes(conn, "t", None)
+        ind = testing.db.dialect.get_indexes(connection, "t", None)
 
-            expected = [
-                {
-                    "unique": False,
-                    "column_names": ["x"],
-                    "name": "idx1",
-                    "dialect_options": {
-                        "postgresql_with": {"fillfactor": "50"}
-                    },
-                }
-            ]
-            if testing.requires.index_reflects_included_columns.enabled:
-                expected[0]["include_columns"] = []
-            eq_(ind, expected)
+        expected = [
+            {
+                "unique": False,
+                "column_names": ["x"],
+                "name": "idx1",
+                "dialect_options": {"postgresql_with": {"fillfactor": "50"}},
+            }
+        ]
+        if testing.requires.index_reflects_included_columns.enabled:
+            expected[0]["include_columns"] = []
+        eq_(ind, expected)
 
-            m = MetaData()
-            t1 = Table("t", m, autoload_with=conn)
-            eq_(
-                list(t1.indexes)[0].dialect_options["postgresql"]["with"],
-                {"fillfactor": "50"},
-            )
+        m = MetaData()
+        t1 = Table("t", m, autoload_with=connection)
+        eq_(
+            list(t1.indexes)[0].dialect_options["postgresql"]["with"],
+            {"fillfactor": "50"},
+        )
 
-    @testing.provide_metadata
-    def test_index_reflection_with_access_method(self):
+    def test_index_reflection_with_access_method(self, metadata, connection):
         """reflect indexes with storage options set"""
 
-        metadata = self.metadata
-
         Table(
             "t",
             metadata,
             Column("id", Integer, primary_key=True),
             Column("x", ARRAY(Integer)),
         )
-        metadata.create_all(testing.db)
-        with testing.db.begin() as conn:
-            conn.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
+        metadata.create_all(connection)
+        connection.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
 
-            ind = testing.db.dialect.get_indexes(conn, "t", None)
-            expected = [
-                {
-                    "unique": False,
-                    "column_names": ["x"],
-                    "name": "idx1",
-                    "dialect_options": {"postgresql_using": "gin"},
-                }
-            ]
-            if testing.requires.index_reflects_included_columns.enabled:
-                expected[0]["include_columns"] = []
-            eq_(ind, expected)
-            m = MetaData()
-            t1 = Table("t", m, autoload_with=conn)
-            eq_(
-                list(t1.indexes)[0].dialect_options["postgresql"]["using"],
-                "gin",
-            )
+        ind = testing.db.dialect.get_indexes(connection, "t", None)
+        expected = [
+            {
+                "unique": False,
+                "column_names": ["x"],
+                "name": "idx1",
+                "dialect_options": {"postgresql_using": "gin"},
+            }
+        ]
+        if testing.requires.index_reflects_included_columns.enabled:
+            expected[0]["include_columns"] = []
+        eq_(ind, expected)
+        m = MetaData()
+        t1 = Table("t", m, autoload_with=connection)
+        eq_(
+            list(t1.indexes)[0].dialect_options["postgresql"]["using"],
+            "gin",
+        )
 
     @testing.skip_if("postgresql < 11.0", "indnkeyatts not supported")
     def test_index_reflection_with_include(self, metadata, connection):
@@ -1199,7 +1171,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
         # [{'column_names': ['x', 'name'],
         #  'name': 'idx1', 'unique': False}]
 
-        ind = testing.db.dialect.get_indexes(connection, "t", None)
+        ind = connection.dialect.get_indexes(connection, "t", None)
         eq_(
             ind,
             [
@@ -1286,15 +1258,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
         for fk in fks:
             eq_(fk, fk_ref[fk["name"]])
 
-    @testing.provide_metadata
-    def test_inspect_enums_schema(self, connection):
+    def test_inspect_enums_schema(self, metadata, connection):
         enum_type = postgresql.ENUM(
             "sad",
             "ok",
             "happy",
             name="mood",
             schema="test_schema",
-            metadata=self.metadata,
+            metadata=metadata,
         )
         enum_type.create(connection)
         inspector = inspect(connection)
@@ -1310,13 +1281,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             ],
         )
 
-    @testing.provide_metadata
-    def test_inspect_enums(self):
+    def test_inspect_enums(self, metadata, connection):
         enum_type = postgresql.ENUM(
-            "cat", "dog", "rat", name="pet", metadata=self.metadata
+            "cat", "dog", "rat", name="pet", metadata=metadata
         )
-        enum_type.create(testing.db)
-        inspector = inspect(testing.db)
+        enum_type.create(connection)
+        inspector = inspect(connection)
         eq_(
             inspector.get_enums(),
             [
@@ -1329,17 +1299,16 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             ],
         )
 
-    @testing.provide_metadata
-    def test_inspect_enums_case_sensitive(self):
+    def test_inspect_enums_case_sensitive(self, metadata, connection):
         sa.event.listen(
-            self.metadata,
+            metadata,
             "before_create",
             sa.DDL('create schema "TestSchema"'),
         )
         sa.event.listen(
-            self.metadata,
+            metadata,
             "after_drop",
-            sa.DDL('drop schema "TestSchema" cascade'),
+            sa.DDL('drop schema if exists "TestSchema" cascade'),
         )
 
         for enum in "lower_case", "UpperCase", "Name.With.Dot":
@@ -1350,11 +1319,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
                     "CapsTwo",
                     name=enum,
                     schema=schema,
-                    metadata=self.metadata,
+                    metadata=metadata,
                 )
 
-        self.metadata.create_all(testing.db)
-        inspector = inspect(testing.db)
+        metadata.create_all(connection)
+        inspector = inspect(connection)
         for schema in None, "test_schema", "TestSchema":
             eq_(
                 sorted(
@@ -1382,17 +1351,18 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
                 ],
             )
 
-    @testing.provide_metadata
-    def test_inspect_enums_case_sensitive_from_table(self):
+    def test_inspect_enums_case_sensitive_from_table(
+        self, metadata, connection
+    ):
         sa.event.listen(
-            self.metadata,
+            metadata,
             "before_create",
             sa.DDL('create schema "TestSchema"'),
         )
         sa.event.listen(
-            self.metadata,
+            metadata,
             "after_drop",
-            sa.DDL('drop schema "TestSchema" cascade'),
+            sa.DDL('drop schema if exists "TestSchema" cascade'),
         )
 
         counter = itertools.count()
@@ -1403,19 +1373,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
                     "CapsOne",
                     "CapsTwo",
                     name=enum,
-                    metadata=self.metadata,
+                    metadata=metadata,
                     schema=schema,
                 )
 
                 Table(
                     "t%d" % next(counter),
-                    self.metadata,
+                    metadata,
                     Column("q", enum_type),
                 )
 
-        self.metadata.create_all(testing.db)
+        metadata.create_all(connection)
 
-        inspector = inspect(testing.db)
+        inspector = inspect(connection)
         counter = itertools.count()
         for enum in "lower_case", "UpperCase", "Name.With.Dot":
             for schema in None, "test_schema", "TestSchema":
@@ -1439,10 +1409,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
                     ],
                 )
 
-    @testing.provide_metadata
-    def test_inspect_enums_star(self):
+    def test_inspect_enums_star(self, metadata, connection):
         enum_type = postgresql.ENUM(
-            "cat", "dog", "rat", name="pet", metadata=self.metadata
+            "cat", "dog", "rat", name="pet", metadata=metadata
         )
         schema_enum_type = postgresql.ENUM(
             "sad",
@@ -1450,11 +1419,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             "happy",
             name="mood",
             schema="test_schema",
-            metadata=self.metadata,
+            metadata=metadata,
         )
-        enum_type.create(testing.db)
-        schema_enum_type.create(testing.db)
-        inspector = inspect(testing.db)
+        enum_type.create(connection)
+        schema_enum_type.create(connection)
+        inspector = inspect(connection)
 
         eq_(
             inspector.get_enums(),
@@ -1486,11 +1455,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             ],
         )
 
-    @testing.provide_metadata
-    def test_inspect_enum_empty(self):
-        enum_type = postgresql.ENUM(name="empty", metadata=self.metadata)
-        enum_type.create(testing.db)
-        inspector = inspect(testing.db)
+    def test_inspect_enum_empty(self, metadata, connection):
+        enum_type = postgresql.ENUM(name="empty", metadata=metadata)
+        enum_type.create(connection)
+        inspector = inspect(connection)
 
         eq_(
             inspector.get_enums(),
@@ -1504,13 +1472,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             ],
         )
 
-    @testing.provide_metadata
-    def test_inspect_enum_empty_from_table(self):
+    def test_inspect_enum_empty_from_table(self, metadata, connection):
         Table(
-            "t", self.metadata, Column("x", postgresql.ENUM(name="empty"))
-        ).create(testing.db)
+            "t", metadata, Column("x", postgresql.ENUM(name="empty"))
+        ).create(connection)
 
-        t = Table("t", MetaData(), autoload_with=testing.db)
+        t = Table("t", MetaData(), autoload_with=connection)
         eq_(t.c.x.type.enums, [])
 
     def test_reflection_with_unique_constraint(self, metadata, connection):
@@ -1749,12 +1716,12 @@ class CustomTypeReflectionTest(fixtures.TestBase):
 
     ischema_names = None
 
-    def setup(self):
+    def setup_test(self):
         ischema_names = postgresql.PGDialect.ischema_names
         postgresql.PGDialect.ischema_names = ischema_names.copy()
         self.ischema_names = ischema_names
 
-    def teardown(self):
+    def teardown_test(self):
         postgresql.PGDialect.ischema_names = self.ischema_names
         self.ischema_names = None
 
@@ -1788,55 +1755,51 @@ class IntervalReflectionTest(fixtures.TestBase):
     __only_on__ = "postgresql"
     __backend__ = True
 
-    def test_interval_types(self):
-        for sym in [
-            "YEAR",
-            "MONTH",
-            "DAY",
-            "HOUR",
-            "MINUTE",
-            "SECOND",
-            "YEAR TO MONTH",
-            "DAY TO HOUR",
-            "DAY TO MINUTE",
-            "DAY TO SECOND",
-            "HOUR TO MINUTE",
-            "HOUR TO SECOND",
-            "MINUTE TO SECOND",
-        ]:
-            self._test_interval_symbol(sym)
-
-    @testing.provide_metadata
-    def _test_interval_symbol(self, sym):
+    @testing.combinations(
+        ("YEAR",),
+        ("MONTH",),
+        ("DAY",),
+        ("HOUR",),
+        ("MINUTE",),
+        ("SECOND",),
+        ("YEAR TO MONTH",),
+        ("DAY TO HOUR",),
+        ("DAY TO MINUTE",),
+        ("DAY TO SECOND",),
+        ("HOUR TO MINUTE",),
+        ("HOUR TO SECOND",),
+        ("MINUTE TO SECOND",),
+        argnames="sym",
+    )
+    def test_interval_types(self, sym, metadata, connection):
         t = Table(
             "i_test",
-            self.metadata,
+            metadata,
             Column("id", Integer, primary_key=True),
             Column("data1", INTERVAL(fields=sym)),
         )
-        t.create(testing.db)
+        t.create(connection)
 
         columns = {
             rec["name"]: rec
-            for rec in inspect(testing.db).get_columns("i_test")
+            for rec in inspect(connection).get_columns("i_test")
         }
         assert isinstance(columns["data1"]["type"], INTERVAL)
         eq_(columns["data1"]["type"].fields, sym.lower())
         eq_(columns["data1"]["type"].precision, None)
 
-    @testing.provide_metadata
-    def test_interval_precision(self):
+    def test_interval_precision(self, metadata, connection):
         t = Table(
             "i_test",
-            self.metadata,
+            metadata,
             Column("id", Integer, primary_key=True),
             Column("data1", INTERVAL(precision=6)),
         )
-        t.create(testing.db)
+        t.create(connection)
 
         columns = {
             rec["name"]: rec
-            for rec in inspect(testing.db).get_columns("i_test")
+            for rec in inspect(connection).get_columns("i_test")
         }
         assert isinstance(columns["data1"]["type"], INTERVAL)
         eq_(columns["data1"]["type"].fields, None)
@@ -1871,8 +1834,8 @@ class IdentityReflectionTest(fixtures.TablesTest):
             Column("id4", SmallInteger, Identity()),
         )
 
-    def test_reflect_identity(self):
-        insp = inspect(testing.db)
+    def test_reflect_identity(self, connection):
+        insp = inspect(connection)
         default = dict(
             always=False,
             start=1,
index e8a1876c7a3852e933389cba2cf50d6d69ff8331..6202f8f868b71018b4324fcc2d3ce099a8204993 100644 (file)
@@ -49,7 +49,6 @@ from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.orm import Session
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import sqltypes
-from sqlalchemy.testing import engines
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import assert_raises
 from sqlalchemy.testing.assertions import assert_raises_message
@@ -156,8 +155,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
 
     __only_on__ = "postgresql > 8.3"
 
-    @testing.provide_metadata
-    def test_create_table(self, connection):
+    def test_create_table(self, metadata, connection):
         metadata = self.metadata
         t1 = Table(
             "table",
@@ -177,8 +175,8 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
             [(1, "two"), (2, "three"), (3, "three")],
         )
 
-    @testing.combinations(None, "foo")
-    def test_create_table_schema_translate_map(self, symbol_name):
+    @testing.combinations(None, "foo", argnames="symbol_name")
+    def test_create_table_schema_translate_map(self, connection, symbol_name):
         # note we can't use the fixture here because it will not drop
         # from the correct schema
         metadata = MetaData()
@@ -199,35 +197,30 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
             ),
             schema=symbol_name,
         )
-        with testing.db.begin() as conn:
-            conn = conn.execution_options(
-                schema_translate_map={symbol_name: testing.config.test_schema}
-            )
-            t1.create(conn)
-            assert "schema_enum" in [
-                e["name"]
-                for e in inspect(conn).get_enums(
-                    schema=testing.config.test_schema
-                )
-            ]
-            t1.create(conn, checkfirst=True)
+        conn = connection.execution_options(
+            schema_translate_map={symbol_name: testing.config.test_schema}
+        )
+        t1.create(conn)
+        assert "schema_enum" in [
+            e["name"]
+            for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+        ]
+        t1.create(conn, checkfirst=True)
 
-            conn.execute(t1.insert(), value="two")
-            conn.execute(t1.insert(), value="three")
-            conn.execute(t1.insert(), value="three")
-            eq_(
-                conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
-                [(1, "two"), (2, "three"), (3, "three")],
-            )
+        conn.execute(t1.insert(), value="two")
+        conn.execute(t1.insert(), value="three")
+        conn.execute(t1.insert(), value="three")
+        eq_(
+            conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
+            [(1, "two"), (2, "three"), (3, "three")],
+        )
 
-            t1.drop(conn)
-            assert "schema_enum" not in [
-                e["name"]
-                for e in inspect(conn).get_enums(
-                    schema=testing.config.test_schema
-                )
-            ]
-            t1.drop(conn, checkfirst=True)
+        t1.drop(conn)
+        assert "schema_enum" not in [
+            e["name"]
+            for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+        ]
+        t1.drop(conn, checkfirst=True)
 
     def test_name_required(self, metadata, connection):
         etype = Enum("four", "five", "six", metadata=metadata)
@@ -270,8 +263,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
             [util.u("réveillé"), util.u("drôle"), util.u("S’il")],
         )
 
-    @testing.provide_metadata
-    def test_non_native_enum(self, connection):
+    def test_non_native_enum(self, metadata, connection):
         metadata = self.metadata
         t1 = Table(
             "foo",
@@ -290,10 +282,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         )
 
         def go():
-            t1.create(testing.db)
+            t1.create(connection)
 
         self.assert_sql(
-            testing.db,
+            connection,
             go,
             [
                 (
@@ -307,8 +299,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         connection.execute(t1.insert(), {"bar": "two"})
         eq_(connection.scalar(select(t1.c.bar)), "two")
 
-    @testing.provide_metadata
-    def test_non_native_enum_w_unicode(self, connection):
+    def test_non_native_enum_w_unicode(self, metadata, connection):
         metadata = self.metadata
         t1 = Table(
             "foo",
@@ -326,10 +317,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         )
 
         def go():
-            t1.create(testing.db)
+            t1.create(connection)
 
         self.assert_sql(
-            testing.db,
+            connection,
             go,
             [
                 (
@@ -346,8 +337,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         connection.execute(t1.insert(), {"bar": util.u("Ü")})
         eq_(connection.scalar(select(t1.c.bar)), util.u("Ü"))
 
-    @testing.provide_metadata
-    def test_disable_create(self):
+    def test_disable_create(self, metadata, connection):
         metadata = self.metadata
 
         e1 = postgresql.ENUM(
@@ -357,13 +347,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         t1 = Table("e1", metadata, Column("c1", e1))
         # table can be created separately
         # without conflict
-        e1.create(bind=testing.db)
-        t1.create(testing.db)
-        t1.drop(testing.db)
-        e1.drop(bind=testing.db)
+        e1.create(bind=connection)
+        t1.create(connection)
+        t1.drop(connection)
+        e1.drop(bind=connection)
 
-    @testing.provide_metadata
-    def test_dont_keep_checking(self, connection):
+    def test_dont_keep_checking(self, metadata, connection):
         metadata = self.metadata
 
         e1 = postgresql.ENUM("one", "two", "three", name="myenum")
@@ -560,11 +549,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
             e["name"] for e in inspect(connection).get_enums()
         ]
 
-    def test_non_native_dialect(self):
-        engine = engines.testing_engine()
+    def test_non_native_dialect(self, metadata, testing_engine):
+        engine = testing_engine()
         engine.connect()
         engine.dialect.supports_native_enum = False
-        metadata = MetaData()
         t1 = Table(
             "foo",
             metadata,
@@ -583,21 +571,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         def go():
             t1.create(engine)
 
-        try:
-            self.assert_sql(
-                engine,
-                go,
-                [
-                    (
-                        "CREATE TABLE foo (bar "
-                        "VARCHAR(5), CONSTRAINT myenum CHECK "
-                        "(bar IN ('one', 'two', 'three')))",
-                        {},
-                    )
-                ],
-            )
-        finally:
-            metadata.drop_all(engine)
+        self.assert_sql(
+            engine,
+            go,
+            [
+                (
+                    "CREATE TABLE foo (bar "
+                    "VARCHAR(5), CONSTRAINT myenum CHECK "
+                    "(bar IN ('one', 'two', 'three')))",
+                    {},
+                )
+            ],
+        )
 
     def test_standalone_enum(self, connection, metadata):
         etype = Enum(
@@ -605,26 +590,26 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         )
         etype.create(connection)
         try:
-            assert testing.db.dialect.has_type(connection, "fourfivesixtype")
+            assert connection.dialect.has_type(connection, "fourfivesixtype")
         finally:
             etype.drop(connection)
-            assert not testing.db.dialect.has_type(
+            assert not connection.dialect.has_type(
                 connection, "fourfivesixtype"
             )
         metadata.create_all(connection)
         try:
-            assert testing.db.dialect.has_type(connection, "fourfivesixtype")
+            assert connection.dialect.has_type(connection, "fourfivesixtype")
         finally:
             metadata.drop_all(connection)
-            assert not testing.db.dialect.has_type(
+            assert not connection.dialect.has_type(
                 connection, "fourfivesixtype"
             )
 
-    def test_no_support(self):
+    def test_no_support(self, testing_engine):
         def server_version_info(self):
             return (8, 2)
 
-        e = engines.testing_engine()
+        e = testing_engine()
         dialect = e.dialect
         dialect._get_server_version_info = server_version_info
 
@@ -692,8 +677,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(t2.c.value2.type.name, "fourfivesixtype")
         eq_(t2.c.value2.type.schema, "test_schema")
 
-    @testing.provide_metadata
-    def test_custom_subclass(self, connection):
+    def test_custom_subclass(self, metadata, connection):
         class MyEnum(TypeDecorator):
             impl = Enum("oneHI", "twoHI", "threeHI", name="myenum")
 
@@ -708,13 +692,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
                 return value
 
         t1 = Table("table1", self.metadata, Column("data", MyEnum()))
-        self.metadata.create_all(testing.db)
+        self.metadata.create_all(connection)
 
         connection.execute(t1.insert(), {"data": "two"})
         eq_(connection.scalar(select(t1.c.data)), "twoHITHERE")
 
-    @testing.provide_metadata
-    def test_generic_w_pg_variant(self, connection):
+    def test_generic_w_pg_variant(self, metadata, connection):
         some_table = Table(
             "some_table",
             self.metadata,
@@ -752,8 +735,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
             e["name"] for e in inspect(connection).get_enums()
         ]
 
-    @testing.provide_metadata
-    def test_generic_w_some_other_variant(self, connection):
+    def test_generic_w_some_other_variant(self, metadata, connection):
         some_table = Table(
             "some_table",
             self.metadata,
@@ -809,26 +791,28 @@ class RegClassTest(fixtures.TestBase):
     __only_on__ = "postgresql"
     __backend__ = True
 
-    @staticmethod
-    def _scalar(expression):
-        with testing.db.connect() as conn:
-            return conn.scalar(select(expression))
+    @testing.fixture()
+    def scalar(self, connection):
+        def go(expression):
+            return connection.scalar(select(expression))
 
-    def test_cast_name(self):
-        eq_(self._scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class")
+        return go
 
-    def test_cast_path(self):
+    def test_cast_name(self, scalar):
+        eq_(scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class")
+
+    def test_cast_path(self, scalar):
         eq_(
-            self._scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)),
+            scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)),
             "pg_class",
         )
 
-    def test_cast_oid(self):
+    def test_cast_oid(self, scalar):
         regclass = cast("pg_class", postgresql.REGCLASS)
-        oid = self._scalar(cast(regclass, postgresql.OID))
+        oid = scalar(cast(regclass, postgresql.OID))
         assert isinstance(oid, int)
         eq_(
-            self._scalar(
+            scalar(
                 cast(type_coerce(oid, postgresql.OID), postgresql.REGCLASS)
             ),
             "pg_class",
@@ -1339,13 +1323,12 @@ class ArrayRoundTripTest(object):
             Column("dimarr", ProcValue),
         )
 
-    def _fixture_456(self, table):
-        with testing.db.begin() as conn:
-            conn.execute(table.insert(), intarr=[4, 5, 6])
+    def _fixture_456(self, table, connection):
+        connection.execute(table.insert(), intarr=[4, 5, 6])
 
-    def test_reflect_array_column(self):
+    def test_reflect_array_column(self, connection):
         metadata2 = MetaData()
-        tbl = Table("arrtable", metadata2, autoload_with=testing.db)
+        tbl = Table("arrtable", metadata2, autoload_with=connection)
         assert isinstance(tbl.c.intarr.type, self.ARRAY)
         assert isinstance(tbl.c.strarr.type, self.ARRAY)
         assert isinstance(tbl.c.intarr.type.item_type, Integer)
@@ -1564,7 +1547,7 @@ class ArrayRoundTripTest(object):
 
     def test_array_getitem_single_exec(self, connection):
         arrtable = self.tables.arrtable
-        self._fixture_456(arrtable)
+        self._fixture_456(arrtable, connection)
         eq_(connection.scalar(select(arrtable.c.intarr[2])), 5)
         connection.execute(arrtable.update().values({arrtable.c.intarr[2]: 7}))
         eq_(connection.scalar(select(arrtable.c.intarr[2])), 7)
@@ -1654,11 +1637,10 @@ class ArrayRoundTripTest(object):
             set([("1", "2", "3"), ("4", "5", "6"), (("4", "5"), ("6", "7"))]),
         )
 
-    def test_array_plus_native_enum_create(self):
-        m = MetaData()
+    def test_array_plus_native_enum_create(self, metadata, connection):
         t = Table(
             "t",
-            m,
+            metadata,
             Column(
                 "data_1",
                 self.ARRAY(postgresql.ENUM("a", "b", "c", name="my_enum_1")),
@@ -1669,13 +1651,13 @@ class ArrayRoundTripTest(object):
             ),
         )
 
-        t.create(testing.db)
+        t.create(connection)
         eq_(
-            set(e["name"] for e in inspect(testing.db).get_enums()),
+            set(e["name"] for e in inspect(connection).get_enums()),
             set(["my_enum_1", "my_enum_2"]),
         )
-        t.drop(testing.db)
-        eq_(inspect(testing.db).get_enums(), [])
+        t.drop(connection)
+        eq_(inspect(connection).get_enums(), [])
 
 
 class CoreArrayRoundTripTest(
@@ -1690,33 +1672,35 @@ class PGArrayRoundTripTest(
 ):
     ARRAY = postgresql.ARRAY
 
-    @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),))
-    def test_undim_array_contains_typed_exec(self, struct):
+    @testing.combinations(
+        (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"
+    )
+    def test_undim_array_contains_typed_exec(self, struct, connection):
         arrtable = self.tables.arrtable
-        self._fixture_456(arrtable)
-        with testing.db.begin() as conn:
-            eq_(
-                conn.scalar(
-                    select(arrtable.c.intarr).where(
-                        arrtable.c.intarr.contains(struct([4, 5]))
-                    )
-                ),
-                [4, 5, 6],
-            )
+        self._fixture_456(arrtable, connection)
+        eq_(
+            connection.scalar(
+                select(arrtable.c.intarr).where(
+                    arrtable.c.intarr.contains(struct([4, 5]))
+                )
+            ),
+            [4, 5, 6],
+        )
 
-    @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),))
-    def test_dim_array_contains_typed_exec(self, struct):
+    @testing.combinations(
+        (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct"
+    )
+    def test_dim_array_contains_typed_exec(self, struct, connection):
         dim_arrtable = self.tables.dim_arrtable
-        self._fixture_456(dim_arrtable)
-        with testing.db.begin() as conn:
-            eq_(
-                conn.scalar(
-                    select(dim_arrtable.c.intarr).where(
-                        dim_arrtable.c.intarr.contains(struct([4, 5]))
-                    )
-                ),
-                [4, 5, 6],
-            )
+        self._fixture_456(dim_arrtable, connection)
+        eq_(
+            connection.scalar(
+                select(dim_arrtable.c.intarr).where(
+                    dim_arrtable.c.intarr.contains(struct([4, 5]))
+                )
+            ),
+            [4, 5, 6],
+        )
 
     def test_array_contained_by_exec(self, connection):
         arrtable = self.tables.arrtable
@@ -1730,7 +1714,7 @@ class PGArrayRoundTripTest(
 
     def test_undim_array_empty(self, connection):
         arrtable = self.tables.arrtable
-        self._fixture_456(arrtable)
+        self._fixture_456(arrtable, connection)
         eq_(
             connection.scalar(
                 select(arrtable.c.intarr).where(arrtable.c.intarr.contains([]))
@@ -1782,8 +1766,9 @@ class ArrayEnum(fixtures.TestBase):
         sqltypes.ARRAY, postgresql.ARRAY, argnames="array_cls"
     )
     @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
-    @testing.provide_metadata
-    def test_raises_non_native_enums(self, array_cls, enum_cls):
+    def test_raises_non_native_enums(
+        self, metadata, connection, array_cls, enum_cls
+    ):
         Table(
             "my_table",
             self.metadata,
@@ -1808,7 +1793,7 @@ class ArrayEnum(fixtures.TestBase):
             "for ARRAY of non-native ENUM; please specify "
             "create_constraint=False on this Enum datatype.",
             self.metadata.create_all,
-            testing.db,
+            connection,
         )
 
     @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls")
@@ -1818,8 +1803,7 @@ class ArrayEnum(fixtures.TestBase):
         (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
         argnames="array_cls",
     )
-    @testing.provide_metadata
-    def test_array_of_enums(self, array_cls, enum_cls, connection):
+    def test_array_of_enums(self, array_cls, enum_cls, metadata, connection):
         tbl = Table(
             "enum_table",
             self.metadata,
@@ -1875,8 +1859,7 @@ class ArrayJSON(fixtures.TestBase):
     @testing.combinations(
         sqltypes.JSON, postgresql.JSON, postgresql.JSONB, argnames="json_cls"
     )
-    @testing.provide_metadata
-    def test_array_of_json(self, array_cls, json_cls, connection):
+    def test_array_of_json(self, array_cls, json_cls, metadata, connection):
         tbl = Table(
             "json_table",
             self.metadata,
@@ -1982,19 +1965,38 @@ class HashableFlagORMTest(fixtures.TestBase):
                 },
             ],
         ),
+        (
+            "HSTORE",
+            postgresql.HSTORE(),
+            [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}],
+            testing.requires.hstore,
+        ),
+        (
+            "JSONB",
+            postgresql.JSONB(),
+            [
+                {"a": "1", "b": "2", "c": "3"},
+                {
+                    "d": "4",
+                    "e": {"e1": "5", "e2": "6"},
+                    "f": {"f1": [9, 10, 11]},
+                },
+            ],
+            testing.requires.postgresql_jsonb,
+        ),
+        argnames="type_,data",
         id_="iaa",
     )
-    @testing.provide_metadata
-    def test_hashable_flag(self, type_, data):
-        Base = declarative_base(metadata=self.metadata)
+    def test_hashable_flag(self, metadata, connection, type_, data):
+        Base = declarative_base(metadata=metadata)
 
         class A(Base):
             __tablename__ = "a1"
             id = Column(Integer, primary_key=True)
             data = Column(type_)
 
-        Base.metadata.create_all(testing.db)
-        s = Session(testing.db)
+        Base.metadata.create_all(connection)
+        s = Session(connection)
         s.add_all([A(data=elem) for elem in data])
         s.commit()
 
@@ -2006,27 +2008,6 @@ class HashableFlagORMTest(fixtures.TestBase):
             list(enumerate(data, 1)),
         )
 
-    @testing.requires.hstore
-    def test_hstore(self):
-        self.test_hashable_flag(
-            postgresql.HSTORE(),
-            [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}],
-        )
-
-    @testing.requires.postgresql_jsonb
-    def test_jsonb(self):
-        self.test_hashable_flag(
-            postgresql.JSONB(),
-            [
-                {"a": "1", "b": "2", "c": "3"},
-                {
-                    "d": "4",
-                    "e": {"e1": "5", "e2": "6"},
-                    "f": {"f1": [9, 10, 11]},
-                },
-            ],
-        )
-
 
 class TimestampTest(fixtures.TestBase, AssertsExecutionResults):
     __only_on__ = "postgresql"
@@ -2108,14 +2089,14 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
 
         return table
 
-    def test_reflection(self, special_types_table):
+    def test_reflection(self, special_types_table, connection):
         # cheat so that the "strict type check"
         # works
         special_types_table.c.year_interval.type = postgresql.INTERVAL()
         special_types_table.c.month_interval.type = postgresql.INTERVAL()
 
         m = MetaData()
-        t = Table("sometable", m, autoload_with=testing.db)
+        t = Table("sometable", m, autoload_with=connection)
 
         self.assert_tables_equal(special_types_table, t, strict_types=True)
         assert t.c.plain_interval.type.precision is None
@@ -2210,7 +2191,7 @@ class UUIDTest(fixtures.TestBase):
 class HStoreTest(AssertsCompiledSQL, fixtures.TestBase):
     __dialect__ = "postgresql"
 
-    def setup(self):
+    def setup_test(self):
         metadata = MetaData()
         self.test_table = Table(
             "test_table",
@@ -2494,17 +2475,16 @@ class HStoreRoundTripTest(fixtures.TablesTest):
             Column("data", HSTORE),
         )
 
-    def _fixture_data(self, engine):
+    def _fixture_data(self, connection):
         data_table = self.tables.data_table
-        with engine.begin() as conn:
-            conn.execute(
-                data_table.insert(),
-                {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
-                {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
-                {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
-                {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
-                {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}},
-            )
+        connection.execute(
+            data_table.insert(),
+            {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
+            {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
+            {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}},
+            {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}},
+            {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}},
+        )
 
     def _assert_data(self, compare, conn):
         data = conn.execute(
@@ -2514,26 +2494,32 @@ class HStoreRoundTripTest(fixtures.TablesTest):
         ).fetchall()
         eq_([d for d, in data], compare)
 
-    def _test_insert(self, engine):
-        with engine.begin() as conn:
-            conn.execute(
-                self.tables.data_table.insert(),
-                {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
-            )
-            self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn)
+    def _test_insert(self, connection):
+        connection.execute(
+            self.tables.data_table.insert(),
+            {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
+        )
+        self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection)
 
-    def _non_native_engine(self):
-        if testing.requires.psycopg2_native_hstore.enabled:
-            engine = engines.testing_engine(
-                options=dict(use_native_hstore=False)
-            )
+    @testing.fixture
+    def non_native_hstore_connection(self, testing_engine):
+        local_engine = testing.requires.psycopg2_native_hstore.enabled
+
+        if local_engine:
+            engine = testing_engine(options=dict(use_native_hstore=False))
         else:
             engine = testing.db
-        engine.connect().close()
-        return engine
 
-    def test_reflect(self):
-        insp = inspect(testing.db)
+        conn = engine.connect()
+        trans = conn.begin()
+        yield conn
+        try:
+            trans.rollback()
+        finally:
+            conn.close()
+
+    def test_reflect(self, connection):
+        insp = inspect(connection)
         cols = insp.get_columns("data_table")
         assert isinstance(cols[2]["type"], HSTORE)
 
@@ -2548,106 +2534,88 @@ class HStoreRoundTripTest(fixtures.TablesTest):
         eq_(connection.scalar(select(expr)), "3")
 
     @testing.requires.psycopg2_native_hstore
-    def test_insert_native(self):
-        engine = testing.db
-        self._test_insert(engine)
+    def test_insert_native(self, connection):
+        self._test_insert(connection)
 
-    def test_insert_python(self):
-        engine = self._non_native_engine()
-        self._test_insert(engine)
+    def test_insert_python(self, non_native_hstore_connection):
+        self._test_insert(non_native_hstore_connection)
 
     @testing.requires.psycopg2_native_hstore
-    def test_criterion_native(self):
-        engine = testing.db
-        self._fixture_data(engine)
-        self._test_criterion(engine)
+    def test_criterion_native(self, connection):
+        self._fixture_data(connection)
+        self._test_criterion(connection)
 
-    def test_criterion_python(self):
-        engine = self._non_native_engine()
-        self._fixture_data(engine)
-        self._test_criterion(engine)
+    def test_criterion_python(self, non_native_hstore_connection):
+        self._fixture_data(non_native_hstore_connection)
+        self._test_criterion(non_native_hstore_connection)
 
-    def _test_criterion(self, engine):
+    def _test_criterion(self, connection):
         data_table = self.tables.data_table
-        with engine.begin() as conn:
-            result = conn.execute(
-                select(data_table.c.data).where(
-                    data_table.c.data["k1"] == "r3v1"
-                )
-            ).first()
-            eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
+        result = connection.execute(
+            select(data_table.c.data).where(data_table.c.data["k1"] == "r3v1")
+        ).first()
+        eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
 
-    def _test_fixed_round_trip(self, engine):
-        with engine.begin() as conn:
-            s = select(
-                hstore(
-                    array(["key1", "key2", "key3"]),
-                    array(["value1", "value2", "value3"]),
-                )
-            )
-            eq_(
-                conn.scalar(s),
-                {"key1": "value1", "key2": "value2", "key3": "value3"},
+    def _test_fixed_round_trip(self, connection):
+        s = select(
+            hstore(
+                array(["key1", "key2", "key3"]),
+                array(["value1", "value2", "value3"]),
             )
+        )
+        eq_(
+            connection.scalar(s),
+            {"key1": "value1", "key2": "value2", "key3": "value3"},
+        )
 
-    def test_fixed_round_trip_python(self):
-        engine = self._non_native_engine()
-        self._test_fixed_round_trip(engine)
+    def test_fixed_round_trip_python(self, non_native_hstore_connection):
+        self._test_fixed_round_trip(non_native_hstore_connection)
 
     @testing.requires.psycopg2_native_hstore
-    def test_fixed_round_trip_native(self):
-        engine = testing.db
-        self._test_fixed_round_trip(engine)
+    def test_fixed_round_trip_native(self, connection):
+        self._test_fixed_round_trip(connection)
 
-    def _test_unicode_round_trip(self, engine):
-        with engine.begin() as conn:
-            s = select(
-                hstore(
-                    array(
-                        [util.u("réveillé"), util.u("drôle"), util.u("S’il")]
-                    ),
-                    array(
-                        [util.u("réveillé"), util.u("drôle"), util.u("S’il")]
-                    ),
-                )
-            )
-            eq_(
-                conn.scalar(s),
-                {
-                    util.u("réveillé"): util.u("réveillé"),
-                    util.u("drôle"): util.u("drôle"),
-                    util.u("S’il"): util.u("S’il"),
-                },
+    def _test_unicode_round_trip(self, connection):
+        s = select(
+            hstore(
+                array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]),
+                array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]),
             )
+        )
+        eq_(
+            connection.scalar(s),
+            {
+                util.u("réveillé"): util.u("réveillé"),
+                util.u("drôle"): util.u("drôle"),
+                util.u("S’il"): util.u("S’il"),
+            },
+        )
 
     @testing.requires.psycopg2_native_hstore
-    def test_unicode_round_trip_python(self):
-        engine = self._non_native_engine()
-        self._test_unicode_round_trip(engine)
+    def test_unicode_round_trip_python(self, non_native_hstore_connection):
+        self._test_unicode_round_trip(non_native_hstore_connection)
 
     @testing.requires.psycopg2_native_hstore
-    def test_unicode_round_trip_native(self):
-        engine = testing.db
-        self._test_unicode_round_trip(engine)
+    def test_unicode_round_trip_native(self, connection):
+        self._test_unicode_round_trip(connection)
 
-    def test_escaped_quotes_round_trip_python(self):
-        engine = self._non_native_engine()
-        self._test_escaped_quotes_round_trip(engine)
+    def test_escaped_quotes_round_trip_python(
+        self, non_native_hstore_connection
+    ):
+        self._test_escaped_quotes_round_trip(non_native_hstore_connection)
 
     @testing.requires.psycopg2_native_hstore
-    def test_escaped_quotes_round_trip_native(self):
-        engine = testing.db
-        self._test_escaped_quotes_round_trip(engine)
+    def test_escaped_quotes_round_trip_native(self, connection):
+        self._test_escaped_quotes_round_trip(connection)
 
-    def _test_escaped_quotes_round_trip(self, engine):
-        with engine.begin() as conn:
-            conn.execute(
-                self.tables.data_table.insert(),
-                {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}},
-            )
-            self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], conn)
+    def _test_escaped_quotes_round_trip(self, connection):
+        connection.execute(
+            self.tables.data_table.insert(),
+            {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}},
+        )
+        self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], connection)
 
-    def test_orm_round_trip(self):
+    def test_orm_round_trip(self, connection):
         from sqlalchemy import orm
 
         class Data(object):
@@ -2656,13 +2624,14 @@ class HStoreRoundTripTest(fixtures.TablesTest):
                 self.data = data
 
         orm.mapper(Data, self.tables.data_table)
-        s = orm.Session(testing.db)
-        d = Data(
-            name="r1",
-            data={"key1": "value1", "key2": "value2", "key3": "value3"},
-        )
-        s.add(d)
-        eq_(s.query(Data.data, Data).all(), [(d.data, d)])
+
+        with orm.Session(connection) as s:
+            d = Data(
+                name="r1",
+                data={"key1": "value1", "key2": "value2", "key3": "value3"},
+            )
+            s.add(d)
+            eq_(s.query(Data.data, Data).all(), [(d.data, d)])
 
 
 class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
@@ -2671,7 +2640,7 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
     # operator tests
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         table = Table(
             "data_table",
             MetaData(),
@@ -2852,10 +2821,10 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
     def test_actual_type(self):
         eq_(str(self._col_type()), self._col_str)
 
-    def test_reflect(self):
+    def test_reflect(self, connection):
         from sqlalchemy import inspect
 
-        insp = inspect(testing.db)
+        insp = inspect(connection)
         cols = insp.get_columns("data_table")
         assert isinstance(cols[0]["type"], self._col_type)
 
@@ -2986,8 +2955,8 @@ class _DateTimeTZRangeTests(object):
 
     def tstzs(self):
         if self._tstzs is None:
-            with testing.db.begin() as conn:
-                lower = conn.scalar(func.current_timestamp().select())
+            with testing.db.connect() as connection:
+                lower = connection.scalar(func.current_timestamp().select())
                 upper = lower + datetime.timedelta(1)
                 self._tstzs = (lower, upper)
         return self._tstzs
@@ -3052,7 +3021,7 @@ class DateTimeTZRangeRoundTripTest(_DateTimeTZRangeTests, _RangeTypeRoundTrip):
 class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
     __dialect__ = "postgresql"
 
-    def setup(self):
+    def setup_test(self):
         metadata = MetaData()
         self.test_table = Table(
             "test_table",
@@ -3151,7 +3120,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
             Column("nulldata", cls.data_type(none_as_null=True)),
         )
 
-    def _fixture_data(self, engine):
+    def _fixture_data(self, connection):
         data_table = self.tables.data_table
 
         data = [
@@ -3162,8 +3131,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
             {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}},
             {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}},
         ]
-        with engine.begin() as conn:
-            conn.execute(data_table.insert(), data)
+        connection.execute(data_table.insert(), data)
         return data
 
     def _assert_data(self, compare, conn, column="data"):
@@ -3185,51 +3153,39 @@ class JSONRoundTripTest(fixtures.TablesTest):
         ).fetchall()
         eq_([d for d, in data], [None])
 
-    def _test_insert(self, conn):
-        conn.execute(
+    def test_reflect(self, connection):
+        insp = inspect(connection)
+        cols = insp.get_columns("data_table")
+        assert isinstance(cols[2]["type"], self.data_type)
+
+    def test_insert(self, connection):
+        connection.execute(
             self.tables.data_table.insert(),
             {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
         )
-        self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn)
+        self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection)
 
-    def _test_insert_nulls(self, conn):
-        conn.execute(
+    def test_insert_nulls(self, connection):
+        connection.execute(
             self.tables.data_table.insert(), {"name": "r1", "data": null()}
         )
-        self._assert_data([None], conn)
+        self._assert_data([None], connection)
 
-    def _test_insert_none_as_null(self, conn):
-        conn.execute(
+    def test_insert_none_as_null(self, connection):
+        connection.execute(
             self.tables.data_table.insert(),
             {"name": "r1", "nulldata": None},
         )
-        self._assert_column_is_NULL(conn, column="nulldata")
+        self._assert_column_is_NULL(connection, column="nulldata")
 
-    def _test_insert_nulljson_into_none_as_null(self, conn):
-        conn.execute(
+    def test_insert_nulljson_into_none_as_null(self, connection):
+        connection.execute(
             self.tables.data_table.insert(),
             {"name": "r1", "nulldata": JSON.NULL},
         )
-        self._assert_column_is_JSON_NULL(conn, column="nulldata")
-
-    def test_reflect(self):
-        insp = inspect(testing.db)
-        cols = insp.get_columns("data_table")
-        assert isinstance(cols[2]["type"], self.data_type)
-
-    def test_insert(self, connection):
-        self._test_insert(connection)
-
-    def test_insert_nulls(self, connection):
-        self._test_insert_nulls(connection)
+        self._assert_column_is_JSON_NULL(connection, column="nulldata")
 
-    def test_insert_none_as_null(self, connection):
-        self._test_insert_none_as_null(connection)
-
-    def test_insert_nulljson_into_none_as_null(self, connection):
-        self._test_insert_nulljson_into_none_as_null(connection)
-
-    def test_custom_serialize_deserialize(self):
+    def test_custom_serialize_deserialize(self, testing_engine):
         import json
 
         def loads(value):
@@ -3242,7 +3198,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
             value["x"] = "dumps_y"
             return json.dumps(value)
 
-        engine = engines.testing_engine(
+        engine = testing_engine(
             options=dict(json_serializer=dumps, json_deserializer=loads)
         )
 
@@ -3250,14 +3206,26 @@ class JSONRoundTripTest(fixtures.TablesTest):
         with engine.begin() as conn:
             eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
 
-    def test_criterion(self):
-        engine = testing.db
-        self._fixture_data(engine)
-        self._test_criterion(engine)
+    def test_criterion(self, connection):
+        self._fixture_data(connection)
+        data_table = self.tables.data_table
+
+        result = connection.execute(
+            select(data_table.c.data).where(
+                data_table.c.data["k1"].astext == "r3v1"
+            )
+        ).first()
+        eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
+
+        result = connection.execute(
+            select(data_table.c.data).where(
+                data_table.c.data["k1"].astext.cast(String) == "r3v1"
+            )
+        ).first()
+        eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
 
     def test_path_query(self, connection):
-        engine = testing.db
-        self._fixture_data(engine)
+        self._fixture_data(connection)
         data_table = self.tables.data_table
 
         result = connection.execute(
@@ -3271,8 +3239,7 @@ class JSONRoundTripTest(fixtures.TablesTest):
         "postgresql < 9.4", "Improvement in PostgreSQL behavior?"
     )
     def test_multi_index_query(self, connection):
-        engine = testing.db
-        self._fixture_data(engine)
+        self._fixture_data(connection)
         data_table = self.tables.data_table
 
         result = connection.execute(
@@ -3283,20 +3250,18 @@ class JSONRoundTripTest(fixtures.TablesTest):
         eq_(result.scalar(), "r6")
 
     def test_query_returned_as_text(self, connection):
-        engine = testing.db
-        self._fixture_data(engine)
+        self._fixture_data(connection)
         data_table = self.tables.data_table
         result = connection.execute(
             select(data_table.c.data["k1"].astext)
         ).first()
-        if engine.dialect.returns_unicode_strings:
+        if connection.dialect.returns_unicode_strings:
             assert isinstance(result[0], util.text_type)
         else:
             assert isinstance(result[0], util.string_types)
 
     def test_query_returned_as_int(self, connection):
-        engine = testing.db
-        self._fixture_data(engine)
+        self._fixture_data(connection)
         data_table = self.tables.data_table
         result = connection.execute(
             select(data_table.c.data["k3"].astext.cast(Integer)).where(
@@ -3305,23 +3270,6 @@ class JSONRoundTripTest(fixtures.TablesTest):
         ).first()
         assert isinstance(result[0], int)
 
-    def _test_criterion(self, engine):
-        data_table = self.tables.data_table
-        with engine.begin() as conn:
-            result = conn.execute(
-                select(data_table.c.data).where(
-                    data_table.c.data["k1"].astext == "r3v1"
-                )
-            ).first()
-            eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
-
-            result = conn.execute(
-                select(data_table.c.data).where(
-                    data_table.c.data["k1"].astext.cast(String) == "r3v1"
-                )
-            ).first()
-            eq_(result, ({"k1": "r3v1", "k2": "r3v2"},))
-
     def test_fixed_round_trip(self, connection):
         s = select(
             cast(
@@ -3352,42 +3300,41 @@ class JSONRoundTripTest(fixtures.TablesTest):
             },
         )
 
-    def test_eval_none_flag_orm(self):
+    def test_eval_none_flag_orm(self, connection):
         Base = declarative_base()
 
         class Data(Base):
             __table__ = self.tables.data_table
 
-        s = Session(testing.db)
+        with Session(connection) as s:
+            d1 = Data(name="d1", data=None, nulldata=None)
+            s.add(d1)
+            s.commit()
 
-        d1 = Data(name="d1", data=None, nulldata=None)
-        s.add(d1)
-        s.commit()
-
-        s.bulk_insert_mappings(
-            Data, [{"name": "d2", "data": None, "nulldata": None}]
-        )
-        eq_(
-            s.query(
-                cast(self.tables.data_table.c.data, String),
-                cast(self.tables.data_table.c.nulldata, String),
+            s.bulk_insert_mappings(
+                Data, [{"name": "d2", "data": None, "nulldata": None}]
             )
-            .filter(self.tables.data_table.c.name == "d1")
-            .first(),
-            ("null", None),
-        )
-        eq_(
-            s.query(
-                cast(self.tables.data_table.c.data, String),
-                cast(self.tables.data_table.c.nulldata, String),
+            eq_(
+                s.query(
+                    cast(self.tables.data_table.c.data, String),
+                    cast(self.tables.data_table.c.nulldata, String),
+                )
+                .filter(self.tables.data_table.c.name == "d1")
+                .first(),
+                ("null", None),
+            )
+            eq_(
+                s.query(
+                    cast(self.tables.data_table.c.data, String),
+                    cast(self.tables.data_table.c.nulldata, String),
+                )
+                .filter(self.tables.data_table.c.name == "d2")
+                .first(),
+                ("null", None),
             )
-            .filter(self.tables.data_table.c.name == "d2")
-            .first(),
-            ("null", None),
-        )
 
     def test_literal(self, connection):
-        exp = self._fixture_data(testing.db)
+        exp = self._fixture_data(connection)
         result = connection.exec_driver_sql(
             "select data from data_table order by name"
         )
@@ -3395,11 +3342,10 @@ class JSONRoundTripTest(fixtures.TablesTest):
         eq_(len(res), len(exp))
         for row, expected in zip(res, exp):
             eq_(row[0], expected["data"])
-        result.close()
 
 
 class JSONBTest(JSONTest):
-    def setup(self):
+    def setup_test(self):
         metadata = MetaData()
         self.test_table = Table(
             "test_table",
index 4658b40a8d0884d8bd49fa7139290563ffab5de3..1926c6065282d7ca3dcbec232b7ba582124d6ffa 100644 (file)
@@ -37,6 +37,7 @@ from sqlalchemy import UniqueConstraint
 from sqlalchemy import util
 from sqlalchemy.dialects.sqlite import base as sqlite
 from sqlalchemy.dialects.sqlite import insert
+from sqlalchemy.dialects.sqlite import provision
 from sqlalchemy.dialects.sqlite import pysqlite as pysqlite_dialect
 from sqlalchemy.engine.url import make_url
 from sqlalchemy.schema import CreateTable
@@ -46,6 +47,7 @@ from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import AssertsExecutionResults
 from sqlalchemy.testing import combinations
+from sqlalchemy.testing import config
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
@@ -774,7 +776,7 @@ class AttachedDBTest(fixtures.TestBase):
 
     def _fixture(self):
         meta = self.metadata
-        self.conn = testing.db.connect()
+        self.conn = self.engine.connect()
         Table("created", meta, Column("foo", Integer), Column("bar", String))
         Table("local_only", meta, Column("q", Integer), Column("p", Integer))
 
@@ -798,14 +800,20 @@ class AttachedDBTest(fixtures.TestBase):
             meta.create_all(self.conn)
         return ct
 
-    def setup(self):
-        self.conn = testing.db.connect()
+    def setup_test(self):
+        self.engine = engines.testing_engine(options={"use_reaper": False})
+
+        provision._sqlite_post_configure_engine(
+            self.engine.url, self.engine, config.ident
+        )
+        self.conn = self.engine.connect()
         self.metadata = MetaData()
 
-    def teardown(self):
+    def teardown_test(self):
         with self.conn.begin():
             self.metadata.drop_all(self.conn)
         self.conn.close()
+        self.engine.dispose()
 
     def test_no_tables(self):
         insp = inspect(self.conn)
@@ -1495,7 +1503,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
     __skip_if__ = (full_text_search_missing,)
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         global metadata, cattable, matchtable
         metadata = MetaData()
         exec_sql(
@@ -1559,7 +1567,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
             )
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         metadata.drop_all(testing.db)
 
     def test_expression(self):
@@ -1681,7 +1689,7 @@ class AutoIncrementTest(fixtures.TestBase, AssertsCompiledSQL):
 class ReflectHeadlessFKsTest(fixtures.TestBase):
     __only_on__ = "sqlite"
 
-    def setup(self):
+    def setup_test(self):
         exec_sql(testing.db, "CREATE TABLE a (id INTEGER PRIMARY KEY)")
         # this syntax actually works on other DBs perhaps we'd want to add
         # tests to test_reflection
@@ -1689,7 +1697,7 @@ class ReflectHeadlessFKsTest(fixtures.TestBase):
             testing.db, "CREATE TABLE b (id INTEGER PRIMARY KEY REFERENCES a)"
         )
 
-    def teardown(self):
+    def teardown_test(self):
         exec_sql(testing.db, "drop table b")
         exec_sql(testing.db, "drop table a")
 
@@ -1728,7 +1736,7 @@ class ConstraintReflectionTest(fixtures.TestBase):
     __only_on__ = "sqlite"
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         with testing.db.begin() as conn:
 
             conn.exec_driver_sql("CREATE TABLE a1 (id INTEGER PRIMARY KEY)")
@@ -1876,7 +1884,7 @@ class ConstraintReflectionTest(fixtures.TestBase):
             )
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         with testing.db.begin() as conn:
             for name in [
                 "implicit_referrer_comp_fake",
@@ -2370,7 +2378,7 @@ class SavepointTest(fixtures.TablesTest):
 
     @classmethod
     def setup_bind(cls):
-        engine = engines.testing_engine(options={"use_reaper": False})
+        engine = engines.testing_engine(options={"scope": "class"})
 
         @event.listens_for(engine, "connect")
         def do_connect(dbapi_connection, connection_record):
@@ -2579,7 +2587,7 @@ class TypeReflectionTest(fixtures.TestBase):
 class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "sqlite"
 
-    def setUp(self):
+    def setup_test(self):
         self.table = table(
             "mytable", column("myid", Integer), column("name", String)
         )
index 396b48aa4af18b943090dc3b1238e83470a30be6..baa766d48fda4d636537163a56d475c71652439a 100644 (file)
@@ -21,7 +21,7 @@ from sqlalchemy.testing.schema import Table
 
 
 class DDLEventTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         self.bind = engines.mock_engine()
         self.metadata = MetaData()
         self.table = Table("t", self.metadata, Column("id", Integer))
@@ -374,7 +374,7 @@ class DDLEventTest(fixtures.TestBase):
 
 
 class DDLExecutionTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         self.engine = engines.mock_engine()
         self.metadata = MetaData()
         self.users = Table(
index a18cf756b1a682b5097f7897702e09d4f8a83461..0a2c9abe5842bc2759d5aedf8cda8068f3277864 100644 (file)
@@ -965,7 +965,7 @@ class TransactionTest(fixtures.TablesTest):
 class HandleInvalidatedOnConnectTest(fixtures.TestBase):
     __requires__ = ("sqlite",)
 
-    def setUp(self):
+    def setup_test(self):
         e = create_engine("sqlite://")
 
         connection = Mock(get_server_version_info=Mock(return_value="5.0"))
@@ -1021,18 +1021,18 @@ def MockDBAPI():  # noqa
 
 
 class PoolTestBase(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         pool.clear_managers()
         self._teardown_conns = []
 
-    def teardown(self):
+    def teardown_test(self):
         for ref in self._teardown_conns:
             conn = ref()
             if conn:
                 conn.close()
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         pool.clear_managers()
 
     def _queuepool_fixture(self, **kw):
@@ -1597,7 +1597,7 @@ class EngineEventsTest(fixtures.TestBase):
     __requires__ = ("ad_hoc_engines",)
     __backend__ = True
 
-    def tearDown(self):
+    def teardown_test(self):
         Engine.dispatch._clear()
         Engine._has_events = False
 
@@ -1650,6 +1650,7 @@ class EngineEventsTest(fixtures.TestBase):
         event.listen(
             engine, "before_cursor_execute", cursor_execute, retval=True
         )
+
         with testing.expect_deprecated(
             r"The argument signature for the "
             r"\"ConnectionEvents.before_execute\" event listener",
@@ -1676,11 +1677,12 @@ class EngineEventsTest(fixtures.TestBase):
             r"The argument signature for the "
             r"\"ConnectionEvents.after_execute\" event listener",
         ):
-            e1.execute(select(1))
+            result = e1.execute(select(1))
+            result.close()
 
 
 class DDLExecutionTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         self.engine = engines.mock_engine()
         self.metadata = MetaData()
         self.users = Table(
index 21d4e06e06ea38358f0345e654d463319c59f999..a1e4ea218ea43ce0433a0d1dd88ae2b0b62681cc 100644 (file)
@@ -43,7 +43,6 @@ from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertsql import CompiledSQL
-from sqlalchemy.testing.engines import testing_engine
 from sqlalchemy.testing.mock import call
 from sqlalchemy.testing.mock import Mock
 from sqlalchemy.testing.mock import patch
@@ -94,13 +93,13 @@ class ExecuteTest(fixtures.TablesTest):
             ).default_from()
         )
 
-        conn = testing.db.connect()
-        result = (
-            conn.execution_options(no_parameters=True)
-            .exec_driver_sql(stmt)
-            .scalar()
-        )
-        eq_(result, "%")
+        with testing.db.connect() as conn:
+            result = (
+                conn.execution_options(no_parameters=True)
+                .exec_driver_sql(stmt)
+                .scalar()
+            )
+            eq_(result, "%")
 
     def test_raw_positional_invalid(self, connection):
         assert_raises_message(
@@ -261,16 +260,15 @@ class ExecuteTest(fixtures.TablesTest):
             (4, "sally"),
         ]
 
-    @testing.engines.close_open_connections
     def test_exception_wrapping_dbapi(self):
-        conn = testing.db.connect()
-        # engine does not have exec_driver_sql
-        assert_raises_message(
-            tsa.exc.DBAPIError,
-            r"not_a_valid_statement",
-            conn.exec_driver_sql,
-            "not_a_valid_statement",
-        )
+        with testing.db.connect() as conn:
+            # engine does not have exec_driver_sql
+            assert_raises_message(
+                tsa.exc.DBAPIError,
+                r"not_a_valid_statement",
+                conn.exec_driver_sql,
+                "not_a_valid_statement",
+            )
 
     @testing.requires.sqlite
     def test_exception_wrapping_non_dbapi_error(self):
@@ -864,12 +862,10 @@ class CompiledCacheTest(fixtures.TestBase):
         ["sqlite", "mysql", "postgresql"],
         "uses blob value that is problematic for some DBAPIs",
     )
-    @testing.provide_metadata
-    def test_cache_noleak_on_statement_values(self, connection):
+    def test_cache_noleak_on_statement_values(self, metadata, connection):
         # This is a non regression test for an object reference leak caused
         # by the compiled_cache.
 
-        metadata = self.metadata
         photo = Table(
             "photo",
             metadata,
@@ -1040,7 +1036,19 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
     __requires__ = ("schemas",)
     __backend__ = True
 
-    def test_create_table(self):
+    @testing.fixture
+    def plain_tables(self, metadata):
+        t1 = Table(
+            "t1", metadata, Column("x", Integer), schema=config.test_schema
+        )
+        t2 = Table(
+            "t2", metadata, Column("x", Integer), schema=config.test_schema
+        )
+        t3 = Table("t3", metadata, Column("x", Integer), schema=None)
+
+        return t1, t2, t3
+
+    def test_create_table(self, plain_tables, connection):
         map_ = {
             None: config.test_schema,
             "foo": config.test_schema,
@@ -1052,18 +1060,16 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
         t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
         t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
 
-        with self.sql_execution_asserter(config.db) as asserter:
-            with config.db.begin() as conn, conn.execution_options(
-                schema_translate_map=map_
-            ) as conn:
+        with self.sql_execution_asserter(connection) as asserter:
+            conn = connection.execution_options(schema_translate_map=map_)
 
-                t1.create(conn)
-                t2.create(conn)
-                t3.create(conn)
+            t1.create(conn)
+            t2.create(conn)
+            t3.create(conn)
 
-                t3.drop(conn)
-                t2.drop(conn)
-                t1.drop(conn)
+            t3.drop(conn)
+            t2.drop(conn)
+            t1.drop(conn)
 
         asserter.assert_(
             CompiledSQL("CREATE TABLE [SCHEMA__none].t1 (x INTEGER)"),
@@ -1074,14 +1080,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
             CompiledSQL("DROP TABLE [SCHEMA__none].t1"),
         )
 
-    def _fixture(self):
-        metadata = self.metadata
-        Table("t1", metadata, Column("x", Integer), schema=config.test_schema)
-        Table("t2", metadata, Column("x", Integer), schema=config.test_schema)
-        Table("t3", metadata, Column("x", Integer), schema=None)
-        metadata.create_all(testing.db)
-
-    def test_ddl_hastable(self):
+    def test_ddl_hastable(self, plain_tables, connection):
 
         map_ = {
             None: config.test_schema,
@@ -1094,27 +1093,28 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
         Table("t2", metadata, Column("x", Integer), schema="foo")
         Table("t3", metadata, Column("x", Integer), schema="bar")
 
-        with config.db.begin() as conn:
-            conn = conn.execution_options(schema_translate_map=map_)
-            metadata.create_all(conn)
+        conn = connection.execution_options(schema_translate_map=map_)
+        metadata.create_all(conn)
 
-        insp = inspect(config.db)
+        insp = inspect(connection)
         is_true(insp.has_table("t1", schema=config.test_schema))
         is_true(insp.has_table("t2", schema=config.test_schema))
         is_true(insp.has_table("t3", schema=None))
 
-        with config.db.begin() as conn:
-            conn = conn.execution_options(schema_translate_map=map_)
-            metadata.drop_all(conn)
+        conn = connection.execution_options(schema_translate_map=map_)
+
+        # if this test fails, the tables won't get dropped.  so need a
+        # more robust fixture for this
+        metadata.drop_all(conn)
 
-        insp = inspect(config.db)
+        insp = inspect(connection)
         is_false(insp.has_table("t1", schema=config.test_schema))
         is_false(insp.has_table("t2", schema=config.test_schema))
         is_false(insp.has_table("t3", schema=None))
 
-    @testing.provide_metadata
-    def test_option_on_execute(self):
-        self._fixture()
+    def test_option_on_execute(self, plain_tables, connection):
+        # provided by metadata fixture provided by plain_tables fixture
+        self.metadata.create_all(connection)
 
         map_ = {
             None: config.test_schema,
@@ -1127,61 +1127,54 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
         t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
         t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
 
-        with self.sql_execution_asserter(config.db) as asserter:
-            with config.db.begin() as conn:
+        with self.sql_execution_asserter(connection) as asserter:
+            conn = connection
+            execution_options = {"schema_translate_map": map_}
+            conn._execute_20(
+                t1.insert(), {"x": 1}, execution_options=execution_options
+            )
+            conn._execute_20(
+                t2.insert(), {"x": 1}, execution_options=execution_options
+            )
+            conn._execute_20(
+                t3.insert(), {"x": 1}, execution_options=execution_options
+            )
 
-                execution_options = {"schema_translate_map": map_}
-                conn._execute_20(
-                    t1.insert(), {"x": 1}, execution_options=execution_options
-                )
-                conn._execute_20(
-                    t2.insert(), {"x": 1}, execution_options=execution_options
-                )
-                conn._execute_20(
-                    t3.insert(), {"x": 1}, execution_options=execution_options
-                )
+            conn._execute_20(
+                t1.update().values(x=1).where(t1.c.x == 1),
+                execution_options=execution_options,
+            )
+            conn._execute_20(
+                t2.update().values(x=2).where(t2.c.x == 1),
+                execution_options=execution_options,
+            )
+            conn._execute_20(
+                t3.update().values(x=3).where(t3.c.x == 1),
+                execution_options=execution_options,
+            )
 
+            eq_(
                 conn._execute_20(
-                    t1.update().values(x=1).where(t1.c.x == 1),
-                    execution_options=execution_options,
-                )
+                    select(t1.c.x), execution_options=execution_options
+                ).scalar(),
+                1,
+            )
+            eq_(
                 conn._execute_20(
-                    t2.update().values(x=2).where(t2.c.x == 1),
-                    execution_options=execution_options,
-                )
+                    select(t2.c.x), execution_options=execution_options
+                ).scalar(),
+                2,
+            )
+            eq_(
                 conn._execute_20(
-                    t3.update().values(x=3).where(t3.c.x == 1),
-                    execution_options=execution_options,
-                )
-
-                eq_(
-                    conn._execute_20(
-                        select(t1.c.x), execution_options=execution_options
-                    ).scalar(),
-                    1,
-                )
-                eq_(
-                    conn._execute_20(
-                        select(t2.c.x), execution_options=execution_options
-                    ).scalar(),
-                    2,
-                )
-                eq_(
-                    conn._execute_20(
-                        select(t3.c.x), execution_options=execution_options
-                    ).scalar(),
-                    3,
-                )
+                    select(t3.c.x), execution_options=execution_options
+                ).scalar(),
+                3,
+            )
 
-                conn._execute_20(
-                    t1.delete(), execution_options=execution_options
-                )
-                conn._execute_20(
-                    t2.delete(), execution_options=execution_options
-                )
-                conn._execute_20(
-                    t3.delete(), execution_options=execution_options
-                )
+            conn._execute_20(t1.delete(), execution_options=execution_options)
+            conn._execute_20(t2.delete(), execution_options=execution_options)
+            conn._execute_20(t3.delete(), execution_options=execution_options)
 
         asserter.assert_(
             CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"),
@@ -1207,9 +1200,9 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
             CompiledSQL("DELETE FROM [SCHEMA_bar].t3"),
         )
 
-    @testing.provide_metadata
-    def test_crud(self):
-        self._fixture()
+    def test_crud(self, plain_tables, connection):
+        # provided by metadata fixture provided by plain_tables fixture
+        self.metadata.create_all(connection)
 
         map_ = {
             None: config.test_schema,
@@ -1222,26 +1215,24 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
         t2 = Table("t2", metadata, Column("x", Integer), schema="foo")
         t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
 
-        with self.sql_execution_asserter(config.db) as asserter:
-            with config.db.begin() as conn, conn.execution_options(
-                schema_translate_map=map_
-            ) as conn:
+        with self.sql_execution_asserter(connection) as asserter:
+            conn = connection.execution_options(schema_translate_map=map_)
 
-                conn.execute(t1.insert(), {"x": 1})
-                conn.execute(t2.insert(), {"x": 1})
-                conn.execute(t3.insert(), {"x": 1})
+            conn.execute(t1.insert(), {"x": 1})
+            conn.execute(t2.insert(), {"x": 1})
+            conn.execute(t3.insert(), {"x": 1})
 
-                conn.execute(t1.update().values(x=1).where(t1.c.x == 1))
-                conn.execute(t2.update().values(x=2).where(t2.c.x == 1))
-                conn.execute(t3.update().values(x=3).where(t3.c.x == 1))
+            conn.execute(t1.update().values(x=1).where(t1.c.x == 1))
+            conn.execute(t2.update().values(x=2).where(t2.c.x == 1))
+            conn.execute(t3.update().values(x=3).where(t3.c.x == 1))
 
-                eq_(conn.scalar(select(t1.c.x)), 1)
-                eq_(conn.scalar(select(t2.c.x)), 2)
-                eq_(conn.scalar(select(t3.c.x)), 3)
+            eq_(conn.scalar(select(t1.c.x)), 1)
+            eq_(conn.scalar(select(t2.c.x)), 2)
+            eq_(conn.scalar(select(t3.c.x)), 3)
 
-                conn.execute(t1.delete())
-                conn.execute(t2.delete())
-                conn.execute(t3.delete())
+            conn.execute(t1.delete())
+            conn.execute(t2.delete())
+            conn.execute(t3.delete())
 
         asserter.assert_(
             CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"),
@@ -1267,9 +1258,10 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
             CompiledSQL("DELETE FROM [SCHEMA_bar].t3"),
         )
 
-    @testing.provide_metadata
-    def test_via_engine(self):
-        self._fixture()
+    def test_via_engine(self, plain_tables, metadata):
+
+        with config.db.begin() as connection:
+            metadata.create_all(connection)
 
         map_ = {
             None: config.test_schema,
@@ -1282,25 +1274,25 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
 
         with self.sql_execution_asserter(config.db) as asserter:
             eng = config.db.execution_options(schema_translate_map=map_)
-            conn = eng.connect()
-            conn.execute(select(t2.c.x))
+            with eng.connect() as conn:
+                conn.execute(select(t2.c.x))
         asserter.assert_(
             CompiledSQL("SELECT [SCHEMA_foo].t2.x FROM [SCHEMA_foo].t2")
         )
 
 
 class ExecutionOptionsTest(fixtures.TestBase):
-    def test_dialect_conn_options(self):
+    def test_dialect_conn_options(self, testing_engine):
         engine = testing_engine("sqlite://", options=dict(_initialize=False))
         engine.dialect = Mock()
-        conn = engine.connect()
-        c2 = conn.execution_options(foo="bar")
-        eq_(
-            engine.dialect.set_connection_execution_options.mock_calls,
-            [call(c2, {"foo": "bar"})],
-        )
+        with engine.connect() as conn:
+            c2 = conn.execution_options(foo="bar")
+            eq_(
+                engine.dialect.set_connection_execution_options.mock_calls,
+                [call(c2, {"foo": "bar"})],
+            )
 
-    def test_dialect_engine_options(self):
+    def test_dialect_engine_options(self, testing_engine):
         engine = testing_engine("sqlite://")
         engine.dialect = Mock()
         e2 = engine.execution_options(foo="bar")
@@ -1319,14 +1311,14 @@ class ExecutionOptionsTest(fixtures.TestBase):
             [call(engine, {"foo": "bar"})],
         )
 
-    def test_propagate_engine_to_connection(self):
+    def test_propagate_engine_to_connection(self, testing_engine):
         engine = testing_engine(
             "sqlite://", options=dict(execution_options={"foo": "bar"})
         )
-        conn = engine.connect()
-        eq_(conn._execution_options, {"foo": "bar"})
+        with engine.connect() as conn:
+            eq_(conn._execution_options, {"foo": "bar"})
 
-    def test_propagate_option_engine_to_connection(self):
+    def test_propagate_option_engine_to_connection(self, testing_engine):
         e1 = testing_engine(
             "sqlite://", options=dict(execution_options={"foo": "bar"})
         )
@@ -1336,27 +1328,30 @@ class ExecutionOptionsTest(fixtures.TestBase):
         eq_(c1._execution_options, {"foo": "bar"})
         eq_(c2._execution_options, {"foo": "bar", "bat": "hoho"})
 
-    def test_get_engine_execution_options(self):
+        c1.close()
+        c2.close()
+
+    def test_get_engine_execution_options(self, testing_engine):
         engine = testing_engine("sqlite://")
         engine.dialect = Mock()
         e2 = engine.execution_options(foo="bar")
 
         eq_(e2.get_execution_options(), {"foo": "bar"})
 
-    def test_get_connection_execution_options(self):
+    def test_get_connection_execution_options(self, testing_engine):
         engine = testing_engine("sqlite://", options=dict(_initialize=False))
         engine.dialect = Mock()
-        conn = engine.connect()
-        c = conn.execution_options(foo="bar")
+        with engine.connect() as conn:
+            c = conn.execution_options(foo="bar")
 
-        eq_(c.get_execution_options(), {"foo": "bar"})
+            eq_(c.get_execution_options(), {"foo": "bar"})
 
 
 class EngineEventsTest(fixtures.TestBase):
     __requires__ = ("ad_hoc_engines",)
     __backend__ = True
 
-    def tearDown(self):
+    def teardown_test(self):
         Engine.dispatch._clear()
         Engine._has_events = False
 
@@ -1376,7 +1371,7 @@ class EngineEventsTest(fixtures.TestBase):
                 ):
                     break
 
-    def test_per_engine_independence(self):
+    def test_per_engine_independence(self, testing_engine):
         e1 = testing_engine(config.db_url)
         e2 = testing_engine(config.db_url)
 
@@ -1400,7 +1395,7 @@ class EngineEventsTest(fixtures.TestBase):
             conn.execute(s2)
         eq_([arg[1][1] for arg in canary.mock_calls], [s1, s1, s2])
 
-    def test_per_engine_plus_global(self):
+    def test_per_engine_plus_global(self, testing_engine):
         canary = Mock()
         event.listen(Engine, "before_execute", canary.be1)
         e1 = testing_engine(config.db_url)
@@ -1409,8 +1404,6 @@ class EngineEventsTest(fixtures.TestBase):
         event.listen(e1, "before_execute", canary.be2)
 
         event.listen(Engine, "before_execute", canary.be3)
-        e1.connect()
-        e2.connect()
 
         with e1.connect() as conn:
             conn.execute(select(1))
@@ -1424,7 +1417,7 @@ class EngineEventsTest(fixtures.TestBase):
         eq_(canary.be2.call_count, 1)
         eq_(canary.be3.call_count, 2)
 
-    def test_per_connection_plus_engine(self):
+    def test_per_connection_plus_engine(self, testing_engine):
         canary = Mock()
         e1 = testing_engine(config.db_url)
 
@@ -1442,9 +1435,14 @@ class EngineEventsTest(fixtures.TestBase):
             eq_(canary.be1.call_count, 2)
             eq_(canary.be2.call_count, 2)
 
-    @testing.combinations((True, False), (True, True), (False, False))
+    @testing.combinations(
+        (True, False),
+        (True, True),
+        (False, False),
+        argnames="mock_out_on_connect, add_our_own_onconnect",
+    )
     def test_insert_connect_is_definitely_first(
-        self, mock_out_on_connect, add_our_own_onconnect
+        self, mock_out_on_connect, add_our_own_onconnect, testing_engine
     ):
         """test issue #5708.
 
@@ -1478,7 +1476,7 @@ class EngineEventsTest(fixtures.TestBase):
             patcher = util.nullcontext()
 
         with patcher:
-            e1 = create_engine(config.db_url)
+            e1 = testing_engine(config.db_url)
 
             initialize = e1.dialect.initialize
 
@@ -1559,10 +1557,11 @@ class EngineEventsTest(fixtures.TestBase):
             conn.exec_driver_sql(select1(testing.db))
         eq_(m1.mock_calls, [])
 
-    def test_add_event_after_connect(self):
+    def test_add_event_after_connect(self, testing_engine):
         # new feature as of #2978
+
         canary = Mock()
-        e1 = create_engine(config.db_url)
+        e1 = testing_engine(config.db_url, future=False)
         assert not e1._has_events
 
         conn = e1.connect()
@@ -1575,9 +1574,9 @@ class EngineEventsTest(fixtures.TestBase):
         conn._branch().execute(select(1))
         eq_(canary.be1.call_count, 2)
 
-    def test_force_conn_events_false(self):
+    def test_force_conn_events_false(self, testing_engine):
         canary = Mock()
-        e1 = create_engine(config.db_url)
+        e1 = testing_engine(config.db_url, future=False)
         assert not e1._has_events
 
         event.listen(e1, "before_execute", canary.be1)
@@ -1593,7 +1592,7 @@ class EngineEventsTest(fixtures.TestBase):
         conn._branch().execute(select(1))
         eq_(canary.be1.call_count, 0)
 
-    def test_cursor_events_ctx_execute_scalar(self):
+    def test_cursor_events_ctx_execute_scalar(self, testing_engine):
         canary = Mock()
         e1 = testing_engine(config.db_url)
 
@@ -1620,7 +1619,7 @@ class EngineEventsTest(fixtures.TestBase):
             [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)],
         )
 
-    def test_cursor_events_execute(self):
+    def test_cursor_events_execute(self, testing_engine):
         canary = Mock()
         e1 = testing_engine(config.db_url)
 
@@ -1653,9 +1652,15 @@ class EngineEventsTest(fixtures.TestBase):
         ),
         ((), {"z": 10}, [], {"z": 10}, testing.requires.legacy_engine),
         (({"z": 10},), {}, [], {"z": 10}),
+        argnames="multiparams, params, expected_multiparams, expected_params",
     )
     def test_modify_parameters_from_event_one(
-        self, multiparams, params, expected_multiparams, expected_params
+        self,
+        multiparams,
+        params,
+        expected_multiparams,
+        expected_params,
+        testing_engine,
     ):
         # this is testing both the normalization added to parameters
         # as of I97cb4d06adfcc6b889f10d01cc7775925cffb116 as well as
@@ -1704,7 +1709,9 @@ class EngineEventsTest(fixtures.TestBase):
             [(15,), (19,)],
         )
 
-    def test_modify_parameters_from_event_three(self, connection):
+    def test_modify_parameters_from_event_three(
+        self, connection, testing_engine
+    ):
         def before_execute(
             conn, clauseelement, multiparams, params, execution_options
         ):
@@ -1721,7 +1728,7 @@ class EngineEventsTest(fixtures.TestBase):
             with e1.connect() as conn:
                 conn.execute(select(literal("1")))
 
-    def test_argument_format_execute(self):
+    def test_argument_format_execute(self, testing_engine):
         def before_execute(
             conn, clauseelement, multiparams, params, execution_options
         ):
@@ -1956,9 +1963,9 @@ class EngineEventsTest(fixtures.TestBase):
         )
 
     @testing.requires.ad_hoc_engines
-    def test_dispose_event(self):
+    def test_dispose_event(self, testing_engine):
         canary = Mock()
-        eng = create_engine(testing.db.url)
+        eng = testing_engine(testing.db.url)
         event.listen(eng, "engine_disposed", canary)
 
         conn = eng.connect()
@@ -2102,13 +2109,13 @@ class EngineEventsTest(fixtures.TestBase):
         event.listen(engine, "commit", tracker("commit"))
         event.listen(engine, "rollback", tracker("rollback"))
 
-        conn = engine.connect()
-        trans = conn.begin()
-        conn.execute(select(1))
-        trans.rollback()
-        trans = conn.begin()
-        conn.execute(select(1))
-        trans.commit()
+        with engine.connect() as conn:
+            trans = conn.begin()
+            conn.execute(select(1))
+            trans.rollback()
+            trans = conn.begin()
+            conn.execute(select(1))
+            trans.commit()
 
         eq_(
             canary,
@@ -2145,13 +2152,13 @@ class EngineEventsTest(fixtures.TestBase):
         event.listen(engine, "commit", tracker("commit"), named=True)
         event.listen(engine, "rollback", tracker("rollback"), named=True)
 
-        conn = engine.connect()
-        trans = conn.begin()
-        conn.execute(select(1))
-        trans.rollback()
-        trans = conn.begin()
-        conn.execute(select(1))
-        trans.commit()
+        with engine.connect() as conn:
+            trans = conn.begin()
+            conn.execute(select(1))
+            trans.rollback()
+            trans = conn.begin()
+            conn.execute(select(1))
+            trans.commit()
 
         eq_(
             canary,
@@ -2310,7 +2317,7 @@ class HandleErrorTest(fixtures.TestBase):
     __requires__ = ("ad_hoc_engines",)
     __backend__ = True
 
-    def tearDown(self):
+    def teardown_test(self):
         Engine.dispatch._clear()
         Engine._has_events = False
 
@@ -2742,7 +2749,7 @@ class HandleErrorTest(fixtures.TestBase):
 class HandleInvalidatedOnConnectTest(fixtures.TestBase):
     __requires__ = ("sqlite",)
 
-    def setUp(self):
+    def setup_test(self):
         e = create_engine("sqlite://")
 
         connection = Mock(get_server_version_info=Mock(return_value="5.0"))
@@ -3014,6 +3021,9 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase):
             ],
         )
 
+        c.close()
+        c2.close()
+
 
 class DialectEventTest(fixtures.TestBase):
     @contextmanager
@@ -3370,7 +3380,7 @@ class SetInputSizesTest(fixtures.TablesTest):
         )
 
     @testing.fixture
-    def input_sizes_fixture(self):
+    def input_sizes_fixture(self, testing_engine):
         canary = mock.Mock()
 
         def do_set_input_sizes(cursor, list_of_tuples, context):
index 29b8132aa326bbcb4a133a807e88f731f27613d4..c565892487d31109ead9e5951ba8ac3af54c5288 100644 (file)
@@ -30,7 +30,7 @@ class LogParamsTest(fixtures.TestBase):
     __only_on__ = "sqlite"
     __requires__ = ("ad_hoc_engines",)
 
-    def setup(self):
+    def setup_test(self):
         self.eng = engines.testing_engine(options={"echo": True})
         self.no_param_engine = engines.testing_engine(
             options={"echo": True, "hide_parameters": True}
@@ -44,7 +44,7 @@ class LogParamsTest(fixtures.TestBase):
         for log in [logging.getLogger("sqlalchemy.engine")]:
             log.addHandler(self.buf)
 
-    def teardown(self):
+    def teardown_test(self):
         exec_sql(self.eng, "drop table if exists foo")
         for log in [logging.getLogger("sqlalchemy.engine")]:
             log.removeHandler(self.buf)
@@ -413,14 +413,14 @@ class LogParamsTest(fixtures.TestBase):
 
 
 class PoolLoggingTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         self.existing_level = logging.getLogger("sqlalchemy.pool").level
 
         self.buf = logging.handlers.BufferingHandler(100)
         for log in [logging.getLogger("sqlalchemy.pool")]:
             log.addHandler(self.buf)
 
-    def teardown(self):
+    def teardown_test(self):
         for log in [logging.getLogger("sqlalchemy.pool")]:
             log.removeHandler(self.buf)
         logging.getLogger("sqlalchemy.pool").setLevel(self.existing_level)
@@ -528,7 +528,7 @@ class LoggingNameTest(fixtures.TestBase):
         kw.update({"echo": True})
         return engines.testing_engine(options=kw)
 
-    def setup(self):
+    def setup_test(self):
         self.buf = logging.handlers.BufferingHandler(100)
         for log in [
             logging.getLogger("sqlalchemy.engine"),
@@ -536,7 +536,7 @@ class LoggingNameTest(fixtures.TestBase):
         ]:
             log.addHandler(self.buf)
 
-    def teardown(self):
+    def teardown_test(self):
         for log in [
             logging.getLogger("sqlalchemy.engine"),
             logging.getLogger("sqlalchemy.pool"),
@@ -588,13 +588,13 @@ class LoggingNameTest(fixtures.TestBase):
 class EchoTest(fixtures.TestBase):
     __requires__ = ("ad_hoc_engines",)
 
-    def setup(self):
+    def setup_test(self):
         self.level = logging.getLogger("sqlalchemy.engine").level
         logging.getLogger("sqlalchemy.engine").setLevel(logging.WARN)
         self.buf = logging.handlers.BufferingHandler(100)
         logging.getLogger("sqlalchemy.engine").addHandler(self.buf)
 
-    def teardown(self):
+    def teardown_test(self):
         logging.getLogger("sqlalchemy.engine").removeHandler(self.buf)
         logging.getLogger("sqlalchemy.engine").setLevel(self.level)
 
index 550fedb8e6efcf702edb3ce731b0bbb67fa6e1b3..decdce3f9bec5f1afb8b45b284259059f0a596a3 100644 (file)
@@ -17,7 +17,9 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_none
 from sqlalchemy.testing import is_not
+from sqlalchemy.testing import is_not_none
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.engines import testing_engine
@@ -63,18 +65,18 @@ def MockDBAPI():  # noqa
 
 
 class PoolTestBase(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         pool.clear_managers()
         self._teardown_conns = []
 
-    def teardown(self):
+    def teardown_test(self):
         for ref in self._teardown_conns:
             conn = ref()
             if conn:
                 conn.close()
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         pool.clear_managers()
 
     def _with_teardown(self, connection):
@@ -364,10 +366,17 @@ class PoolEventsTest(PoolTestBase):
         p = self._queuepool_fixture()
         canary = []
 
+        @event.listens_for(p, "checkin")
         def checkin(*arg, **kw):
             canary.append("checkin")
 
-        event.listen(p, "checkin", checkin)
+        @event.listens_for(p, "close_detached")
+        def close_detached(*arg, **kw):
+            canary.append("close_detached")
+
+        @event.listens_for(p, "detach")
+        def detach(*arg, **kw):
+            canary.append("detach")
 
         return p, canary
 
@@ -629,15 +638,35 @@ class PoolEventsTest(PoolTestBase):
         assert canary.call_args_list[0][0][0] is dbapi_con
         assert canary.call_args_list[0][0][2] is exc
 
+    @testing.combinations((True, testing.requires.python3), (False,))
     @testing.requires.predictable_gc
-    def test_checkin_event_gc(self):
+    def test_checkin_event_gc(self, detach_gced):
         p, canary = self._checkin_event_fixture()
 
+        if detach_gced:
+            p._is_asyncio = True
+
         c1 = p.connect()
+
+        dbapi_connection = weakref.ref(c1.connection)
+
         eq_(canary, [])
         del c1
         lazy_gc()
-        eq_(canary, ["checkin"])
+
+        if detach_gced:
+            # "close_detached" is not called because for asyncio the
+            # connection is just lost.
+            eq_(canary, ["detach"])
+
+        else:
+            eq_(canary, ["checkin"])
+
+        gc_collect()
+        if detach_gced:
+            is_none(dbapi_connection())
+        else:
+            is_not_none(dbapi_connection())
 
     def test_checkin_event_on_subsequently_recreated(self):
         p, canary = self._checkin_event_fixture()
@@ -744,7 +773,7 @@ class PoolEventsTest(PoolTestBase):
         eq_(conn.info["important_flag"], True)
         conn.close()
 
-    def teardown(self):
+    def teardown_test(self):
         # TODO: need to get remove() functionality
         # going
         pool.Pool.dispatch._clear()
@@ -1490,12 +1519,16 @@ class QueuePoolTest(PoolTestBase):
 
         self._assert_cleanup_on_pooled_reconnect(dbapi, p)
 
+    @testing.combinations((True, testing.requires.python3), (False,))
     @testing.requires.predictable_gc
-    def test_userspace_disconnectionerror_weakref_finalizer(self):
+    def test_userspace_disconnectionerror_weakref_finalizer(self, detach_gced):
         dbapi, pool = self._queuepool_dbapi_fixture(
             pool_size=1, max_overflow=2
         )
 
+        if detach_gced:
+            pool._is_asyncio = True
+
         @event.listens_for(pool, "checkout")
         def handle_checkout_event(dbapi_con, con_record, con_proxy):
             if getattr(dbapi_con, "boom") == "yes":
@@ -1514,8 +1547,12 @@ class QueuePoolTest(PoolTestBase):
         del conn
         gc_collect()
 
-        # new connection was reset on return appropriately
-        eq_(dbapi_conn.mock_calls, [call.rollback()])
+        if detach_gced:
+            # new connection was detached + abandoned on return
+            eq_(dbapi_conn.mock_calls, [])
+        else:
+            # new connection reset and returned to pool
+            eq_(dbapi_conn.mock_calls, [call.rollback()])
 
         # old connection was just closed - did not get an
         # erroneous reset on return
index 3810de06a5c8d304116893f7d4a70bc0151225d9..5a4220c827d12cb0b54eb8d3d5cf7583e5835389 100644 (file)
@@ -25,7 +25,7 @@ class CBooleanProcessorTest(_BooleanProcessorTest):
     __requires__ = ("cextensions",)
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         from sqlalchemy import cprocessors
 
         cls.module = cprocessors
@@ -83,7 +83,7 @@ class _DateProcessorTest(fixtures.TestBase):
 
 class PyDateProcessorTest(_DateProcessorTest):
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         from sqlalchemy import processors
 
         cls.module = type(
@@ -100,7 +100,7 @@ class CDateProcessorTest(_DateProcessorTest):
     __requires__ = ("cextensions",)
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         from sqlalchemy import cprocessors
 
         cls.module = cprocessors
@@ -185,7 +185,7 @@ class _DistillArgsTest(fixtures.TestBase):
 
 class PyDistillArgsTest(_DistillArgsTest):
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         from sqlalchemy.engine import util
 
         cls.module = type(
@@ -202,7 +202,7 @@ class CDistillArgsTest(_DistillArgsTest):
     __requires__ = ("cextensions",)
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         from sqlalchemy import cutils as util
 
         cls.module = util
index 5fe7f6cc2a72f143bb991b315cbd7bb96ad542c9..7a64b25508b6e72c5c85147a0894aad510dba498 100644 (file)
@@ -162,7 +162,7 @@ def MockDBAPI():
 
 
 class PrePingMockTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         self.dbapi = MockDBAPI()
 
     def _pool_fixture(self, pre_ping, pool_kw=None):
@@ -182,7 +182,7 @@ class PrePingMockTest(fixtures.TestBase):
         )
         return _pool
 
-    def teardown(self):
+    def teardown_test(self):
         self.dbapi.dispose()
 
     def test_ping_not_on_first_connect(self):
@@ -357,7 +357,7 @@ class PrePingMockTest(fixtures.TestBase):
 
 
 class MockReconnectTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         self.dbapi = MockDBAPI()
 
         self.db = testing_engine(
@@ -373,7 +373,7 @@ class MockReconnectTest(fixtures.TestBase):
             e, MockDisconnect
         )
 
-    def teardown(self):
+    def teardown_test(self):
         self.dbapi.dispose()
 
     def test_reconnect(self):
@@ -1004,10 +1004,10 @@ class RealReconnectTest(fixtures.TestBase):
     __backend__ = True
     __requires__ = "graceful_disconnects", "ad_hoc_engines"
 
-    def setup(self):
+    def setup_test(self):
         self.engine = engines.reconnecting_engine()
 
-    def teardown(self):
+    def teardown_test(self):
         self.engine.dispose()
 
     def test_reconnect(self):
@@ -1336,7 +1336,7 @@ class PrePingRealTest(fixtures.TestBase):
 class InvalidateDuringResultTest(fixtures.TestBase):
     __backend__ = True
 
-    def setup(self):
+    def setup_test(self):
         self.engine = engines.reconnecting_engine()
         self.meta = MetaData()
         table = Table(
@@ -1353,7 +1353,7 @@ class InvalidateDuringResultTest(fixtures.TestBase):
                 [{"id": i, "name": "row %d" % i} for i in range(1, 100)],
             )
 
-    def teardown(self):
+    def teardown_test(self):
         with self.engine.begin() as conn:
             self.meta.drop_all(conn)
         self.engine.dispose()
@@ -1470,7 +1470,7 @@ class ReconnectRecipeTest(fixtures.TestBase):
 
     __backend__ = True
 
-    def setup(self):
+    def setup_test(self):
         self.engine = engines.reconnecting_engine(
             options=dict(future=self.future)
         )
@@ -1483,7 +1483,7 @@ class ReconnectRecipeTest(fixtures.TestBase):
         )
         self.meta.create_all(self.engine)
 
-    def teardown(self):
+    def teardown_test(self):
         self.meta.drop_all(self.engine)
         self.engine.dispose()
 
index 658cdd79f02276d106b6c00a8c08c660f96d171c..0a46ddeecaa86c759fe5f073c1a4ed4ec2983d9d 100644 (file)
@@ -796,7 +796,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables):
         assert f1 in b1.constraints
         assert len(b1.constraints) == 2
 
-    def test_override_keys(self, connection, metadata):
+    def test_override_keys(self, metadata, connection):
         """test that columns can be overridden with a 'key',
         and that ForeignKey targeting during reflection still works."""
 
@@ -1375,7 +1375,7 @@ class CreateDropTest(fixtures.TablesTest):
     run_create_tables = None
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         # TablesTest is used here without
         # run_create_tables, so add an explicit drop of whatever is in
         # metadata
@@ -1658,7 +1658,6 @@ class SchemaTest(fixtures.TestBase):
     @testing.requires.schemas
     @testing.requires.cross_schema_fk_reflection
     @testing.requires.implicit_default_schema
-    @testing.provide_metadata
     def test_blank_schema_arg(self, connection, metadata):
 
         Table(
@@ -1913,7 +1912,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
     __backend__ = True
 
     @testing.requires.denormalized_names
-    def setup(self):
+    def setup_test(self):
         with testing.db.begin() as conn:
             conn.exec_driver_sql(
                 """
@@ -1926,7 +1925,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
             )
 
     @testing.requires.denormalized_names
-    def teardown(self):
+    def teardown_test(self):
         with testing.db.begin() as conn:
             conn.exec_driver_sql("drop table weird_casing")
 
index 79126fc5bb9ddc4c00192250b534f358999a5752..47504b60a32a11633f852800994b400f77366801 100644 (file)
@@ -1,6 +1,5 @@
 import sys
 
-from sqlalchemy import create_engine
 from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import func
@@ -640,12 +639,12 @@ class AutoRollbackTest(fixtures.TestBase):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         global metadata
         metadata = MetaData()
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         metadata.drop_all(testing.db)
 
     def test_rollback_deadlock(self):
@@ -871,11 +870,13 @@ class IsolationLevelTest(fixtures.TestBase):
 
     def test_per_engine(self):
         # new in 0.9
-        eng = create_engine(
+        eng = testing_engine(
             testing.db.url,
-            execution_options={
-                "isolation_level": self._non_default_isolation_level()
-            },
+            options=dict(
+                execution_options={
+                    "isolation_level": self._non_default_isolation_level()
+                }
+            ),
         )
         conn = eng.connect()
         eq_(
@@ -884,7 +885,7 @@ class IsolationLevelTest(fixtures.TestBase):
         )
 
     def test_per_option_engine(self):
-        eng = create_engine(testing.db.url).execution_options(
+        eng = testing_engine(testing.db.url).execution_options(
             isolation_level=self._non_default_isolation_level()
         )
 
@@ -895,14 +896,14 @@ class IsolationLevelTest(fixtures.TestBase):
         )
 
     def test_isolation_level_accessors_connection_default(self):
-        eng = create_engine(testing.db.url)
+        eng = testing_engine(testing.db.url)
         with eng.connect() as conn:
             eq_(conn.default_isolation_level, self._default_isolation_level())
         with eng.connect() as conn:
             eq_(conn.get_isolation_level(), self._default_isolation_level())
 
     def test_isolation_level_accessors_connection_option_modified(self):
-        eng = create_engine(testing.db.url)
+        eng = testing_engine(testing.db.url)
         with eng.connect() as conn:
             c2 = conn.execution_options(
                 isolation_level=self._non_default_isolation_level()
index 7dae1411e542ae1884ac76cc12ac85c8bd277850..59a44f8e2e1ebbab9ecc4f11d85410c7780de170 100644 (file)
@@ -269,7 +269,7 @@ class AsyncEngineTest(EngineFixture):
                 await trans.rollback(),
 
     @async_test
-    async def test_pool_exhausted(self, async_engine):
+    async def test_pool_exhausted_some_timeout(self, async_engine):
         engine = create_async_engine(
             testing.db.url,
             pool_size=1,
@@ -277,7 +277,19 @@ class AsyncEngineTest(EngineFixture):
             pool_timeout=0.1,
         )
         async with engine.connect():
-            with expect_raises(asyncio.TimeoutError):
+            with expect_raises(exc.TimeoutError):
+                await engine.connect()
+
+    @async_test
+    async def test_pool_exhausted_no_timeout(self, async_engine):
+        engine = create_async_engine(
+            testing.db.url,
+            pool_size=1,
+            max_overflow=0,
+            pool_timeout=0,
+        )
+        async with engine.connect():
+            with expect_raises(exc.TimeoutError):
                 await engine.connect()
 
     @async_test
index 2b80b753ebe6115884b96fe6c70d16367e09e3f0..e25e7cfc292faef2169ee63a506bb4d3bdd2309e 100644 (file)
@@ -27,11 +27,11 @@ Base = None
 
 
 class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
-    def setup(self):
+    def setup_test(self):
         global Base
         Base = decl.declarative_base(testing.db)
 
-    def teardown(self):
+    def teardown_test(self):
         close_all_sessions()
         clear_mappers()
         Base.metadata.drop_all(testing.db)
index d7fcbf9e8e578135321bc88bc4819630043563e8..c327de7d4ff5804c570f913b541d9f7cb09af862 100644 (file)
@@ -4,7 +4,6 @@ from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.ext.declarative import DeferredReflection
 from sqlalchemy.orm import clear_mappers
-from sqlalchemy.orm import create_session
 from sqlalchemy.orm import decl_api as decl
 from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import exc as orm_exc
@@ -14,6 +13,7 @@ from sqlalchemy.orm.decl_base import _DeferredMapperConfig
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.util import gc_collect
@@ -22,20 +22,19 @@ from sqlalchemy.testing.util import gc_collect
 class DeclarativeReflectionBase(fixtures.TablesTest):
     __requires__ = ("reflectable_autoincrement",)
 
-    def setup(self):
+    def setup_test(self):
         global Base, registry
 
         registry = decl.registry()
         Base = registry.generate_base()
 
-    def teardown(self):
-        super(DeclarativeReflectionBase, self).teardown()
+    def teardown_test(self):
         clear_mappers()
 
 
 class DeferredReflectBase(DeclarativeReflectionBase):
-    def teardown(self):
-        super(DeferredReflectBase, self).teardown()
+    def teardown_test(self):
+        super(DeferredReflectBase, self).teardown_test()
         _DeferredMapperConfig._configs.clear()
 
 
@@ -101,22 +100,23 @@ class DeferredReflectionTest(DeferredReflectBase):
         u1 = User(
             name="u1", addresses=[Address(email="one"), Address(email="two")]
         )
-        sess = create_session(testing.db)
-        sess.add(u1)
-        sess.flush()
-        sess.expunge_all()
-        eq_(
-            sess.query(User).all(),
-            [
-                User(
-                    name="u1",
-                    addresses=[Address(email="one"), Address(email="two")],
-                )
-            ],
-        )
-        a1 = sess.query(Address).filter(Address.email == "two").one()
-        eq_(a1, Address(email="two"))
-        eq_(a1.user, User(name="u1"))
+        with fixture_session() as sess:
+            sess.add(u1)
+            sess.commit()
+
+        with fixture_session() as sess:
+            eq_(
+                sess.query(User).all(),
+                [
+                    User(
+                        name="u1",
+                        addresses=[Address(email="one"), Address(email="two")],
+                    )
+                ],
+            )
+            a1 = sess.query(Address).filter(Address.email == "two").one()
+            eq_(a1, Address(email="two"))
+            eq_(a1.user, User(name="u1"))
 
     def test_exception_prepare_not_called(self):
         class User(DeferredReflection, fixtures.ComparableEntity, Base):
@@ -191,15 +191,25 @@ class DeferredReflectionTest(DeferredReflectBase):
                 return {"primary_key": cls.__table__.c.id}
 
         DeferredReflection.prepare(testing.db)
-        sess = Session(testing.db)
-        sess.add_all(
-            [User(name="G"), User(name="Q"), User(name="A"), User(name="C")]
-        )
-        sess.commit()
-        eq_(
-            sess.query(User).order_by(User.name).all(),
-            [User(name="A"), User(name="C"), User(name="G"), User(name="Q")],
-        )
+        with fixture_session() as sess:
+            sess.add_all(
+                [
+                    User(name="G"),
+                    User(name="Q"),
+                    User(name="A"),
+                    User(name="C"),
+                ]
+            )
+            sess.commit()
+            eq_(
+                sess.query(User).order_by(User.name).all(),
+                [
+                    User(name="A"),
+                    User(name="C"),
+                    User(name="G"),
+                    User(name="Q"),
+                ],
+            )
 
     @testing.requires.predictable_gc
     def test_cls_not_strong_ref(self):
@@ -255,14 +265,14 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase):
 
         u1 = User(name="u1", items=[Item(name="i1"), Item(name="i2")])
 
-        sess = Session(testing.db)
-        sess.add(u1)
-        sess.commit()
+        with fixture_session() as sess:
+            sess.add(u1)
+            sess.commit()
 
-        eq_(
-            sess.query(User).all(),
-            [User(name="u1", items=[Item(name="i1"), Item(name="i2")])],
-        )
+            eq_(
+                sess.query(User).all(),
+                [User(name="u1", items=[Item(name="i1"), Item(name="i2")])],
+            )
 
     def test_string_resolution(self):
         class User(DeferredReflection, fixtures.ComparableEntity, Base):
@@ -296,27 +306,26 @@ class DeferredInhReflectBase(DeferredReflectBase):
         Foo = Base.registry._class_registry["Foo"]
         Bar = Base.registry._class_registry["Bar"]
 
-        s = Session(testing.db)
-
-        s.add_all(
-            [
-                Bar(data="d1", bar_data="b1"),
-                Bar(data="d2", bar_data="b2"),
-                Bar(data="d3", bar_data="b3"),
-                Foo(data="d4"),
-            ]
-        )
-        s.commit()
-
-        eq_(
-            s.query(Foo).order_by(Foo.id).all(),
-            [
-                Bar(data="d1", bar_data="b1"),
-                Bar(data="d2", bar_data="b2"),
-                Bar(data="d3", bar_data="b3"),
-                Foo(data="d4"),
-            ],
-        )
+        with fixture_session() as s:
+            s.add_all(
+                [
+                    Bar(data="d1", bar_data="b1"),
+                    Bar(data="d2", bar_data="b2"),
+                    Bar(data="d3", bar_data="b3"),
+                    Foo(data="d4"),
+                ]
+            )
+            s.commit()
+
+            eq_(
+                s.query(Foo).order_by(Foo.id).all(),
+                [
+                    Bar(data="d1", bar_data="b1"),
+                    Bar(data="d2", bar_data="b2"),
+                    Bar(data="d3", bar_data="b3"),
+                    Foo(data="d4"),
+                ],
+            )
 
 
 class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
index b1f5cc956f00c307af276fb5f4911b2662c72a73..31ae050c1150fa6ca3108a719606173ae2410952 100644 (file)
@@ -101,7 +101,7 @@ class AutoFlushTest(fixtures.TablesTest):
             Column("name", String(50)),
         )
 
-    def teardown(self):
+    def teardown_test(self):
         clear_mappers()
 
     def _fixture(self, collection_class, is_dict=False):
@@ -198,7 +198,7 @@ class AutoFlushTest(fixtures.TablesTest):
 
 
 class _CollectionOperations(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         collection_class = self.collection_class
 
         metadata = MetaData()
@@ -260,7 +260,7 @@ class _CollectionOperations(fixtures.TestBase):
         self.session = fixture_session()
         self.Parent, self.Child = Parent, Child
 
-    def teardown(self):
+    def teardown_test(self):
         self.metadata.drop_all(testing.db)
 
     def roundtrip(self, obj):
@@ -885,7 +885,7 @@ class CustomObjectTest(_CollectionOperations):
 
 
 class ProxyFactoryTest(ListTest):
-    def setup(self):
+    def setup_test(self):
         metadata = MetaData()
 
         parents_table = Table(
@@ -1157,7 +1157,7 @@ class ScalarTest(fixtures.TestBase):
 
 
 class LazyLoadTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         metadata = MetaData()
 
         parents_table = Table(
@@ -1197,7 +1197,7 @@ class LazyLoadTest(fixtures.TestBase):
         self.Parent, self.Child = Parent, Child
         self.table = parents_table
 
-    def teardown(self):
+    def teardown_test(self):
         self.metadata.drop_all(testing.db)
 
     def roundtrip(self, obj):
@@ -2294,7 +2294,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL):
 
 
 class DictOfTupleUpdateTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         class B(object):
             def __init__(self, key, elem):
                 self.key = key
@@ -2434,7 +2434,7 @@ class CompositeAccessTest(fixtures.DeclarativeMappedTest):
 
 
 class AttributeAccessTest(fixtures.TestBase):
-    def teardown(self):
+    def teardown_test(self):
         clear_mappers()
 
     def test_resolve_aliased_class(self):
index 71fabc629f5ddea6bf4bbb8bba8c2edb2a8b2ba5..2d4e9848e571e57786a9a886868c71ca47112bdf 100644 (file)
@@ -27,7 +27,7 @@ class BakedTest(_fixtures.FixtureTest):
     run_inserts = "once"
     run_deletes = None
 
-    def setup(self):
+    def setup_test(self):
         self.bakery = baked.bakery()
 
 
index 058c1dfd77edffbfe45ab2d398151b1504aa95f4..d011417d77f4bb3b957d18e07f478bcf2d25fc9e 100644 (file)
@@ -426,7 +426,7 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL):
 
     __dialect__ = "default"
 
-    def teardown(self):
+    def teardown_test(self):
         for cls in (Select, BindParameter):
             deregister(cls)
 
index ad9bf0bc05005085b4a7f0f59572befd68f4656d..f3eceb0dca5f2ef6d8512fba2480eeb0989103c8 100644 (file)
@@ -32,7 +32,7 @@ def modifies_instrumentation_finders(fn, *args, **kw):
 
 class _ExtBase(object):
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         instrumentation._reinstall_default_lookups()
 
 
@@ -89,7 +89,7 @@ MyBaseClass, MyClass = None, None
 
 class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         global MyBaseClass, MyClass
 
         class MyBaseClass(object):
@@ -143,7 +143,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
                 else:
                     del self._goofy_dict[key]
 
-    def teardown(self):
+    def teardown_test(self):
         clear_mappers()
 
     def test_instance_dict(self):
index 038bdd83e1b499b6dee45d98bded965612a57cd6..bb06d9648ba905ca7badaea2e47b9d7c445bc3a7 100644 (file)
@@ -19,7 +19,6 @@ from sqlalchemy import update
 from sqlalchemy import util
 from sqlalchemy.ext.horizontal_shard import ShardedSession
 from sqlalchemy.orm import clear_mappers
-from sqlalchemy.orm import create_session
 from sqlalchemy.orm import deferred
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
@@ -42,7 +41,7 @@ class ShardTest(object):
 
     schema = None
 
-    def setUp(self):
+    def setup_test(self):
         global db1, db2, db3, db4, weather_locations, weather_reports
 
         db1, db2, db3, db4 = self._dbs = self._init_dbs()
@@ -88,7 +87,7 @@ class ShardTest(object):
 
     @classmethod
     def setup_session(cls):
-        global create_session
+        global sharded_session
         shard_lookup = {
             "North America": "north_america",
             "Asia": "asia",
@@ -128,10 +127,10 @@ class ShardTest(object):
             else:
                 return ids
 
-        create_session = sessionmaker(
+        sharded_session = sessionmaker(
             class_=ShardedSession, autoflush=True, autocommit=False
         )
-        create_session.configure(
+        sharded_session.configure(
             shards={
                 "north_america": db1,
                 "asia": db2,
@@ -180,7 +179,7 @@ class ShardTest(object):
         tokyo.reports.append(Report(80.0, id_=1))
         newyork.reports.append(Report(75, id_=1))
         quito.reports.append(Report(85))
-        sess = create_session(future=True)
+        sess = sharded_session(future=True)
         for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
             sess.add(c)
         sess.flush()
@@ -671,11 +670,10 @@ class DistinctEngineShardTest(ShardTest, fixtures.TestBase):
         self.dbs = [db1, db2, db3, db4]
         return self.dbs
 
-    def teardown(self):
+    def teardown_test(self):
         clear_mappers()
 
-        for db in self.dbs:
-            db.connect().invalidate()
+        testing_reaper.checkin_all()
         for i in range(1, 5):
             os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
 
@@ -702,10 +700,10 @@ class AttachedFileShardTest(ShardTest, fixtures.TestBase):
         self.engine = e
         return db1, db2, db3, db4
 
-    def teardown(self):
+    def teardown_test(self):
         clear_mappers()
 
-        self.engine.connect().invalidate()
+        testing_reaper.checkin_all()
         for i in range(1, 5):
             os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
 
@@ -778,10 +776,13 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
         self.postgresql_engine = e2
         return db1, db2, db3, db4
 
-    def teardown(self):
+    def teardown_test(self):
         clear_mappers()
 
-        self.sqlite_engine.connect().invalidate()
+        # the tests in this suite don't cleanly close out the Session
+        # at the moment so use the reaper to close all connections
+        testing_reaper.checkin_all()
+
         for i in [1, 3]:
             os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
 
@@ -789,6 +790,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
             self.tables_test_metadata.drop_all(conn)
             for i in [2, 4]:
                 conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,))
+        self.postgresql_engine.dispose()
 
 
 class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest):
@@ -904,11 +906,11 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest):
 
         return self.dbs
 
-    def teardown(self):
+    def teardown_test(self):
         for db in self.dbs:
             db.connect().invalidate()
 
-        testing_reaper.close_all()
+        testing_reaper.checkin_all()
         for i in range(1, 3):
             os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
 
index 048a8b52d16a715831bbdb2bcc671f82756805ee..3bab7db934b4283fd8ffea104c7bd6df4e8ba379 100644 (file)
@@ -697,7 +697,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         from sqlalchemy import literal
 
         symbols = ("usd", "gbp", "cad", "eur", "aud")
index eba2ac0cbb5343ee745440c3dbaa2257c923126a..21244de73d6891c6c34204d64e4692dcea3a3586 100644 (file)
@@ -90,11 +90,10 @@ class _MutableDictTestFixture(object):
     def _type_fixture(cls):
         return MutableDict
 
-    def teardown(self):
+    def teardown_test(self):
         # clear out mapper events
         Mapper.dispatch._clear()
         ClassManager.dispatch._clear()
-        super(_MutableDictTestFixture, self).teardown()
 
 
 class _MutableDictTestBase(_MutableDictTestFixture):
@@ -312,11 +311,10 @@ class _MutableListTestFixture(object):
     def _type_fixture(cls):
         return MutableList
 
-    def teardown(self):
+    def teardown_test(self):
         # clear out mapper events
         Mapper.dispatch._clear()
         ClassManager.dispatch._clear()
-        super(_MutableListTestFixture, self).teardown()
 
 
 class _MutableListTestBase(_MutableListTestFixture):
@@ -619,11 +617,10 @@ class _MutableSetTestFixture(object):
     def _type_fixture(cls):
         return MutableSet
 
-    def teardown(self):
+    def teardown_test(self):
         # clear out mapper events
         Mapper.dispatch._clear()
         ClassManager.dispatch._clear()
-        super(_MutableSetTestFixture, self).teardown()
 
 
 class _MutableSetTestBase(_MutableSetTestFixture):
@@ -1234,17 +1231,15 @@ class _CompositeTestBase(object):
             Column("unrelated_data", String(50)),
         )
 
-    def setup(self):
+    def setup_test(self):
         from sqlalchemy.ext import mutable
 
         mutable._setup_composite_listener()
-        super(_CompositeTestBase, self).setup()
 
-    def teardown(self):
+    def teardown_test(self):
         # clear out mapper events
         Mapper.dispatch._clear()
         ClassManager.dispatch._clear()
-        super(_CompositeTestBase, self).teardown()
 
     @classmethod
     def _type_fixture(cls):
index f23d6cb576307e117a72c3c94c16e63ee894d4bc..280fad6cf09446798188b9ef7aab2c7f687d1ed9 100644 (file)
@@ -8,7 +8,7 @@ from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
-from sqlalchemy.testing.fixtures import create_session
+from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.util import picklers
@@ -60,7 +60,7 @@ def alpha_ordering(index, collection):
 
 
 class OrderingListTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         global metadata, slides_table, bullets_table, Slide, Bullet
         slides_table, bullets_table = None, None
         Slide, Bullet = None, None
@@ -122,7 +122,7 @@ class OrderingListTest(fixtures.TestBase):
 
         metadata.create_all(testing.db)
 
-    def teardown(self):
+    def teardown_test(self):
         metadata.drop_all(testing.db)
 
     def test_append_no_reorder(self):
@@ -167,7 +167,7 @@ class OrderingListTest(fixtures.TestBase):
         self.assert_(s1.bullets[2].position == 3)
         self.assert_(s1.bullets[3].position == 4)
 
-        session = create_session()
+        session = fixture_session()
         session.add(s1)
         session.flush()
 
@@ -232,7 +232,7 @@ class OrderingListTest(fixtures.TestBase):
 
         s1.bullets._reorder()
         self.assert_(s1.bullets[4].position == 5)
-        session = create_session()
+        session = fixture_session()
         session.add(s1)
         session.flush()
 
@@ -289,7 +289,7 @@ class OrderingListTest(fixtures.TestBase):
         self.assert_(len(s1.bullets) == 6)
         self.assert_(s1.bullets[5].position == 5)
 
-        session = create_session()
+        session = fixture_session()
         session.add(s1)
         session.flush()
 
@@ -338,7 +338,7 @@ class OrderingListTest(fixtures.TestBase):
             self.assert_(s1.bullets[li].position == li)
             self.assert_(s1.bullets[li] == b[bi])
 
-        session = create_session()
+        session = fixture_session()
         session.add(s1)
         session.flush()
 
@@ -365,7 +365,7 @@ class OrderingListTest(fixtures.TestBase):
         self.assert_(len(s1.bullets) == 3)
         self.assert_(s1.bullets[2].position == 2)
 
-        session = create_session()
+        session = fixture_session()
         session.add(s1)
         session.flush()
 
index 4c005d336c964bd18eb12d7fda3f2a8eb007f773..4d9162105b81ecc821aae8166848b91ac698e64d 100644 (file)
@@ -61,11 +61,11 @@ class DeclarativeTestBase(
 ):
     __dialect__ = "default"
 
-    def setup(self):
+    def setup_test(self):
         global Base
         Base = declarative_base(testing.db)
 
-    def teardown(self):
+    def teardown_test(self):
         close_all_sessions()
         clear_mappers()
         Base.metadata.drop_all(testing.db)
index 5f12d82723d771272c15eb79bd702263aafc52ba..ecddc2e5fa2a5ed3b19422533ad87260227c93b9 100644 (file)
@@ -17,7 +17,7 @@ from sqlalchemy.testing.fixtures import fixture_session
 
 
 class ConcurrentUseDeclMappingTest(fixtures.TestBase):
-    def teardown(self):
+    def teardown_test(self):
         clear_mappers()
 
     @classmethod
index cc29cab7deebcbb15f5e462f8d50df9057e24a5e..e09b1570e27743fcf3b7694fb85ad42694dc321b 100644 (file)
@@ -27,11 +27,11 @@ Base = None
 
 
 class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
-    def setup(self):
+    def setup_test(self):
         global Base
         Base = decl.declarative_base(testing.db)
 
-    def teardown(self):
+    def teardown_test(self):
         close_all_sessions()
         clear_mappers()
         Base.metadata.drop_all(testing.db)
index 631527daf92c7034d1fca2c38eef863756d2be41..ad4832c35737fe4d4ae68387843f4fffc2e75613 100644 (file)
@@ -38,13 +38,13 @@ mapper_registry = None
 
 
 class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
-    def setup(self):
+    def setup_test(self):
         global Base, mapper_registry
 
         mapper_registry = registry(metadata=MetaData())
         Base = mapper_registry.generate_base()
 
-    def teardown(self):
+    def teardown_test(self):
         close_all_sessions()
         clear_mappers()
         with testing.db.begin() as conn:
index 241528c44ed241c55c48b7af72bfbc3cc3fd7c5d..e7b2a70588b44ea974f28e45512d6c2c3591368d 100644 (file)
@@ -17,14 +17,13 @@ from sqlalchemy.testing.schema import Table
 class DeclarativeReflectionBase(fixtures.TablesTest):
     __requires__ = ("reflectable_autoincrement",)
 
-    def setup(self):
+    def setup_test(self):
         global Base, registry
 
         registry = decl.registry(metadata=MetaData())
         Base = registry.generate_base()
 
-    def teardown(self):
-        super(DeclarativeReflectionBase, self).teardown()
+    def teardown_test(self):
         clear_mappers()
 
 
index bdcdedc44e4383713fbe1e42a55df5d70906fcca..da07b4941b23daa906d221ffee2fb2dc793b5ffe 100644 (file)
@@ -31,7 +31,6 @@ from sqlalchemy.orm import synonym
 from sqlalchemy.orm.util import instance_str
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
-from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -1889,7 +1888,6 @@ class VersioningTest(fixtures.MappedTest):
 
     @testing.emits_warning(r".*updated rowcount")
     @testing.requires.sane_rowcount_w_returning
-    @engines.close_open_connections
     def test_save_update(self):
         subtable, base, stuff = (
             self.tables.subtable,
@@ -2927,7 +2925,7 @@ class NoPKOnSubTableWarningTest(fixtures.TestBase):
         )
         return parent, child
 
-    def tearDown(self):
+    def teardown_test(self):
         clear_mappers()
 
     def test_warning_on_sub(self):
@@ -3417,27 +3415,26 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest):
     @classmethod
     def insert_data(cls, connection):
         Parent, A, B = cls.classes("Parent", "A", "B")
-        s = fixture_session()
-
-        p1 = Parent(id=1)
-        p2 = Parent(id=2)
-        s.add_all([p1, p2])
-        s.flush()
+        with Session(connection) as s:
+            p1 = Parent(id=1)
+            p2 = Parent(id=2)
+            s.add_all([p1, p2])
+            s.flush()
 
-        s.add_all(
-            [
-                A(id=1, parent_id=1),
-                B(id=2, parent_id=1),
-                A(id=3, parent_id=1),
-                B(id=4, parent_id=1),
-            ]
-        )
-        s.flush()
+            s.add_all(
+                [
+                    A(id=1, parent_id=1),
+                    B(id=2, parent_id=1),
+                    A(id=3, parent_id=1),
+                    B(id=4, parent_id=1),
+                ]
+            )
+            s.flush()
 
-        s.query(A).filter(A.id.in_([3, 4])).update(
-            {A.type: None}, synchronize_session=False
-        )
-        s.commit()
+            s.query(A).filter(A.id.in_([3, 4])).update(
+                {A.type: None}, synchronize_session=False
+            )
+            s.commit()
 
     def test_pk_is_null(self):
         Parent, A = self.classes("Parent", "A")
@@ -3527,10 +3524,12 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest):
         ASingleSubA, ASingleSubB, AJoinedSubA, AJoinedSubB = cls.classes(
             "ASingleSubA", "ASingleSubB", "AJoinedSubA", "AJoinedSubB"
         )
-        s = fixture_session()
+        with Session(connection) as s:
 
-        s.add_all([ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()])
-        s.commit()
+            s.add_all(
+                [ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()]
+            )
+            s.commit()
 
     def test_single_invalid_ident(self):
         ASingle, ASingleSubA = self.classes("ASingle", "ASingleSubA")
index 8820aa6a459b0ab1b0c9d4b909c4be0157ed06fa..0a0a5d12b75ffb5de943fa82b8395d87df42e0bd 100644 (file)
@@ -209,7 +209,7 @@ class AttributeImplAPITest(fixtures.MappedTest):
 
 
 class AttributesTest(fixtures.ORMTest):
-    def setup(self):
+    def setup_test(self):
         global MyTest, MyTest2
 
         class MyTest(object):
@@ -218,7 +218,7 @@ class AttributesTest(fixtures.ORMTest):
         class MyTest2(object):
             pass
 
-    def teardown(self):
+    def teardown_test(self):
         global MyTest, MyTest2
         MyTest, MyTest2 = None, None
 
@@ -3690,7 +3690,7 @@ class EventPropagateTest(fixtures.TestBase):
 
 
 class CollectionInitTest(fixtures.TestBase):
-    def setUp(self):
+    def setup_test(self):
         class A(object):
             pass
 
@@ -3749,7 +3749,7 @@ class CollectionInitTest(fixtures.TestBase):
 
 
 class TestUnlink(fixtures.TestBase):
-    def setUp(self):
+    def setup_test(self):
         class A(object):
             pass
 
index 2f54f7fff0ed771c1dc832f88107b1a3e7bb8e21..014fa152e9071dbe47ad72cd6ff3518abdae88df 100644 (file)
@@ -428,39 +428,46 @@ class BindIntegrationTest(_fixtures.FixtureTest):
         User, users = self.classes.User, self.tables.users
 
         mapper(User, users)
-        c = testing.db.connect()
-
-        sess = Session(bind=c, autocommit=False)
-        u = User(name="u1")
-        sess.add(u)
-        sess.flush()
-        sess.close()
-        assert not c.in_transaction()
-        assert c.exec_driver_sql("select count(1) from users").scalar() == 0
-
-        sess = Session(bind=c, autocommit=False)
-        u = User(name="u2")
-        sess.add(u)
-        sess.flush()
-        sess.commit()
-        assert not c.in_transaction()
-        assert c.exec_driver_sql("select count(1) from users").scalar() == 1
-
-        with c.begin():
-            c.exec_driver_sql("delete from users")
-        assert c.exec_driver_sql("select count(1) from users").scalar() == 0
-
-        c = testing.db.connect()
-
-        trans = c.begin()
-        sess = Session(bind=c, autocommit=True)
-        u = User(name="u3")
-        sess.add(u)
-        sess.flush()
-        assert c.in_transaction()
-        trans.commit()
-        assert not c.in_transaction()
-        assert c.exec_driver_sql("select count(1) from users").scalar() == 1
+        with testing.db.connect() as c:
+
+            sess = Session(bind=c, autocommit=False)
+            u = User(name="u1")
+            sess.add(u)
+            sess.flush()
+            sess.close()
+            assert not c.in_transaction()
+            assert (
+                c.exec_driver_sql("select count(1) from users").scalar() == 0
+            )
+
+            sess = Session(bind=c, autocommit=False)
+            u = User(name="u2")
+            sess.add(u)
+            sess.flush()
+            sess.commit()
+            assert not c.in_transaction()
+            assert (
+                c.exec_driver_sql("select count(1) from users").scalar() == 1
+            )
+
+            with c.begin():
+                c.exec_driver_sql("delete from users")
+            assert (
+                c.exec_driver_sql("select count(1) from users").scalar() == 0
+            )
+
+        with testing.db.connect() as c:
+            trans = c.begin()
+            sess = Session(bind=c, autocommit=True)
+            u = User(name="u3")
+            sess.add(u)
+            sess.flush()
+            assert c.in_transaction()
+            trans.commit()
+            assert not c.in_transaction()
+            assert (
+                c.exec_driver_sql("select count(1) from users").scalar() == 1
+            )
 
 
 class SessionBindTest(fixtures.MappedTest):
@@ -506,6 +513,7 @@ class SessionBindTest(fixtures.MappedTest):
             finally:
                 if hasattr(bind, "close"):
                     bind.close()
+                sess.close()
 
     def test_session_unbound(self):
         Foo = self.classes.Foo
index 3d09bd44605dfbae902aea7c516d7d913a1fd05d..2a0aafbbcc63ca9106ce53052a1a716c45ea0cb3 100644 (file)
@@ -92,13 +92,12 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest):
             return str((id(self), self.a, self.b, self.c))
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         instrumentation.register_class(cls.Entity)
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         instrumentation.unregister_class(cls.Entity)
-        super(CollectionsTest, cls).teardown_class()
 
     _entity_id = 1
 
index df652daf45e7fa0becff7c9869de23da74778585..20d8ecc2db1ecf5ff163f25036b07861575cb0de 100644 (file)
@@ -19,7 +19,7 @@ from sqlalchemy.testing import fixtures
 class CompileTest(fixtures.ORMTest):
     """test various mapper compilation scenarios"""
 
-    def teardown(self):
+    def teardown_test(self):
         clear_mappers()
 
     def test_with_polymorphic(self):
index e1ef67fed61b3c64a8409ece22eef15e41e96493..ed11b89c9b9bc3a63469d5f5f0932776ba982565 100644 (file)
@@ -1743,8 +1743,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest):
             id = Column(Integer, primary_key=True)
             a_id = Column(ForeignKey("a.id", name="a_fk"))
 
-    def setup(self):
-        super(PostUpdateOnUpdateTest, self).setup()
+    def setup_test(self):
         PostUpdateOnUpdateTest.counter = count()
         PostUpdateOnUpdateTest.db_counter = count()
 
index 6d946cfe6e57cbcb6ef096e063626eaa48c01ca7..15063ebe929fd46e6e3d348ab09e525a32ba65c0 100644 (file)
@@ -2199,6 +2199,7 @@ class SessionTest(fixtures.RemovesEvents, _LocalFixture):
 
 class AutocommitClosesOnFailTest(fixtures.MappedTest):
     __requires__ = ("deferrable_fks",)
+    __only_on__ = ("postgresql+psycopg2",)  # needs #5824 for asyncpg
 
     @classmethod
     def define_tables(cls, metadata):
@@ -4498,44 +4499,49 @@ class JoinTest(QueryTest, AssertsCompiledSQL):
                 warnings += (join_aliased_dep,)
             # load a user who has an order that contains item id 3 and address
             # id 1 (order 3, owned by jack)
-            with testing.expect_deprecated_20(*warnings):
-                result = (
-                    fixture_session()
-                    .query(User)
-                    .join("orders", "items", aliased=aliased_)
-                    .filter_by(id=3)
-                    .reset_joinpoint()
-                    .join("orders", "address", aliased=aliased_)
-                    .filter_by(id=1)
-                    .all()
-                )
-            assert [User(id=7, name="jack")] == result
 
-            with testing.expect_deprecated_20(*warnings):
-                result = (
-                    fixture_session()
-                    .query(User)
-                    .join("orders", "items", aliased=aliased_, isouter=True)
-                    .filter_by(id=3)
-                    .reset_joinpoint()
-                    .join("orders", "address", aliased=aliased_, isouter=True)
-                    .filter_by(id=1)
-                    .all()
-                )
-            assert [User(id=7, name="jack")] == result
-
-            with testing.expect_deprecated_20(*warnings):
-                result = (
-                    fixture_session()
-                    .query(User)
-                    .outerjoin("orders", "items", aliased=aliased_)
-                    .filter_by(id=3)
-                    .reset_joinpoint()
-                    .outerjoin("orders", "address", aliased=aliased_)
-                    .filter_by(id=1)
-                    .all()
-                )
-            assert [User(id=7, name="jack")] == result
+            with fixture_session() as sess:
+                with testing.expect_deprecated_20(*warnings):
+                    result = (
+                        sess.query(User)
+                        .join("orders", "items", aliased=aliased_)
+                        .filter_by(id=3)
+                        .reset_joinpoint()
+                        .join("orders", "address", aliased=aliased_)
+                        .filter_by(id=1)
+                        .all()
+                    )
+                assert [User(id=7, name="jack")] == result
+
+            with fixture_session() as sess:
+                with testing.expect_deprecated_20(*warnings):
+                    result = (
+                        sess.query(User)
+                        .join(
+                            "orders", "items", aliased=aliased_, isouter=True
+                        )
+                        .filter_by(id=3)
+                        .reset_joinpoint()
+                        .join(
+                            "orders", "address", aliased=aliased_, isouter=True
+                        )
+                        .filter_by(id=1)
+                        .all()
+                    )
+                assert [User(id=7, name="jack")] == result
+
+            with fixture_session() as sess:
+                with testing.expect_deprecated_20(*warnings):
+                    result = (
+                        sess.query(User)
+                        .outerjoin("orders", "items", aliased=aliased_)
+                        .filter_by(id=3)
+                        .reset_joinpoint()
+                        .outerjoin("orders", "address", aliased=aliased_)
+                        .filter_by(id=1)
+                        .all()
+                    )
+                assert [User(id=7, name="jack")] == result
 
 
 class AliasFromCorrectLeftTest(
index 4498fc1ff92f6ef3754fe705666527b865279df6..7eedb37c92879c5224c5f4e73ec719c257fbed96 100644 (file)
@@ -559,15 +559,15 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
                 5,
             ),
         ]:
-            sess = fixture_session()
+            with fixture_session() as sess:
 
-            def go():
-                eq_(
-                    sess.query(User).options(*opt).order_by(User.id).all(),
-                    self.static.user_item_keyword_result,
-                )
+                def go():
+                    eq_(
+                        sess.query(User).options(*opt).order_by(User.id).all(),
+                        self.static.user_item_keyword_result,
+                    )
 
-            self.assert_sql_count(testing.db, go, count)
+                self.assert_sql_count(testing.db, go, count)
 
     def test_disable_dynamic(self):
         """test no joined option on a dynamic."""
index 1c918a88cd09aa9b75b2ce8ad7a020742510d7a8..e85c23d6f6d831962ac02e04724aee668c37fb99 100644 (file)
@@ -45,13 +45,14 @@ from test.orm import _fixtures
 
 
 class _RemoveListeners(object):
-    def teardown(self):
+    @testing.fixture(autouse=True)
+    def _remove_listeners(self):
+        yield
         events.MapperEvents._clear()
         events.InstanceEvents._clear()
         events.SessionEvents._clear()
         events.InstrumentationEvents._clear()
         events.QueryEvents._clear()
-        super(_RemoveListeners, self).teardown()
 
 
 class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
@@ -1174,7 +1175,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest):
             argnames="target, event_name, fn",
         )(fn)
 
-    def teardown(self):
+    def teardown_test(self):
         A = self.classes.A
         A._sa_class_manager.dispatch._clear()
 
index cc959646644504bc5c06c23d6b2a707c7de09930..f622bff0250ba7791794309f851051ded71731f0 100644 (file)
@@ -2395,16 +2395,19 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
         ]
 
         adalias = addresses.alias()
-        q = (
-            fixture_session()
-            .query(User)
-            .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name))
-            .outerjoin(adalias, "addresses")
-            .group_by(users)
-            .order_by(users.c.id)
-        )
 
-        assert q.all() == expected
+        with fixture_session() as sess:
+            q = (
+                sess.query(User)
+                .add_columns(
+                    func.count(adalias.c.id), ("Name:" + users.c.name)
+                )
+                .outerjoin(adalias, "addresses")
+                .group_by(users)
+                .order_by(users.c.id)
+            )
+
+            eq_(q.all(), expected)
 
         # test with a straight statement
         s = (
@@ -2417,52 +2420,57 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
             .group_by(*[c for c in users.c])
             .order_by(users.c.id)
         )
-        q = fixture_session().query(User)
-        result = (
-            q.add_columns(s.selected_columns.count, s.selected_columns.concat)
-            .from_statement(s)
-            .all()
-        )
-        assert result == expected
-
-        sess.expunge_all()
 
-        # test with select_entity_from()
-        q = (
-            fixture_session()
-            .query(User)
-            .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name))
-            .select_entity_from(users.outerjoin(addresses))
-            .group_by(users)
-            .order_by(users.c.id)
-        )
+        with fixture_session() as sess:
+            q = sess.query(User)
+            result = (
+                q.add_columns(
+                    s.selected_columns.count, s.selected_columns.concat
+                )
+                .from_statement(s)
+                .all()
+            )
+            eq_(result, expected)
 
-        assert q.all() == expected
-        sess.expunge_all()
+        with fixture_session() as sess:
+            # test with select_entity_from()
+            q = (
+                fixture_session()
+                .query(User)
+                .add_columns(
+                    func.count(addresses.c.id), ("Name:" + users.c.name)
+                )
+                .select_entity_from(users.outerjoin(addresses))
+                .group_by(users)
+                .order_by(users.c.id)
+            )
 
-        q = (
-            fixture_session()
-            .query(User)
-            .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name))
-            .outerjoin("addresses")
-            .group_by(users)
-            .order_by(users.c.id)
-        )
+            eq_(q.all(), expected)
 
-        assert q.all() == expected
-        sess.expunge_all()
+        with fixture_session() as sess:
+            q = (
+                sess.query(User)
+                .add_columns(
+                    func.count(addresses.c.id), ("Name:" + users.c.name)
+                )
+                .outerjoin("addresses")
+                .group_by(users)
+                .order_by(users.c.id)
+            )
+            eq_(q.all(), expected)
 
-        q = (
-            fixture_session()
-            .query(User)
-            .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name))
-            .outerjoin(adalias, "addresses")
-            .group_by(users)
-            .order_by(users.c.id)
-        )
+        with fixture_session() as sess:
+            q = (
+                sess.query(User)
+                .add_columns(
+                    func.count(adalias.c.id), ("Name:" + users.c.name)
+                )
+                .outerjoin(adalias, "addresses")
+                .group_by(users)
+                .order_by(users.c.id)
+            )
 
-        assert q.all() == expected
-        sess.expunge_all()
+            eq_(q.all(), expected)
 
     def test_expression_selectable_matches_mzero(self):
         User, Address = self.classes.User, self.classes.Address
index 3061de309b2bef4d1580a6c2bf73e43b9e1c530f..43cf81e6d4159c75feb1ca166236994706402995 100644 (file)
@@ -717,24 +717,24 @@ class LazyTest(_fixtures.FixtureTest):
                 ),
             )
 
-            sess = fixture_session()
+            with fixture_session() as sess:
 
-            # load address
-            a1 = (
-                sess.query(Address)
-                .filter_by(email_address="ed@wood.com")
-                .one()
-            )
+                # load address
+                a1 = (
+                    sess.query(Address)
+                    .filter_by(email_address="ed@wood.com")
+                    .one()
+                )
 
-            # load user that is attached to the address
-            u1 = sess.query(User).get(8)
+                # load user that is attached to the address
+                u1 = sess.query(User).get(8)
 
-            def go():
-                # lazy load of a1.user should get it from the session
-                assert a1.user is u1
+                def go():
+                    # lazy load of a1.user should get it from the session
+                    assert a1.user is u1
 
-            self.assert_sql_count(testing.db, go, 0)
-            sa.orm.clear_mappers()
+                self.assert_sql_count(testing.db, go, 0)
+                sa.orm.clear_mappers()
 
     def test_uses_get_compatible_types(self):
         """test the use_get optimization with compatible
@@ -789,24 +789,23 @@ class LazyTest(_fixtures.FixtureTest):
                 properties=dict(user=relationship(mapper(User, users))),
             )
 
-            sess = fixture_session()
-
-            # load address
-            a1 = (
-                sess.query(Address)
-                .filter_by(email_address="ed@wood.com")
-                .one()
-            )
+            with fixture_session() as sess:
+                # load address
+                a1 = (
+                    sess.query(Address)
+                    .filter_by(email_address="ed@wood.com")
+                    .one()
+                )
 
-            # load user that is attached to the address
-            u1 = sess.query(User).get(8)
+                # load user that is attached to the address
+                u1 = sess.query(User).get(8)
 
-            def go():
-                # lazy load of a1.user should get it from the session
-                assert a1.user is u1
+                def go():
+                    # lazy load of a1.user should get it from the session
+                    assert a1.user is u1
 
-            self.assert_sql_count(testing.db, go, 0)
-            sa.orm.clear_mappers()
+                self.assert_sql_count(testing.db, go, 0)
+                sa.orm.clear_mappers()
 
     def test_many_to_one(self):
         users, Address, addresses, User = (
index 0e8ac97e3e061245c44697c4438c05ca800d7703..42b5b3e459087def1ffd15a73f1af56b4776442b 100644 (file)
@@ -9,14 +9,12 @@ from sqlalchemy.orm import Session
 from sqlalchemy.orm.attributes import instance_state
 from sqlalchemy.testing import AssertsExecutionResults
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 
 
-engine = testing.db
-
-
 class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
-    def setUp(self):
+    def setup_test(self):
         global Parent, Child, Base
         Base = declarative_base()
 
@@ -36,27 +34,27 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
             )
             parent_id = Column(Integer, ForeignKey("parent.id"))
 
-        Base.metadata.create_all(engine)
+        Base.metadata.create_all(testing.db)
 
-    def tearDown(self):
-        Base.metadata.drop_all(engine)
+    def teardown_test(self):
+        Base.metadata.drop_all(testing.db)
 
     def test_annoying_autoflush_one(self):
-        sess = Session(engine)
+        sess = fixture_session()
 
         p1 = Parent()
         sess.add(p1)
         p1.children = []
 
     def test_annoying_autoflush_two(self):
-        sess = Session(engine)
+        sess = fixture_session()
 
         p1 = Parent()
         sess.add(p1)
         assert p1.children == []
 
     def test_dont_load_if_no_keys(self):
-        sess = Session(engine)
+        sess = fixture_session()
 
         p1 = Parent()
         sess.add(p1)
@@ -68,7 +66,9 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase):
 
 
 class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
-    def setUp(self):
+    __leave_connections_for_teardown__ = True
+
+    def setup_test(self):
         global Parent, Child, Base
         Base = declarative_base()
 
@@ -91,10 +91,10 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
 
             parent = relationship(Parent, backref=backref("children"))
 
-        Base.metadata.create_all(engine)
+        Base.metadata.create_all(testing.db)
 
         global sess, p1, p2, c1, c2
-        sess = Session(bind=engine)
+        sess = Session(bind=testing.db)
 
         p1 = Parent()
         p2 = Parent()
@@ -105,9 +105,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase):
 
         sess.commit()
 
-    def tearDown(self):
+    def teardown_test(self):
         sess.rollback()
-        Base.metadata.drop_all(engine)
+        Base.metadata.drop_all(testing.db)
 
     def test_load_on_pending_allows_backref_event(self):
         Child.parent.property.load_on_pending = True
index 013eb21e1169437dda3f4e61f429139938705d3e..d182fd2c1748c69003d670b411a670f4082129d2 100644 (file)
@@ -2560,7 +2560,7 @@ class MagicNamesTest(fixtures.MappedTest):
 
 
 class DocumentTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
 
         self.mapper = registry().map_imperatively
 
@@ -2624,14 +2624,14 @@ class DocumentTest(fixtures.TestBase):
 
 
 class ORMLoggingTest(_fixtures.FixtureTest):
-    def setup(self):
+    def setup_test(self):
         self.buf = logging.handlers.BufferingHandler(100)
         for log in [logging.getLogger("sqlalchemy.orm")]:
             log.addHandler(self.buf)
 
         self.mapper = registry().map_imperatively
 
-    def teardown(self):
+    def teardown_test(self):
         for log in [logging.getLogger("sqlalchemy.orm")]:
             log.removeHandler(self.buf)
 
index b22b318e9b7a4d1fd55fa2998fab78ca734bf31f..6f47c1238afebeb04fe632c6cdf7df071cc375ab 100644 (file)
@@ -1523,9 +1523,7 @@ class PickleTest(PathTest, QueryTest):
 
 class LocalOptsTest(PathTest, QueryTest):
     @classmethod
-    def setup_class(cls):
-        super(LocalOptsTest, cls).setup_class()
-
+    def setup_test_class(cls):
         @strategy_options.loader_option()
         def some_col_opt_only(loadopt, key, opts):
             return loadopt.set_column_strategy(
index fd8e849fb57d95f3b9f4924d472e5047f1fa0b54..7546ba1626e838196c026b624b1aec6340ae7c60 100644 (file)
@@ -6000,12 +6000,19 @@ class SynonymTest(QueryTest, AssertsCompiledSQL):
             [User.orders_syn, Order.items_syn],
             [User.orders_syn_2, Order.items_syn],
         ):
-            q = fixture_session().query(User)
-            for path in j:
-                q = q.join(path)
-            q = q.filter_by(id=3)
-            result = q.all()
-            assert [User(id=7, name="jack"), User(id=9, name="fred")] == result
+            with fixture_session() as sess:
+                q = sess.query(User)
+                for path in j:
+                    q = q.join(path)
+                q = q.filter_by(id=3)
+                result = q.all()
+                eq_(
+                    result,
+                    [
+                        User(id=7, name="jack"),
+                        User(id=9, name="fred"),
+                    ],
+                )
 
     def test_with_parent(self):
         Order, User = self.classes.Order, self.classes.User
@@ -6018,17 +6025,17 @@ class SynonymTest(QueryTest, AssertsCompiledSQL):
             ("name_syn", "orders_syn"),
             ("name_syn", "orders_syn_2"),
         ):
-            sess = fixture_session()
-            q = sess.query(User)
+            with fixture_session() as sess:
+                q = sess.query(User)
 
-            u1 = q.filter_by(**{nameprop: "jack"}).one()
+                u1 = q.filter_by(**{nameprop: "jack"}).one()
 
-            o = sess.query(Order).with_parent(u1, property=orderprop).all()
-            assert [
-                Order(description="order 1"),
-                Order(description="order 3"),
-                Order(description="order 5"),
-            ] == o
+                o = sess.query(Order).with_parent(u1, property=orderprop).all()
+                assert [
+                    Order(description="order 1"),
+                    Order(description="order 3"),
+                    Order(description="order 5"),
+                ] == o
 
     def test_froms_aliased_col(self):
         Address, User = self.classes.Address, self.classes.User
index 12c084b2d23696db91ec86a54928da3414e20903..ef1bf2e603d9fbbf68104332df60369d2928f310 100644 (file)
@@ -26,7 +26,7 @@ from sqlalchemy.testing import mock
 
 class _JoinFixtures(object):
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         m = MetaData()
         cls.left = Table(
             "lft",
index 5979f08ae61c2a2d5e4bb0502a3289bb56bf0530..8d73cd40e1e6a8d01e164323a4785916f781020d 100644 (file)
@@ -637,7 +637,7 @@ class OverlappingFksSiblingTest(fixtures.TestBase):
 
     """
 
-    def teardown(self):
+    def teardown_test(self):
         clear_mappers()
 
     def _fixture_one(
@@ -2474,7 +2474,7 @@ class JoinConditionErrorTest(fixtures.TestBase):
 
         assert_raises(sa.exc.ArgumentError, configure_mappers)
 
-    def teardown(self):
+    def teardown_test(self):
         clear_mappers()
 
 
@@ -4354,7 +4354,7 @@ class AmbiguousFKResolutionTest(_RelationshipErrors, fixtures.MappedTest):
 
 
 class SecondaryArgTest(fixtures.TestBase):
-    def teardown(self):
+    def teardown_test(self):
         clear_mappers()
 
     @testing.combinations((True,), (False,))
index 5535fe5d68e854e262e809f00cba0fba895c0454..4895c7d3a0780b4d64f5b26768e92e4348792966 100644 (file)
@@ -713,35 +713,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
     def _do_query_tests(self, opts, count):
         Order, User = self.classes.Order, self.classes.User
 
-        sess = fixture_session()
+        with fixture_session() as sess:
 
-        def go():
-            eq_(
-                sess.query(User).options(*opts).order_by(User.id).all(),
-                self.static.user_item_keyword_result,
-            )
+            def go():
+                eq_(
+                    sess.query(User).options(*opts).order_by(User.id).all(),
+                    self.static.user_item_keyword_result,
+                )
 
-        self.assert_sql_count(testing.db, go, count)
+            self.assert_sql_count(testing.db, go, count)
 
-        eq_(
-            sess.query(User)
-            .options(*opts)
-            .filter(User.name == "fred")
-            .order_by(User.id)
-            .all(),
-            self.static.user_item_keyword_result[2:3],
-        )
+            eq_(
+                sess.query(User)
+                .options(*opts)
+                .filter(User.name == "fred")
+                .order_by(User.id)
+                .all(),
+                self.static.user_item_keyword_result[2:3],
+            )
 
-        sess = fixture_session()
-        eq_(
-            sess.query(User)
-            .options(*opts)
-            .join(User.orders)
-            .filter(Order.id == 3)
-            .order_by(User.id)
-            .all(),
-            self.static.user_item_keyword_result[0:1],
-        )
+        with fixture_session() as sess:
+            eq_(
+                sess.query(User)
+                .options(*opts)
+                .join(User.orders)
+                .filter(Order.id == 3)
+                .order_by(User.id)
+                .all(),
+                self.static.user_item_keyword_result[0:1],
+            )
 
     def test_cyclical(self):
         """A circular eager relationship breaks the cycle with a lazy loader"""
index 20c4752b82dff08478744c35b9ea78ddaeb25a89..3d4566af3eeecb8f2f1bbd0c771e0365af242fa1 100644 (file)
@@ -1273,7 +1273,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
 
     run_inserts = None
 
-    def setup(self):
+    def setup_test(self):
         mapper(self.classes.User, self.tables.users)
 
     def _assert_modified(self, u1):
@@ -1288,11 +1288,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
     def _assert_no_cycle(self, u1):
         assert sa.orm.attributes.instance_state(u1)._strong_obj is None
 
-    def _persistent_fixture(self):
+    def _persistent_fixture(self, gc_collect=False):
         User = self.classes.User
         u1 = User()
         u1.name = "ed"
-        sess = fixture_session()
+        if gc_collect:
+            sess = Session(testing.db)
+        else:
+            sess = fixture_session()
         sess.add(u1)
         sess.flush()
         return sess, u1
@@ -1389,14 +1392,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest):
 
     @testing.requires.predictable_gc
     def test_move_gc_session_persistent_dirty(self):
-        sess, u1 = self._persistent_fixture()
+        sess, u1 = self._persistent_fixture(gc_collect=True)
         u1.name = "edchanged"
         self._assert_cycle(u1)
         self._assert_modified(u1)
         del sess
         gc_collect()
         self._assert_cycle(u1)
-        s2 = fixture_session()
+        s2 = Session(testing.db)
         s2.add(u1)
         self._assert_cycle(u1)
         self._assert_modified(u1)
@@ -1565,7 +1568,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest):
 
         mapper(User, users)
 
-        sess = fixture_session()
+        sess = Session(testing.db)
 
         u1 = User(name="u1")
         sess.add(u1)
@@ -1573,7 +1576,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest):
 
         # can't add u1 to Session,
         # already belongs to u2
-        s2 = fixture_session()
+        s2 = Session(testing.db)
         assert_raises_message(
             sa.exc.InvalidRequestError,
             r".*is already attached to session",
@@ -1725,11 +1728,10 @@ class DisposedStates(fixtures.MappedTest):
 
         mapper(T, cls.tables.t1)
 
-    def teardown(self):
+    def teardown_test(self):
         from sqlalchemy.orm.session import _sessions
 
         _sessions.clear()
-        super(DisposedStates, self).teardown()
 
     def _set_imap_in_disposal(self, sess, *objs):
         """remove selected objects from the given session, as though
index fe20442a30f812ffa5e5130fb9dc96d2db184d2d..150cee2225392411ecf423a4ef85fde352184709 100644 (file)
@@ -734,35 +734,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
     def _do_query_tests(self, opts, count):
         Order, User = self.classes.Order, self.classes.User
 
-        sess = fixture_session()
+        with fixture_session() as sess:
 
-        def go():
-            eq_(
-                sess.query(User).options(*opts).order_by(User.id).all(),
-                self.static.user_item_keyword_result,
-            )
+            def go():
+                eq_(
+                    sess.query(User).options(*opts).order_by(User.id).all(),
+                    self.static.user_item_keyword_result,
+                )
 
-        self.assert_sql_count(testing.db, go, count)
+            self.assert_sql_count(testing.db, go, count)
 
-        eq_(
-            sess.query(User)
-            .options(*opts)
-            .filter(User.name == "fred")
-            .order_by(User.id)
-            .all(),
-            self.static.user_item_keyword_result[2:3],
-        )
+            eq_(
+                sess.query(User)
+                .options(*opts)
+                .filter(User.name == "fred")
+                .order_by(User.id)
+                .all(),
+                self.static.user_item_keyword_result[2:3],
+            )
 
-        sess = fixture_session()
-        eq_(
-            sess.query(User)
-            .options(*opts)
-            .join(User.orders)
-            .filter(Order.id == 3)
-            .order_by(User.id)
-            .all(),
-            self.static.user_item_keyword_result[0:1],
-        )
+        with fixture_session() as sess:
+            eq_(
+                sess.query(User)
+                .options(*opts)
+                .join(User.orders)
+                .filter(Order.id == 3)
+                .order_by(User.id)
+                .all(),
+                self.static.user_item_keyword_result[0:1],
+            )
 
     def test_cyclical(self):
         """A circular eager relationship breaks the cycle with a lazy loader"""
index 550cf6535b04396e0b336c486e7553936dc8dca9..7f77b01c781e2c6a9f1dd28dbaf27e822ab95d33 100644 (file)
@@ -66,17 +66,18 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
         users, User = self.tables.users, self.classes.User
 
         mapper(User, users)
-        conn = testing.db.connect()
-        trans = conn.begin()
-        sess = Session(bind=conn, autocommit=False, autoflush=True)
-        sess.begin(subtransactions=True)
-        u = User(name="ed")
-        sess.add(u)
-        sess.flush()
-        sess.commit()  # commit does nothing
-        trans.rollback()  # rolls back
-        assert len(sess.query(User).all()) == 0
-        sess.close()
+
+        with testing.db.connect() as conn:
+            trans = conn.begin()
+            sess = Session(bind=conn, autocommit=False, autoflush=True)
+            sess.begin(subtransactions=True)
+            u = User(name="ed")
+            sess.add(u)
+            sess.flush()
+            sess.commit()  # commit does nothing
+            trans.rollback()  # rolls back
+            assert len(sess.query(User).all()) == 0
+            sess.close()
 
     @engines.close_open_connections
     def test_subtransaction_on_external_no_begin(self):
@@ -260,34 +261,33 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
         users = self.tables.users
 
         engine = Engine._future_facade(testing.db)
-        session = Session(engine, autocommit=False)
-
-        session.begin()
-        session.connection().execute(users.insert().values(name="user1"))
-        session.begin(subtransactions=True)
-        session.begin_nested()
-        session.connection().execute(users.insert().values(name="user2"))
-        assert (
-            session.connection()
-            .exec_driver_sql("select count(1) from users")
-            .scalar()
-            == 2
-        )
-        session.rollback()
-        assert (
-            session.connection()
-            .exec_driver_sql("select count(1) from users")
-            .scalar()
-            == 1
-        )
-        session.connection().execute(users.insert().values(name="user3"))
-        session.commit()
-        assert (
-            session.connection()
-            .exec_driver_sql("select count(1) from users")
-            .scalar()
-            == 2
-        )
+        with Session(engine, autocommit=False) as session:
+            session.begin()
+            session.connection().execute(users.insert().values(name="user1"))
+            session.begin(subtransactions=True)
+            session.begin_nested()
+            session.connection().execute(users.insert().values(name="user2"))
+            assert (
+                session.connection()
+                .exec_driver_sql("select count(1) from users")
+                .scalar()
+                == 2
+            )
+            session.rollback()
+            assert (
+                session.connection()
+                .exec_driver_sql("select count(1) from users")
+                .scalar()
+                == 1
+            )
+            session.connection().execute(users.insert().values(name="user3"))
+            session.commit()
+            assert (
+                session.connection()
+                .exec_driver_sql("select count(1) from users")
+                .scalar()
+                == 2
+            )
 
     @testing.requires.savepoints
     def test_dirty_state_transferred_deep_nesting(self):
@@ -295,27 +295,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
 
         mapper(User, users)
 
-        s = Session(testing.db)
-        u1 = User(name="u1")
-        s.add(u1)
-        s.commit()
-
-        nt1 = s.begin_nested()
-        nt2 = s.begin_nested()
-        u1.name = "u2"
-        assert attributes.instance_state(u1) not in nt2._dirty
-        assert attributes.instance_state(u1) not in nt1._dirty
-        s.flush()
-        assert attributes.instance_state(u1) in nt2._dirty
-        assert attributes.instance_state(u1) not in nt1._dirty
+        with fixture_session() as s:
+            u1 = User(name="u1")
+            s.add(u1)
+            s.commit()
+
+            nt1 = s.begin_nested()
+            nt2 = s.begin_nested()
+            u1.name = "u2"
+            assert attributes.instance_state(u1) not in nt2._dirty
+            assert attributes.instance_state(u1) not in nt1._dirty
+            s.flush()
+            assert attributes.instance_state(u1) in nt2._dirty
+            assert attributes.instance_state(u1) not in nt1._dirty
 
-        s.commit()
-        assert attributes.instance_state(u1) in nt2._dirty
-        assert attributes.instance_state(u1) in nt1._dirty
+            s.commit()
+            assert attributes.instance_state(u1) in nt2._dirty
+            assert attributes.instance_state(u1) in nt1._dirty
 
-        s.rollback()
-        assert attributes.instance_state(u1).expired
-        eq_(u1.name, "u1")
+            s.rollback()
+            assert attributes.instance_state(u1).expired
+            eq_(u1.name, "u1")
 
     @testing.requires.savepoints
     def test_dirty_state_transferred_deep_nesting_future(self):
@@ -323,27 +323,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
 
         mapper(User, users)
 
-        s = Session(testing.db, future=True)
-        u1 = User(name="u1")
-        s.add(u1)
-        s.commit()
-
-        nt1 = s.begin_nested()
-        nt2 = s.begin_nested()
-        u1.name = "u2"
-        assert attributes.instance_state(u1) not in nt2._dirty
-        assert attributes.instance_state(u1) not in nt1._dirty
-        s.flush()
-        assert attributes.instance_state(u1) in nt2._dirty
-        assert attributes.instance_state(u1) not in nt1._dirty
+        with fixture_session(future=True) as s:
+            u1 = User(name="u1")
+            s.add(u1)
+            s.commit()
+
+            nt1 = s.begin_nested()
+            nt2 = s.begin_nested()
+            u1.name = "u2"
+            assert attributes.instance_state(u1) not in nt2._dirty
+            assert attributes.instance_state(u1) not in nt1._dirty
+            s.flush()
+            assert attributes.instance_state(u1) in nt2._dirty
+            assert attributes.instance_state(u1) not in nt1._dirty
 
-        nt2.commit()
-        assert attributes.instance_state(u1) in nt2._dirty
-        assert attributes.instance_state(u1) in nt1._dirty
+            nt2.commit()
+            assert attributes.instance_state(u1) in nt2._dirty
+            assert attributes.instance_state(u1) in nt1._dirty
 
-        nt1.rollback()
-        assert attributes.instance_state(u1).expired
-        eq_(u1.name, "u1")
+            nt1.rollback()
+            assert attributes.instance_state(u1).expired
+            eq_(u1.name, "u1")
 
     @testing.requires.independent_connections
     def test_transactions_isolated(self):
@@ -1049,23 +1049,25 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest):
 
         mapper(User, users)
 
-        session = Session(testing.db)
+        with fixture_session() as session:
 
-        with expect_warnings(".*during handling of a previous exception.*"):
-            session.begin_nested()
-            savepoint = session.connection()._nested_transaction._savepoint
+            with expect_warnings(
+                ".*during handling of a previous exception.*"
+            ):
+                session.begin_nested()
+                savepoint = session.connection()._nested_transaction._savepoint
 
-            # force the savepoint to disappear
-            session.connection().dialect.do_release_savepoint(
-                session.connection(), savepoint
-            )
+                # force the savepoint to disappear
+                session.connection().dialect.do_release_savepoint(
+                    session.connection(), savepoint
+                )
 
-            # now do a broken flush
-            session.add_all([User(id=1), User(id=1)])
+                # now do a broken flush
+                session.add_all([User(id=1), User(id=1)])
 
-            assert_raises_message(
-                sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush
-            )
+                assert_raises_message(
+                    sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush
+                )
 
 
 class _LocalFixture(FixtureTest):
@@ -1170,39 +1172,40 @@ class SubtransactionRecipeTest(FixtureTest):
     def test_recipe_heavy_nesting(self, subtransaction_recipe):
         users = self.tables.users
 
-        session = Session(testing.db, future=self.future)
-
-        with subtransaction_recipe(session):
-            session.connection().execute(users.insert().values(name="user1"))
+        with fixture_session(future=self.future) as session:
             with subtransaction_recipe(session):
-                savepoint = session.begin_nested()
                 session.connection().execute(
-                    users.insert().values(name="user2")
+                    users.insert().values(name="user1")
                 )
+                with subtransaction_recipe(session):
+                    savepoint = session.begin_nested()
+                    session.connection().execute(
+                        users.insert().values(name="user2")
+                    )
+                    assert (
+                        session.connection()
+                        .exec_driver_sql("select count(1) from users")
+                        .scalar()
+                        == 2
+                    )
+                    savepoint.rollback()
+
+                with subtransaction_recipe(session):
+                    assert (
+                        session.connection()
+                        .exec_driver_sql("select count(1) from users")
+                        .scalar()
+                        == 1
+                    )
+                    session.connection().execute(
+                        users.insert().values(name="user3")
+                    )
                 assert (
                     session.connection()
                     .exec_driver_sql("select count(1) from users")
                     .scalar()
                     == 2
                 )
-                savepoint.rollback()
-
-            with subtransaction_recipe(session):
-                assert (
-                    session.connection()
-                    .exec_driver_sql("select count(1) from users")
-                    .scalar()
-                    == 1
-                )
-                session.connection().execute(
-                    users.insert().values(name="user3")
-                )
-            assert (
-                session.connection()
-                .exec_driver_sql("select count(1) from users")
-                .scalar()
-                == 2
-            )
 
     @engines.close_open_connections
     def test_recipe_subtransaction_on_external_subtrans(
@@ -1228,13 +1231,12 @@ class SubtransactionRecipeTest(FixtureTest):
         User, users = self.classes.User, self.tables.users
 
         mapper(User, users)
-        sess = Session(testing.db, future=self.future)
-
-        with subtransaction_recipe(sess):
-            u = User(name="u1")
-            sess.add(u)
-        sess.close()
-        assert len(sess.query(User).all()) == 1
+        with fixture_session(future=self.future) as sess:
+            with subtransaction_recipe(sess):
+                u = User(name="u1")
+                sess.add(u)
+            sess.close()
+            assert len(sess.query(User).all()) == 1
 
     def test_recipe_subtransaction_on_noautocommit(
         self, subtransaction_recipe
@@ -1242,16 +1244,15 @@ class SubtransactionRecipeTest(FixtureTest):
         User, users = self.classes.User, self.tables.users
 
         mapper(User, users)
-        sess = Session(testing.db, future=self.future)
-
-        sess.begin()
-        with subtransaction_recipe(sess):
-            u = User(name="u1")
-            sess.add(u)
-            sess.flush()
-        sess.rollback()  # rolls back
-        assert len(sess.query(User).all()) == 0
-        sess.close()
+        with fixture_session(future=self.future) as sess:
+            sess.begin()
+            with subtransaction_recipe(sess):
+                u = User(name="u1")
+                sess.add(u)
+                sess.flush()
+            sess.rollback()  # rolls back
+            assert len(sess.query(User).all()) == 0
+            sess.close()
 
     @testing.requires.savepoints
     def test_recipe_mixed_transaction_control(self, subtransaction_recipe):
@@ -1259,30 +1260,28 @@ class SubtransactionRecipeTest(FixtureTest):
 
         mapper(User, users)
 
-        sess = Session(testing.db, future=self.future)
+        with fixture_session(future=self.future) as sess:
 
-        sess.begin()
-        sess.begin_nested()
+            sess.begin()
+            sess.begin_nested()
 
-        with subtransaction_recipe(sess):
+            with subtransaction_recipe(sess):
 
-            sess.add(User(name="u1"))
+                sess.add(User(name="u1"))
 
-        sess.commit()
-        sess.commit()
+            sess.commit()
+            sess.commit()
 
-        eq_(len(sess.query(User).all()), 1)
-        sess.close()
+            eq_(len(sess.query(User).all()), 1)
+            sess.close()
 
-        t1 = sess.begin()
-        t2 = sess.begin_nested()
-
-        sess.add(User(name="u2"))
+            t1 = sess.begin()
+            t2 = sess.begin_nested()
 
-        t2.commit()
-        assert sess._legacy_transaction() is t1
+            sess.add(User(name="u2"))
 
-        sess.close()
+            t2.commit()
+            assert sess._legacy_transaction() is t1
 
     def test_recipe_error_on_using_inactive_session_commands(
         self, subtransaction_recipe
@@ -1290,56 +1289,55 @@ class SubtransactionRecipeTest(FixtureTest):
         users, User = self.tables.users, self.classes.User
 
         mapper(User, users)
-        sess = Session(testing.db, future=self.future)
-        sess.begin()
-
-        try:
-            with subtransaction_recipe(sess):
-                sess.add(User(name="u1"))
-                sess.flush()
-                raise Exception("force rollback")
-        except:
-            pass
-
-        if self.recipe_rollsback_early:
-            # that was a real rollback, so no transaction
-            assert not sess.in_transaction()
-            is_(sess.get_transaction(), None)
-        else:
-            assert sess.in_transaction()
-
-        sess.close()
-        assert not sess.in_transaction()
-
-    def test_recipe_multi_nesting(self, subtransaction_recipe):
-        sess = Session(testing.db, future=self.future)
-
-        with subtransaction_recipe(sess):
-            assert sess.in_transaction()
+        with fixture_session(future=self.future) as sess:
+            sess.begin()
 
             try:
                 with subtransaction_recipe(sess):
-                    assert sess._legacy_transaction()
+                    sess.add(User(name="u1"))
+                    sess.flush()
                     raise Exception("force rollback")
             except:
                 pass
 
             if self.recipe_rollsback_early:
+                # that was a real rollback, so no transaction
                 assert not sess.in_transaction()
+                is_(sess.get_transaction(), None)
             else:
                 assert sess.in_transaction()
 
-        assert not sess.in_transaction()
+            sess.close()
+            assert not sess.in_transaction()
+
+    def test_recipe_multi_nesting(self, subtransaction_recipe):
+        with fixture_session(future=self.future) as sess:
+            with subtransaction_recipe(sess):
+                assert sess.in_transaction()
+
+                try:
+                    with subtransaction_recipe(sess):
+                        assert sess._legacy_transaction()
+                        raise Exception("force rollback")
+                except:
+                    pass
+
+                if self.recipe_rollsback_early:
+                    assert not sess.in_transaction()
+                else:
+                    assert sess.in_transaction()
+
+            assert not sess.in_transaction()
 
     def test_recipe_deactive_status_check(self, subtransaction_recipe):
-        sess = Session(testing.db, future=self.future)
-        sess.begin()
+        with fixture_session(future=self.future) as sess:
+            sess.begin()
 
-        with subtransaction_recipe(sess):
-            sess.rollback()
+            with subtransaction_recipe(sess):
+                sess.rollback()
 
-        assert not sess.in_transaction()
-        sess.commit()  # no error
+            assert not sess.in_transaction()
+            sess.commit()  # no error
 
 
 class FixtureDataTest(_LocalFixture):
@@ -1394,28 +1392,28 @@ class CleanSavepointTest(FixtureTest):
 
         mapper(User, users)
 
-        s = Session(bind=testing.db, future=future)
-        u1 = User(name="u1")
-        u2 = User(name="u2")
-        s.add_all([u1, u2])
-        s.commit()
-        u1.name
-        u2.name
-        trans = s._transaction
-        assert trans is not None
-        s.begin_nested()
-        update_fn(s, u2)
-        eq_(u2.name, "u2modified")
-        s.rollback()
+        with fixture_session(future=future) as s:
+            u1 = User(name="u1")
+            u2 = User(name="u2")
+            s.add_all([u1, u2])
+            s.commit()
+            u1.name
+            u2.name
+            trans = s._transaction
+            assert trans is not None
+            s.begin_nested()
+            update_fn(s, u2)
+            eq_(u2.name, "u2modified")
+            s.rollback()
 
-        if future:
-            assert s._transaction is None
-            assert "name" not in u1.__dict__
-        else:
-            assert s._transaction is trans
-            eq_(u1.__dict__["name"], "u1")
-        assert "name" not in u2.__dict__
-        eq_(u2.name, "u2")
+            if future:
+                assert s._transaction is None
+                assert "name" not in u1.__dict__
+            else:
+                assert s._transaction is trans
+                eq_(u1.__dict__["name"], "u1")
+            assert "name" not in u2.__dict__
+            eq_(u2.name, "u2")
 
     @testing.requires.savepoints
     def test_rollback_ignores_clean_on_savepoint(self):
@@ -2116,82 +2114,108 @@ class ContextManagerPlusFutureTest(FixtureTest):
         eq_(sess.query(User).count(), 1)
 
     def test_explicit_begin(self):
-        s1 = Session(testing.db)
-        with s1.begin() as trans:
-            is_(trans, s1._legacy_transaction())
-            s1.connection()
+        with fixture_session() as s1:
+            with s1.begin() as trans:
+                is_(trans, s1._legacy_transaction())
+                s1.connection()
 
-        is_(s1._transaction, None)
+            is_(s1._transaction, None)
 
     def test_no_double_begin_explicit(self):
-        s1 = Session(testing.db)
-        s1.begin()
-        assert_raises_message(
-            sa_exc.InvalidRequestError,
-            "A transaction is already begun on this Session.",
-            s1.begin,
-        )
+        with fixture_session() as s1:
+            s1.begin()
+            assert_raises_message(
+                sa_exc.InvalidRequestError,
+                "A transaction is already begun on this Session.",
+                s1.begin,
+            )
 
     @testing.requires.savepoints
     def test_future_rollback_is_global(self):
         users = self.tables.users
 
-        s1 = Session(testing.db, future=True)
+        with fixture_session(future=True) as s1:
+            s1.begin()
 
-        s1.begin()
+            s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
 
-        s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
+            s1.begin_nested()
 
-        s1.begin_nested()
-
-        s1.connection().execute(
-            users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}]
-        )
+            s1.connection().execute(
+                users.insert(),
+                [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}],
+            )
 
-        eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3)
+            eq_(
+                s1.connection().scalar(
+                    select(func.count()).select_from(users)
+                ),
+                3,
+            )
 
-        # rolls back the whole transaction
-        s1.rollback()
-        is_(s1._legacy_transaction(), None)
+            # rolls back the whole transaction
+            s1.rollback()
+            is_(s1._legacy_transaction(), None)
 
-        eq_(s1.connection().scalar(select(func.count()).select_from(users)), 0)
+            eq_(
+                s1.connection().scalar(
+                    select(func.count()).select_from(users)
+                ),
+                0,
+            )
 
-        s1.commit()
-        is_(s1._legacy_transaction(), None)
+            s1.commit()
+            is_(s1._legacy_transaction(), None)
 
     @testing.requires.savepoints
     def test_old_rollback_is_local(self):
         users = self.tables.users
 
-        s1 = Session(testing.db)
+        with fixture_session() as s1:
 
-        t1 = s1.begin()
+            t1 = s1.begin()
 
-        s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
+            s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}])
 
-        s1.begin_nested()
+            s1.begin_nested()
 
-        s1.connection().execute(
-            users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}]
-        )
+            s1.connection().execute(
+                users.insert(),
+                [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}],
+            )
 
-        eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3)
+            eq_(
+                s1.connection().scalar(
+                    select(func.count()).select_from(users)
+                ),
+                3,
+            )
 
-        # rolls back only the savepoint
-        s1.rollback()
+            # rolls back only the savepoint
+            s1.rollback()
 
-        is_(s1._legacy_transaction(), t1)
+            is_(s1._legacy_transaction(), t1)
 
-        eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1)
+            eq_(
+                s1.connection().scalar(
+                    select(func.count()).select_from(users)
+                ),
+                1,
+            )
 
-        s1.commit()
-        eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1)
-        is_not(s1._legacy_transaction(), None)
+            s1.commit()
+            eq_(
+                s1.connection().scalar(
+                    select(func.count()).select_from(users)
+                ),
+                1,
+            )
+            is_not(s1._legacy_transaction(), None)
 
     def test_session_as_ctx_manager_one(self):
         users = self.tables.users
 
-        with Session(testing.db) as sess:
+        with fixture_session() as sess:
             is_not(sess._legacy_transaction(), None)
 
             sess.connection().execute(
@@ -2212,7 +2236,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
     def test_session_as_ctx_manager_future_one(self):
         users = self.tables.users
 
-        with Session(testing.db, future=True) as sess:
+        with fixture_session(future=True) as sess:
             is_(sess._legacy_transaction(), None)
 
             sess.connection().execute(
@@ -2234,7 +2258,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
         users = self.tables.users
 
         try:
-            with Session(testing.db) as sess:
+            with fixture_session() as sess:
                 is_not(sess._legacy_transaction(), None)
 
                 sess.connection().execute(
@@ -2250,7 +2274,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
         users = self.tables.users
 
         try:
-            with Session(testing.db, future=True) as sess:
+            with fixture_session(future=True) as sess:
                 is_(sess._legacy_transaction(), None)
 
                 sess.connection().execute(
@@ -2265,7 +2289,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
     def test_begin_context_manager(self):
         users = self.tables.users
 
-        with Session(testing.db) as sess:
+        with fixture_session() as sess:
             with sess.begin():
                 sess.connection().execute(
                     users.insert().values(id=1, name="user1")
@@ -2296,12 +2320,13 @@ class ContextManagerPlusFutureTest(FixtureTest):
 
         # committed
         eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+        sess.close()
 
     def test_begin_context_manager_rollback_trans(self):
         users = self.tables.users
 
         try:
-            with Session(testing.db) as sess:
+            with fixture_session() as sess:
                 with sess.begin():
                     sess.connection().execute(
                         users.insert().values(id=1, name="user1")
@@ -2318,12 +2343,13 @@ class ContextManagerPlusFutureTest(FixtureTest):
 
         # rolled back
         eq_(sess.connection().execute(users.select()).all(), [])
+        sess.close()
 
     def test_begin_context_manager_rollback_outer(self):
         users = self.tables.users
 
         try:
-            with Session(testing.db) as sess:
+            with fixture_session() as sess:
                 with sess.begin():
                     sess.connection().execute(
                         users.insert().values(id=1, name="user1")
@@ -2340,6 +2366,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
 
         # committed
         eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+        sess.close()
 
     def test_sessionmaker_begin_context_manager_rollback_trans(self):
         users = self.tables.users
@@ -2363,6 +2390,7 @@ class ContextManagerPlusFutureTest(FixtureTest):
 
         # rolled back
         eq_(sess.connection().execute(users.select()).all(), [])
+        sess.close()
 
     def test_sessionmaker_begin_context_manager_rollback_outer(self):
         users = self.tables.users
@@ -2386,36 +2414,37 @@ class ContextManagerPlusFutureTest(FixtureTest):
 
         # committed
         eq_(sess.connection().execute(users.select()).all(), [(1, "user1")])
+        sess.close()
 
 
 class TransactionFlagsTest(fixtures.TestBase):
     def test_in_transaction(self):
-        s1 = Session(testing.db)
+        with fixture_session() as s1:
 
-        eq_(s1.in_transaction(), False)
+            eq_(s1.in_transaction(), False)
 
-        trans = s1.begin()
+            trans = s1.begin()
 
-        eq_(s1.in_transaction(), True)
-        is_(s1.get_transaction(), trans)
+            eq_(s1.in_transaction(), True)
+            is_(s1.get_transaction(), trans)
 
-        n1 = s1.begin_nested()
+            n1 = s1.begin_nested()
 
-        eq_(s1.in_transaction(), True)
-        is_(s1.get_transaction(), trans)
-        is_(s1.get_nested_transaction(), n1)
+            eq_(s1.in_transaction(), True)
+            is_(s1.get_transaction(), trans)
+            is_(s1.get_nested_transaction(), n1)
 
-        n1.rollback()
+            n1.rollback()
 
-        is_(s1.get_nested_transaction(), None)
-        is_(s1.get_transaction(), trans)
+            is_(s1.get_nested_transaction(), None)
+            is_(s1.get_transaction(), trans)
 
-        eq_(s1.in_transaction(), True)
+            eq_(s1.in_transaction(), True)
 
-        s1.commit()
+            s1.commit()
 
-        eq_(s1.in_transaction(), False)
-        is_(s1.get_transaction(), None)
+            eq_(s1.in_transaction(), False)
+            is_(s1.get_transaction(), None)
 
     def test_in_transaction_subtransactions(self):
         """we'd like to do away with subtransactions for future sessions
@@ -2425,72 +2454,71 @@ class TransactionFlagsTest(fixtures.TestBase):
         the external API works.
 
         """
-        s1 = Session(testing.db)
-
-        eq_(s1.in_transaction(), False)
+        with fixture_session() as s1:
+            eq_(s1.in_transaction(), False)
 
-        trans = s1.begin()
+            trans = s1.begin()
 
-        eq_(s1.in_transaction(), True)
-        is_(s1.get_transaction(), trans)
+            eq_(s1.in_transaction(), True)
+            is_(s1.get_transaction(), trans)
 
-        subtrans = s1.begin(_subtrans=True)
-        is_(s1.get_transaction(), trans)
-        eq_(s1.in_transaction(), True)
+            subtrans = s1.begin(_subtrans=True)
+            is_(s1.get_transaction(), trans)
+            eq_(s1.in_transaction(), True)
 
-        is_(s1._transaction, subtrans)
+            is_(s1._transaction, subtrans)
 
-        s1.rollback()
+            s1.rollback()
 
-        eq_(s1.in_transaction(), True)
-        is_(s1._transaction, trans)
+            eq_(s1.in_transaction(), True)
+            is_(s1._transaction, trans)
 
-        s1.rollback()
+            s1.rollback()
 
-        eq_(s1.in_transaction(), False)
-        is_(s1._transaction, None)
+            eq_(s1.in_transaction(), False)
+            is_(s1._transaction, None)
 
     def test_in_transaction_nesting(self):
-        s1 = Session(testing.db)
+        with fixture_session() as s1:
 
-        eq_(s1.in_transaction(), False)
+            eq_(s1.in_transaction(), False)
 
-        trans = s1.begin()
+            trans = s1.begin()
 
-        eq_(s1.in_transaction(), True)
-        is_(s1.get_transaction(), trans)
+            eq_(s1.in_transaction(), True)
+            is_(s1.get_transaction(), trans)
 
-        sp1 = s1.begin_nested()
+            sp1 = s1.begin_nested()
 
-        eq_(s1.in_transaction(), True)
-        is_(s1.get_transaction(), trans)
-        is_(s1.get_nested_transaction(), sp1)
+            eq_(s1.in_transaction(), True)
+            is_(s1.get_transaction(), trans)
+            is_(s1.get_nested_transaction(), sp1)
 
-        sp2 = s1.begin_nested()
+            sp2 = s1.begin_nested()
 
-        eq_(s1.in_transaction(), True)
-        eq_(s1.in_nested_transaction(), True)
-        is_(s1.get_transaction(), trans)
-        is_(s1.get_nested_transaction(), sp2)
+            eq_(s1.in_transaction(), True)
+            eq_(s1.in_nested_transaction(), True)
+            is_(s1.get_transaction(), trans)
+            is_(s1.get_nested_transaction(), sp2)
 
-        sp2.rollback()
+            sp2.rollback()
 
-        eq_(s1.in_transaction(), True)
-        eq_(s1.in_nested_transaction(), True)
-        is_(s1.get_transaction(), trans)
-        is_(s1.get_nested_transaction(), sp1)
+            eq_(s1.in_transaction(), True)
+            eq_(s1.in_nested_transaction(), True)
+            is_(s1.get_transaction(), trans)
+            is_(s1.get_nested_transaction(), sp1)
 
-        sp1.rollback()
+            sp1.rollback()
 
-        is_(s1.get_nested_transaction(), None)
-        eq_(s1.in_transaction(), True)
-        eq_(s1.in_nested_transaction(), False)
-        is_(s1.get_transaction(), trans)
+            is_(s1.get_nested_transaction(), None)
+            eq_(s1.in_transaction(), True)
+            eq_(s1.in_nested_transaction(), False)
+            is_(s1.get_transaction(), trans)
 
-        s1.rollback()
+            s1.rollback()
 
-        eq_(s1.in_transaction(), False)
-        is_(s1.get_transaction(), None)
+            eq_(s1.in_transaction(), False)
+            is_(s1.get_transaction(), None)
 
 
 class NaturalPKRollbackTest(fixtures.MappedTest):
@@ -2674,8 +2702,11 @@ class NaturalPKRollbackTest(fixtures.MappedTest):
 class JoinIntoAnExternalTransactionFixture(object):
     """Test the "join into an external transaction" examples"""
 
-    def setup(self):
-        self.connection = testing.db.connect()
+    __leave_connections_for_teardown__ = True
+
+    def setup_test(self):
+        self.engine = testing.db
+        self.connection = self.engine.connect()
 
         self.metadata = MetaData()
         self.table = Table(
@@ -2686,6 +2717,17 @@ class JoinIntoAnExternalTransactionFixture(object):
 
         self.setup_session()
 
+    def teardown_test(self):
+        self.teardown_session()
+
+        with self.connection.begin():
+            self._assert_count(0)
+
+        with self.connection.begin():
+            self.table.drop(self.connection)
+
+        self.connection.close()
+
     def test_something(self):
         A = self.A
 
@@ -2727,18 +2769,6 @@ class JoinIntoAnExternalTransactionFixture(object):
         )
         eq_(result, count)
 
-    def teardown(self):
-        self.teardown_session()
-
-        with self.connection.begin():
-            self._assert_count(0)
-
-        with self.connection.begin():
-            self.table.drop(self.connection)
-
-        # return connection to the Engine
-        self.connection.close()
-
 
 class NewStyleJoinIntoAnExternalTransactionTest(
     JoinIntoAnExternalTransactionFixture
@@ -2775,7 +2805,8 @@ class NewStyleJoinIntoAnExternalTransactionTest(
         # rollback - everything that happened with the
         # Session above (including calls to commit())
         # is rolled back.
-        self.trans.rollback()
+        if self.trans.is_active:
+            self.trans.rollback()
 
 
 class FutureJoinIntoAnExternalTransactionTest(
index 84373b2dca576de475d6fb3a20df6bbdb3843522..2c35bec45f818d6b7c0ddb8793b0ef8a0d63b343 100644 (file)
@@ -203,14 +203,6 @@ class UnicodeSchemaTest(fixtures.MappedTest):
         cls.tables["t1"] = t1
         cls.tables["t2"] = t2
 
-    @classmethod
-    def setup_class(cls):
-        super(UnicodeSchemaTest, cls).setup_class()
-
-    @classmethod
-    def teardown_class(cls):
-        super(UnicodeSchemaTest, cls).teardown_class()
-
     def test_mapping(self):
         t2, t1 = self.tables.t2, self.tables.t1
 
index 4e713627c1ef1936d46068243a7b26decfb1fe16..65089f773ce1c3d8e46be04e63c485b11d91fe5c 100644 (file)
@@ -771,14 +771,13 @@ class RudimentaryFlushTest(UOWTest):
 
 
 class SingleCycleTest(UOWTest):
-    def teardown(self):
+    def teardown_test(self):
         engines.testing_reaper.rollback_all()
         # mysql can't handle delete from nodes
         # since it doesn't deal with the FKs correctly,
         # so wipe out the parent_id first
         with testing.db.begin() as conn:
             conn.execute(self.tables.nodes.update().values(parent_id=None))
-        super(SingleCycleTest, self).teardown()
 
     def test_one_to_many_save(self):
         Node, nodes = self.classes.Node, self.tables.nodes
index d5a718372762fbb0f67bfc3b1a44e004f1d64bd1..3c9b39ac71567da78d61f90641fb0aaf8fa0089f 100644 (file)
@@ -1624,15 +1624,15 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def postgresql_utf8_server_encoding(self):
+        def go(config):
+            if not against(config, "postgresql"):
+                return False
 
-        return only_if(
-            lambda config: against(config, "postgresql")
-            and config.db.connect(close_with_result=True)
-            .exec_driver_sql("show server_encoding")
-            .scalar()
-            .lower()
-            == "utf8"
-        )
+            with config.db.connect() as conn:
+                enc = conn.exec_driver_sql("show server_encoding").scalar()
+                return enc.lower() == "utf8"
+
+        return only_if(go)
 
     @property
     def cxoracle6_or_greater(self):
index 4bef1df7f3c917674dc01309637a0e9b83974fb0..b44971cecd24b89efe046c59b1300d13a2dabaf9 100644 (file)
@@ -26,7 +26,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         metadata = MetaData()
         global info_table
         info_table = Table(
@@ -52,7 +52,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL):
             )
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         with testing.db.begin() as conn:
             info_table.drop(conn)
 
index 70281d4e89d46f23786c956e6207eb8703194383..1ac3613f74fc0d2aefe0bda3bfe06d22c922c675 100644 (file)
@@ -1203,7 +1203,7 @@ class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase):
 
 class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         # TODO: we need to get dialects here somehow, perhaps in test_suite?
         [
             importlib.import_module("sqlalchemy.dialects.%s" % d)
index fdffe04bf37d8babbeda724941dd9ecd2ec2afb4..4429753ecb7fd45014c1b5564fe62052bd2317e3 100644 (file)
@@ -4306,7 +4306,7 @@ class StringifySpecialTest(fixtures.TestBase):
 
 class KwargPropagationTest(fixtures.TestBase):
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         from sqlalchemy.sql.expression import ColumnClause, TableClause
 
         class CatchCol(ColumnClause):
index 2a2e70bc3941551948758c49a8ce18f7d5922329..8be7eed1f51d0143774cfcfc719cf25fba1831a9 100644 (file)
@@ -503,9 +503,8 @@ class DefaultRoundTripTest(fixtures.TablesTest):
             Column("col11", MyType(), default="foo"),
         )
 
-    def teardown(self):
+    def teardown_test(self):
         self.default_generator["x"] = 50
-        super(DefaultRoundTripTest, self).teardown()
 
     def test_standalone(self, connection):
         t = self.tables.default_test
@@ -1226,7 +1225,7 @@ class SpecialTypePKTest(fixtures.TestBase):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         class MyInteger(TypeDecorator):
             impl = Integer
 
index acc12a5febb0281dd49b1b56a02a09ae9c5fea8e..7775652208b6f21c724a47a16117db8dafcb57bd 100644 (file)
@@ -2488,12 +2488,12 @@ class LegacySequenceExecTest(fixtures.TestBase):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         cls.seq = Sequence("my_sequence")
         cls.seq.create(testing.db)
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         cls.seq.drop(testing.db)
 
     def _assert_seq_result(self, ret):
@@ -2574,7 +2574,7 @@ class LegacySequenceExecTest(fixtures.TestBase):
 
 
 class DDLDeprecatedBindTest(fixtures.TestBase):
-    def teardown(self):
+    def teardown_test(self):
         with testing.db.begin() as conn:
             if inspect(conn).has_table("foo"):
                 conn.execute(schema.DropTable(table("foo")))
index 4edc9d0258306dbf9784c42826a57c40dd2da053..a6001ba9da4e6f5c01751d714077911394d93e22 100644 (file)
@@ -47,7 +47,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults):
     ability to copy and modify a ClauseElement in place."""
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         global A, B
 
         # establish two fictitious ClauseElements.
@@ -321,7 +321,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         global t1, t2, t3
         t1 = table("table1", column("col1"), column("col2"), column("col3"))
         t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -1012,7 +1012,7 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         global t1, t2
         t1 = table(
             "table1",
@@ -1196,7 +1196,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         global t1, t2
         t1 = table("table1", column("col1"), column("col2"), column("col3"))
         t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -1943,7 +1943,7 @@ class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         global table1, table2, table3, table4
 
         def _table(name):
@@ -2031,7 +2031,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         global t1, t2
         t1 = table("table1", column("col1"), column("col2"), column("col3"))
         t2 = table("table2", column("col1"), column("col2"), column("col3"))
@@ -2128,7 +2128,7 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL):
     # fixme: consolidate converage from elsewhere here and expand
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         global t1, t2
         t1 = table("table1", column("col1"), column("col2"), column("col3"))
         t2 = table("table2", column("col1"), column("col2"), column("col3"))
index 6afe41aaca8786fda3530a90d76bc1fc2cb3ef98..b0bcee18e242ed6fdd4b714b2f3cf4a37fc2af9b 100644 (file)
@@ -25,7 +25,7 @@ class TestFindUnmatchingFroms(fixtures.TablesTest):
         Table("table_c", metadata, Column("col_c", Integer, primary_key=True))
         Table("table_d", metadata, Column("col_d", Integer, primary_key=True))
 
-    def setup(self):
+    def setup_test(self):
         self.a = self.tables.table_a
         self.b = self.tables.table_b
         self.c = self.tables.table_c
@@ -267,8 +267,10 @@ class TestLinter(fixtures.TablesTest):
             with self.bind.connect() as conn:
                 conn.execute(query)
 
-    def test_no_linting(self):
-        eng = engines.testing_engine(options={"enable_from_linting": False})
+    def test_no_linting(self, metadata, connection):
+        eng = engines.testing_engine(
+            options={"enable_from_linting": False, "use_reaper": False}
+        )
         eng.pool = self.bind.pool  # needed for SQLite
         a, b = self.tables("table_a", "table_b")
         query = select(a.c.col_a).where(b.c.col_b == 5)
index 91076f9c388e4501b9655bf98d4e60448ffca7da..32ea642d744432e20671b5fe24e41883309aaff9 100644 (file)
@@ -54,10 +54,10 @@ table1 = table(
 class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
-    def setup(self):
+    def setup_test(self):
         self._registry = deepcopy(functions._registry)
 
-    def teardown(self):
+    def teardown_test(self):
         functions._registry = self._registry
 
     def test_compile(self):
@@ -938,7 +938,7 @@ class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase):
 class ExecuteTest(fixtures.TestBase):
     __backend__ = True
 
-    def tearDown(self):
+    def teardown_test(self):
         pass
 
     def test_conn_execute(self, connection):
@@ -1113,10 +1113,10 @@ class ExecuteTest(fixtures.TestBase):
 class RegisterTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
-    def setup(self):
+    def setup_test(self):
         self._registry = deepcopy(functions._registry)
 
-    def teardown(self):
+    def teardown_test(self):
         functions._registry = self._registry
 
     def test_GenericFunction_is_registered(self):
index 502f70ce78e66338a465b64c7a598f253150d75a..9bf351f5c41cc73c77bd957d4b732c51f3bb411c 100644 (file)
@@ -4257,7 +4257,7 @@ class DialectKWArgTest(fixtures.TestBase):
         with mock.patch("sqlalchemy.dialects.registry.load", load):
             yield
 
-    def teardown(self):
+    def teardown_test(self):
         Index._kw_registry.clear()
 
     def test_participating(self):
index aaeed68ddb81fcee97c01831123e6a43b6daf401..270e79ba16c7a04bd71c5a393a53dee6293ef98d 100644 (file)
@@ -608,7 +608,7 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
 
 class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
-    def setUp(self):
+    def setup_test(self):
         class MyTypeCompiler(compiler.GenericTypeCompiler):
             def visit_mytype(self, type_, **kw):
                 return "MYTYPE"
@@ -766,7 +766,7 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
 
 class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
-    def setUp(self):
+    def setup_test(self):
         class MyTypeCompiler(compiler.GenericTypeCompiler):
             def visit_mytype(self, type_, **kw):
                 return "MYTYPE"
@@ -2370,7 +2370,7 @@ class MatchTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
 
-    def setUp(self):
+    def setup_test(self):
         self.table = table(
             "mytable", column("myid", Integer), column("name", String)
         )
@@ -2403,7 +2403,7 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 class RegexpTestStrCompiler(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default_enhanced"
 
-    def setUp(self):
+    def setup_test(self):
         self.table = table(
             "mytable", column("myid", Integer), column("name", String)
         )
index 136f10cf4c3f74918d7f9c5a9f6ddb4b0b8cf06b..7ad12c62037d8c333c792c87114b0a980f97a436 100644 (file)
@@ -661,8 +661,8 @@ class CursorResultTest(fixtures.TablesTest):
         assert_raises(KeyError, lambda: row._mapping["Case_insensitive"])
         assert_raises(KeyError, lambda: row._mapping["casesensitive"])
 
-    def test_row_case_sensitive_unoptimized(self):
-        with engines.testing_engine().connect() as ins_conn:
+    def test_row_case_sensitive_unoptimized(self, testing_engine):
+        with testing_engine().connect() as ins_conn:
             row = ins_conn.execute(
                 select(
                     literal_column("1").label("case_insensitive"),
@@ -1234,8 +1234,7 @@ class CursorResultTest(fixtures.TablesTest):
         eq_(proxy[0], "value")
         eq_(proxy._mapping["key"], "value")
 
-    @testing.provide_metadata
-    def test_no_rowcount_on_selects_inserts(self):
+    def test_no_rowcount_on_selects_inserts(self, metadata, testing_engine):
         """assert that rowcount is only called on deletes and updates.
 
         This because cursor.rowcount may can be expensive on some dialects
@@ -1244,9 +1243,7 @@ class CursorResultTest(fixtures.TablesTest):
 
         """
 
-        metadata = self.metadata
-
-        engine = engines.testing_engine()
+        engine = testing_engine()
 
         t = Table("t1", metadata, Column("data", String(10)))
         metadata.create_all(engine)
@@ -2132,7 +2129,9 @@ class AlternateCursorResultTest(fixtures.TablesTest):
 
     @classmethod
     def setup_bind(cls):
-        cls.engine = engine = engines.testing_engine("sqlite://")
+        cls.engine = engine = engines.testing_engine(
+            "sqlite://", options={"scope": "class"}
+        )
         return engine
 
     @classmethod
index 65325aa6f70194aab41c6eaa9979ec67a28396da..5cfc2663f603fdb9852fdecf09492fb1f47e2e31 100644 (file)
@@ -100,12 +100,12 @@ class SequenceExecTest(fixtures.TestBase):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
+    def setup_test_class(cls):
         cls.seq = Sequence("my_sequence")
         cls.seq.create(testing.db)
 
     @classmethod
-    def teardown_class(cls):
+    def teardown_test_class(cls):
         cls.seq.drop(testing.db)
 
     def _assert_seq_result(self, ret):
index 0e114780074d503117615233d0f8d42ac2bdc16c..64ace87dfab066a5451eee94368bd94b39db90db 100644 (file)
@@ -1375,7 +1375,7 @@ class VariantBackendTest(fixtures.TestBase, AssertsCompiledSQL):
 
 
 class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
-    def setup(self):
+    def setup_test(self):
         class UTypeOne(types.UserDefinedType):
             def get_col_spec(self):
                 return "UTYPEONE"
@@ -2504,7 +2504,7 @@ class BinaryTest(fixtures.TablesTest, AssertsExecutionResults):
 
 
 class JSONTest(fixtures.TestBase):
-    def setup(self):
+    def setup_test(self):
         metadata = MetaData()
         self.test_table = Table(
             "test_table",
@@ -3445,7 +3445,12 @@ class BooleanTest(
     @testing.requires.non_native_boolean_unconstrained
     def test_constraint(self, connection):
         assert_raises(
-            (exc.IntegrityError, exc.ProgrammingError, exc.OperationalError),
+            (
+                exc.IntegrityError,
+                exc.ProgrammingError,
+                exc.OperationalError,
+                exc.InternalError,  # older pymysql's do this
+            ),
             connection.exec_driver_sql,
             "insert into boolean_table (id, value) values(1, 5)",
         )
diff --git a/tox.ini b/tox.ini
index 8f2fd9de0fe00b7e2335fcf721e3542d93f700da..ea2b76e1665268bd594c2da4422971655d7b3298 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -15,7 +15,9 @@ install_command=python -m pip install {env:TOX_PIP_OPTS:} {opts} {packages}
 usedevelop=
      cov: True
 
-deps=pytest>=4.6.11 # this can be 6.x once we are on python 3 only
+deps=
+     pytest>=4.6.11,<5.0; python_version < '3'
+     pytest>=6.2; python_version >= '3'
      pytest-xdist
      greenlet != 0.4.17
      mock; python_version < '3.3'
@@ -74,9 +76,11 @@ setenv=
     sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file}
 
     postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql}
+    py2{,7}-postgresql: POSTGRESQL={env:TOX_POSTGRESQL_PY2K:{env:TOX_POSTGRESQL:--db postgresql}}
     py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000}
 
     mysql: MYSQL={env:TOX_MYSQL:--db mysql}
+    py2{,7}-mysql: MYSQL={env:TOX_MYSQL_PY2K:{env:TOX_MYSQL:--db mysql}}
     mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql}
     py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql}
 
@@ -89,7 +93,7 @@ setenv=
 # tox as of 2.0 blocks all environment variables from the
 # outside, unless they are here (or in TOX_TESTENV_PASSENV,
 # wildcards OK).  Need at least these
-passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_MYSQL TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS EXTRA_PG_DRIVERS EXTRA_MYSQL_DRIVERS
+passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_POSTGRESQL_PY2K TOX_MYSQL TOX_MYSQL_PY2K TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS EXTRA_PG_DRIVERS EXTRA_MYSQL_DRIVERS
 
 # for nocext, we rm *.so in lib in case we are doing usedevelop=True
 commands=