From f1e96cb0874927a475d0c111393b7861796dd758 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 10 Jan 2021 13:44:14 -0500 Subject: [PATCH] reinvent xdist hooks in terms of pytest fixtures To allow the "connection" pytest fixture and others work correctly in conjunction with setup/teardown that expects to be external to the transaction, remove and prevent any usage of "xdist" style names that are hardcoded by pytest to run inside of fixtures, even function level ones. Instead use pytest autouse fixtures to implement our own r"setup|teardown_test(?:_class)?" methods so that we can ensure function-scoped fixtures are run within them. A new more explicit flow is set up within plugin_base and pytestplugin such that the order of setup/teardown steps, which there are now many, is fully documented and controllable. New granularity has been added to the test teardown phase to distinguish between "end of the test" when lock-holding structures on connections should be released to allow for table drops, vs. "end of the test plus its teardown steps" when we can perform final cleanup on connections and run assertions that everything is closed out. From there we can remove most of the defensive "tear down everything" logic inside of engines which for many years would frequently dispose of pools over and over again, creating for a broken and expensive connection flow. A quick test shows that running test/sql/ against a single Postgresql engine with the new approach uses 75% fewer new connections, creating 42 new connections total, vs. 164 new connections total with the previous system. As part of this, the new fixtures metadata/connection/future_connection have been integrated such that they can be combined together effectively. The fixture_session(), provide_metadata() fixtures have been improved, including that fixture_session() now strongly references sessions which are explicitly torn down before table drops occur afer a test. Major changes have been made to the ConnectionKiller such that it now features different "scopes" for testing engines and will limit its cleanup to those testing engines corresponding to end of test, end of test class, or end of test session. The system by which it tracks DBAPI connections has been reworked, is ultimately somewhat similar to how it worked before but is organized more clearly along with the proxy-tracking logic. A "testing_engine" fixture is also added that works as a pytest fixture rather than a standalone function. The connection cleanup logic should now be very robust, as we now can use the same global connection pools for the whole suite without ever disposing them, while also running a query for PostgreSQL locks remaining after every test and assert there are no open transactions leaking between tests at all. Additional steps are added that also accommodate for asyncio connections not explicitly closed, as is the case for legacy sync-style tests as well as the async tests themselves. As always, hundreds of tests are further refined to use the new fixtures where problems with loose connections were identified, largely as a result of the new PostgreSQL assertions, many more tests have moved from legacy patterns into the newest. An unfortunate discovery during the creation of this system is that autouse fixtures (as well as if they are set up by @pytest.mark.usefixtures) are not usable at our current scale with pytest 4.6.11 running under Python 2. It's unclear if this is due to the older version of pytest or how it implements itself for Python 2, as well as if the issue is CPU slowness or just large memory use, but collecting the full span of tests takes over a minute for a single process when any autouse fixtures are in place and on CI the jobs just time out after ten minutes. So at the moment this patch also reinvents a small version of "autouse" fixtures when py2k is running, which skips generating the real fixture and instead uses two global pytest fixtures (which don't seem to impact performance) to invoke the "autouse" fixtures ourselves outside of pytest. This will limit our ability to do more with fixtures until we can remove py2k support. py.test is still observed to be much slower in collection in the 4.6.11 version compared to modern 6.2 versions, so add support for new TOX_POSTGRESQL_PY2K and TOX_MYSQL_PY2K environment variables that will run the suite for fewer backends under Python 2. For Python 3 pin pytest to modern 6.2 versions where performance for collection has been improved greatly. Includes the following improvements: Fixed bug in asyncio connection pool where ``asyncio.TimeoutError`` would be raised rather than :class:`.exc.TimeoutError`. Also repaired the :paramref:`_sa.create_engine.pool_timeout` parameter set to zero when using the async engine, which previously would ignore the timeout and block rather than timing out immediately as is the behavior with regular :class:`.QueuePool`. For asyncio the connection pool will now also not interact at all with an asyncio connection whose ConnectionFairy is being garbage collected; a warning that the connection was not properly closed is emitted and the connection is discarded. Within the test suite the ConnectionKiller is now maintaining strong references to all DBAPI connections and ensuring they are released when tests end, including those whose ConnectionFairy proxies are GCed. Identified cx_Oracle.stmtcachesize as a major factor in Oracle test scalability issues, this can be reset on a per-test basis rather than setting it to zero across the board. the addition of this flag has resolved the long-standing oracle "two task" error problem. For SQL Server, changed the temp table style used by the "suite" tests to be the double-pound-sign, i.e. global, variety, which is much easier to test generically. There are already reflection tests that are more finely tuned to both styles of temp table within the mssql test suite. Additionally, added an extra step to the "dropfirst" mechanism for SQL Server that will remove all foreign key constraints first as some issues were observed when using this flag when multiple schemas had not been torn down. Identified and fixed two subtle failure modes in the engine, when commit/rollback fails in a begin() context manager, the connection is explicitly closed, and when "initialize()" fails on the first new connection of a dialect, the transactional state on that connection is still rolled back. Fixes: #5826 Fixes: #5827 Change-Id: Ib1d05cb8c7cf84f9a4bfd23df397dc23c9329bfe --- doc/build/changelog/unreleased_14/5823.rst | 13 + doc/build/changelog/unreleased_14/5827.rst | 10 + lib/sqlalchemy/dialects/mssql/base.py | 13 +- lib/sqlalchemy/dialects/mssql/provision.py | 34 +- lib/sqlalchemy/dialects/oracle/cx_oracle.py | 1 + lib/sqlalchemy/dialects/oracle/provision.py | 42 +- lib/sqlalchemy/dialects/postgresql/asyncpg.py | 1 - .../dialects/postgresql/provision.py | 21 + lib/sqlalchemy/dialects/sqlite/provision.py | 6 +- lib/sqlalchemy/engine/base.py | 18 +- lib/sqlalchemy/engine/create.py | 9 +- lib/sqlalchemy/future/engine.py | 15 +- lib/sqlalchemy/pool/base.py | 35 +- lib/sqlalchemy/testing/__init__.py | 2 + lib/sqlalchemy/testing/assertions.py | 8 + lib/sqlalchemy/testing/config.py | 4 + lib/sqlalchemy/testing/engines.py | 179 +++-- lib/sqlalchemy/testing/fixtures.py | 245 ++++--- lib/sqlalchemy/testing/plugin/bootstrap.py | 5 + lib/sqlalchemy/testing/plugin/plugin_base.py | 39 +- lib/sqlalchemy/testing/plugin/pytestplugin.py | 188 ++++- .../testing/plugin/reinvent_fixtures_py2k.py | 112 +++ lib/sqlalchemy/testing/provision.py | 11 +- .../testing/suite/test_reflection.py | 2 +- lib/sqlalchemy/testing/suite/test_results.py | 32 +- lib/sqlalchemy/testing/suite/test_types.py | 74 +- lib/sqlalchemy/testing/util.py | 60 +- lib/sqlalchemy/util/queue.py | 15 +- test/aaa_profiling/test_compiler.py | 2 +- test/aaa_profiling/test_memusage.py | 4 +- test/aaa_profiling/test_misc.py | 2 +- test/aaa_profiling/test_orm.py | 6 +- test/aaa_profiling/test_pool.py | 2 +- test/base/test_events.py | 22 +- test/base/test_inspect.py | 2 +- test/base/test_tutorials.py | 4 +- test/dialect/mssql/test_compiler.py | 2 +- test/dialect/mssql/test_deprecations.py | 2 +- test/dialect/mssql/test_query.py | 3 +- test/dialect/mysql/test_compiler.py | 4 +- test/dialect/mysql/test_reflection.py | 2 +- test/dialect/oracle/test_compiler.py | 2 +- test/dialect/oracle/test_dialect.py | 4 +- test/dialect/oracle/test_reflection.py | 16 +- test/dialect/oracle/test_types.py | 4 +- test/dialect/postgresql/test_async_pg_py3k.py | 4 +- test/dialect/postgresql/test_compiler.py | 8 +- test/dialect/postgresql/test_dialect.py | 21 +- test/dialect/postgresql/test_query.py | 6 +- test/dialect/postgresql/test_reflection.py | 499 ++++++------- test/dialect/postgresql/test_types.py | 692 ++++++++---------- test/dialect/test_sqlite.py | 32 +- test/engine/test_ddlevents.py | 4 +- test/engine/test_deprecations.py | 16 +- test/engine/test_execute.py | 364 ++++----- test/engine/test_logging.py | 16 +- test/engine/test_pool.py | 57 +- test/engine/test_processors.py | 10 +- test/engine/test_reconnect.py | 20 +- test/engine/test_reflection.py | 9 +- test/engine/test_transaction.py | 21 +- test/ext/asyncio/test_engine_py3k.py | 16 +- test/ext/declarative/test_inheritance.py | 4 +- test/ext/declarative/test_reflection.py | 127 ++-- test/ext/test_associationproxy.py | 16 +- test/ext/test_baked.py | 2 +- test/ext/test_compiler.py | 2 +- test/ext/test_extendedattr.py | 6 +- test/ext/test_horizontal_shard.py | 32 +- test/ext/test_hybrid.py | 2 +- test/ext/test_mutable.py | 15 +- test/ext/test_orderinglist.py | 16 +- test/orm/declarative/test_basic.py | 4 +- test/orm/declarative/test_concurrency.py | 2 +- test/orm/declarative/test_inheritance.py | 4 +- test/orm/declarative/test_mixin.py | 4 +- test/orm/declarative/test_reflection.py | 5 +- test/orm/inheritance/test_basic.py | 49 +- test/orm/test_attributes.py | 8 +- test/orm/test_bind.py | 74 +- test/orm/test_collection.py | 5 +- test/orm/test_compile.py | 2 +- test/orm/test_cycles.py | 3 +- test/orm/test_deprecations.py | 80 +- test/orm/test_eager_relations.py | 14 +- test/orm/test_events.py | 7 +- test/orm/test_froms.py | 106 +-- test/orm/test_lazy_relations.py | 57 +- test/orm/test_load_on_fks.py | 30 +- test/orm/test_mapper.py | 6 +- test/orm/test_options.py | 4 +- test/orm/test_query.py | 37 +- test/orm/test_rel_fn.py | 2 +- test/orm/test_relationships.py | 6 +- test/orm/test_selectin_relations.py | 50 +- test/orm/test_session.py | 20 +- test/orm/test_subquery_relations.py | 50 +- test/orm/test_transaction.py | 681 +++++++++-------- test/orm/test_unitofwork.py | 8 - test/orm/test_unitofworkv2.py | 3 +- test/requirements.py | 16 +- test/sql/test_case_statement.py | 4 +- test/sql/test_compare.py | 2 +- test/sql/test_compiler.py | 2 +- test/sql/test_defaults.py | 5 +- test/sql/test_deprecations.py | 6 +- test/sql/test_external_traversal.py | 14 +- test/sql/test_from_linter.py | 8 +- test/sql/test_functions.py | 10 +- test/sql/test_metadata.py | 2 +- test/sql/test_operators.py | 8 +- test/sql/test_resultset.py | 15 +- test/sql/test_sequences.py | 4 +- test/sql/test_types.py | 11 +- tox.ini | 8 +- 115 files changed, 2638 insertions(+), 2092 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/5823.rst create mode 100644 doc/build/changelog/unreleased_14/5827.rst create mode 100644 lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py diff --git a/doc/build/changelog/unreleased_14/5823.rst b/doc/build/changelog/unreleased_14/5823.rst new file mode 100644 index 0000000000..74debdaa93 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5823.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, pool, asyncio + :tickets: 5823 + + When using an asyncio engine, the connection pool will now detach and + discard a pooled connection that is was not explicitly closed/returned to + the pool when its tracking object is garbage collected, emitting a warning + that the connection was not properly closed. As this operation occurs + during Python gc finalizers, it's not safe to run any IO operations upon + the connection including transaction rollback or connection close as this + will often be outside of the event loop. + + diff --git a/doc/build/changelog/unreleased_14/5827.rst b/doc/build/changelog/unreleased_14/5827.rst new file mode 100644 index 0000000000..d5c8acd8c5 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5827.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, asyncio + :tickets: 5827 + + Fixed bug in asyncio connection pool where ``asyncio.TimeoutError`` would + be raised rather than :class:`.exc.TimeoutError`. Also repaired the + :paramref:`_sa.create_engine.pool_timeout` parameter set to zero when using + the async engine, which previously would ignore the timeout and block + rather than timing out immediately as is the behavior with regular + :class:`.QueuePool`. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 538679fcf4..0227e515d3 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2785,15 +2785,14 @@ class MSDialect(default.DefaultDialect): def has_table(self, connection, tablename, dbname, owner, schema): if tablename.startswith("#"): # temporary table tables = ischema.mssql_temp_table_columns - result = connection.execute( - sql.select(tables.c.table_name) - .where( - tables.c.table_name.like( - self._temp_table_name_like_pattern(tablename) - ) + + s = sql.select(tables.c.table_name).where( + tables.c.table_name.like( + self._temp_table_name_like_pattern(tablename) ) - .limit(1) ) + + result = connection.execute(s.limit(1)) return result.scalar() is not None else: tables = ischema.tables diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py index 269eb164f7..56f3305a70 100644 --- a/lib/sqlalchemy/dialects/mssql/provision.py +++ b/lib/sqlalchemy/dialects/mssql/provision.py @@ -1,6 +1,14 @@ +from sqlalchemy import inspect +from sqlalchemy import Integer from ... import create_engine from ... import exc +from ...schema import Column +from ...schema import DropConstraint +from ...schema import ForeignKeyConstraint +from ...schema import MetaData +from ...schema import Table from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_pre_tables from ...testing.provision import drop_db from ...testing.provision import get_temp_table_name from ...testing.provision import log @@ -38,7 +46,6 @@ def _mssql_drop_ignore(conn, ident): # "where database_id=db_id('%s')" % ident): # log.info("killing SQL server session %s", row['session_id']) # conn.exec_driver_sql("kill %s" % row['session_id']) - conn.exec_driver_sql("drop database %s" % ident) log.info("Reaped db: %s", ident) return True @@ -83,4 +90,27 @@ def _mssql_temp_table_keyword_args(cfg, eng): @get_temp_table_name.for_db("mssql") def _mssql_get_temp_table_name(cfg, eng, base_name): - return "#" + base_name + return "##" + base_name + + +@drop_all_schema_objects_pre_tables.for_db("mssql") +def drop_all_schema_objects_pre_tables(cfg, eng): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + inspector = inspect(conn) + for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2): + for tname in inspector.get_table_names(schema=schema): + tb = Table( + tname, + MetaData(), + Column("x", Integer), + Column("y", Integer), + schema=schema, + ) + for fk in inspect(conn).get_foreign_keys(tname, schema=schema): + conn.execute( + DropConstraint( + ForeignKeyConstraint( + [tb.c.x], [tb.c.y], name=fk["name"] + ) + ) + ) diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 042443692d..b8b4df760c 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -93,6 +93,7 @@ The parameters accepted by the cx_oracle dialect are as follows: * ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail. + .. _cx_oracle_unicode: Unicode diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py index d51131c0b6..e0dadd58ea 100644 --- a/lib/sqlalchemy/dialects/oracle/provision.py +++ b/lib/sqlalchemy/dialects/oracle/provision.py @@ -6,11 +6,11 @@ from ...testing.provision import create_db from ...testing.provision import drop_db from ...testing.provision import follower_url_from_main from ...testing.provision import log +from ...testing.provision import post_configure_engine from ...testing.provision import run_reap_dbs from ...testing.provision import set_default_schema_on_connection -from ...testing.provision import stop_test_class +from ...testing.provision import stop_test_class_outside_fixtures from ...testing.provision import temp_table_keyword_args -from ...testing.provision import update_db_opts @create_db.for_db("oracle") @@ -57,21 +57,39 @@ def _oracle_drop_db(cfg, eng, ident): _ora_drop_ignore(conn, "%s_ts2" % ident) -@update_db_opts.for_db("oracle") -def _oracle_update_db_opts(db_url, db_opts): - pass +@stop_test_class_outside_fixtures.for_db("oracle") +def stop_test_class_outside_fixtures(config, db, cls): + with db.begin() as conn: + # run magic command to get rid of identity sequences + # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa E501 + conn.exec_driver_sql("purge recyclebin") -@stop_test_class.for_db("oracle") -def stop_test_class(config, db, cls): - """run magic command to get rid of identity sequences + # clear statement cache on all connections that were used + # https://github.com/oracle/python-cx_Oracle/issues/519 - # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ + for cx_oracle_conn in _all_conns: + try: + sc = cx_oracle_conn.stmtcachesize + except db.dialect.dbapi.InterfaceError: + # connection closed + pass + else: + cx_oracle_conn.stmtcachesize = 0 + cx_oracle_conn.stmtcachesize = sc + _all_conns.clear() - """ - with db.begin() as conn: - conn.exec_driver_sql("purge recyclebin") +_all_conns = set() + + +@post_configure_engine.for_db("oracle") +def _oracle_post_configure_engine(url, engine, follower_ident): + from sqlalchemy import event + + @event.listens_for(engine, "checkout") + def checkout(dbapi_con, con_record, con_proxy): + _all_conns.add(dbapi_con) @run_reap_dbs.for_db("oracle") diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 7c6e8fb02c..e542c77f43 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -670,7 +670,6 @@ class AsyncAdapt_asyncpg_connection: def rollback(self): if self._started: self.await_(self._transaction.rollback()) - self._transaction = None self._started = False diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py index d345cdfdfe..70c3908000 100644 --- a/lib/sqlalchemy/dialects/postgresql/provision.py +++ b/lib/sqlalchemy/dialects/postgresql/provision.py @@ -8,6 +8,7 @@ from ...testing.provision import drop_all_schema_objects_post_tables from ...testing.provision import drop_all_schema_objects_pre_tables from ...testing.provision import drop_db from ...testing.provision import log +from ...testing.provision import prepare_for_drop_tables from ...testing.provision import set_default_schema_on_connection from ...testing.provision import temp_table_keyword_args @@ -102,3 +103,23 @@ def drop_all_schema_objects_post_tables(cfg, eng): postgresql.ENUM(name=enum["name"], schema=enum["schema"]) ) ) + + +@prepare_for_drop_tables.for_db("postgresql") +def prepare_for_drop_tables(config, connection): + """Ensure there are no locks on the current username/database.""" + + result = connection.exec_driver_sql( + "select pid, state, wait_event_type, query " + # "select pg_terminate_backend(pid), state, wait_event_type " + "from pg_stat_activity where " + "usename=current_user " + "and datname=current_database() and state='idle in transaction' " + "and pid != pg_backend_pid()" + ) + rows = result.all() # noqa + assert not rows, ( + "PostgreSQL may not be able to DROP tables due to " + "idle in transaction: %s" + % ("; ".join(row._mapping["query"] for row in rows)) + ) diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py index f26c21e223..a481be27ef 100644 --- a/lib/sqlalchemy/dialects/sqlite/provision.py +++ b/lib/sqlalchemy/dialects/sqlite/provision.py @@ -7,7 +7,7 @@ from ...testing.provision import follower_url_from_main from ...testing.provision import log from ...testing.provision import post_configure_engine from ...testing.provision import run_reap_dbs -from ...testing.provision import stop_test_class +from ...testing.provision import stop_test_class_outside_fixtures from ...testing.provision import temp_table_keyword_args @@ -57,8 +57,8 @@ def _sqlite_drop_db(cfg, eng, ident): os.remove(path) -@stop_test_class.for_db("sqlite") -def stop_test_class(config, db, cls): +@stop_test_class_outside_fixtures.for_db("sqlite") +def stop_test_class_outside_fixtures(config, db, cls): with db.connect() as conn: files = [ row.file diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 50f00c025d..72d66b7c82 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -2729,14 +2729,16 @@ class Engine(Connectable, log.Identified): return self.conn def __exit__(self, type_, value, traceback): - - if type_ is not None: - self.transaction.rollback() - else: - if self.transaction.is_active: - self.transaction.commit() - if not self.close_with_result: - self.conn.close() + try: + if type_ is not None: + if self.transaction.is_active: + self.transaction.rollback() + else: + if self.transaction.is_active: + self.transaction.commit() + finally: + if not self.close_with_result: + self.conn.close() def begin(self, close_with_result=False): """Return a context manager delivering a :class:`_engine.Connection` diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index f89be1809f..72d232085e 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -655,9 +655,12 @@ def create_engine(url, **kwargs): c = base.Connection( engine, connection=dbapi_connection, _has_events=False ) - c._execution_options = util.immutabledict() - dialect.initialize(c) - dialect.do_rollback(c.connection) + c._execution_options = util.EMPTY_DICT + + try: + dialect.initialize(c) + finally: + dialect.do_rollback(c.connection) # previously, the "first_connect" event was used here, which was then # scaled back if the "on_connect" handler were present. now, diff --git a/lib/sqlalchemy/future/engine.py b/lib/sqlalchemy/future/engine.py index d2f609326a..bfdcdfc7f8 100644 --- a/lib/sqlalchemy/future/engine.py +++ b/lib/sqlalchemy/future/engine.py @@ -368,12 +368,15 @@ class Engine(_LegacyEngine): return self.conn def __exit__(self, type_, value, traceback): - if type_ is not None: - self.transaction.rollback() - else: - if self.transaction.is_active: - self.transaction.commit() - self.conn.close() + try: + if type_ is not None: + if self.transaction.is_active: + self.transaction.rollback() + else: + if self.transaction.is_active: + self.transaction.commit() + finally: + self.conn.close() def begin(self): """Return a :class:`_future.Connection` object with a transaction diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 7c9509e452..6c3aad037f 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -426,6 +426,7 @@ class _ConnectionRecord(object): rec._checkin_failed(err) echo = pool._should_log_debug() fairy = _ConnectionFairy(dbapi_connection, rec, echo) + rec.fairy_ref = weakref.ref( fairy, lambda ref: _finalize_fairy @@ -609,6 +610,15 @@ def _finalize_fairy( assert connection is None connection = connection_record.connection + dont_restore_gced = pool._is_asyncio + + if dont_restore_gced: + detach = not connection_record or ref + can_manipulate_connection = not ref + else: + detach = not connection_record + can_manipulate_connection = True + if connection is not None: if connection_record and echo: pool.logger.debug( @@ -620,13 +630,26 @@ def _finalize_fairy( connection, connection_record, echo ) assert fairy.connection is connection - fairy._reset(pool) + if can_manipulate_connection: + fairy._reset(pool) + + if detach: + if connection_record: + fairy._pool = pool + fairy.detach() + + if can_manipulate_connection: + if pool.dispatch.close_detached: + pool.dispatch.close_detached(connection) + + pool._close_connection(connection) + else: + util.warn( + "asyncio connection is being garbage " + "collected without being properly closed: %r" + % connection + ) - # Immediately close detached instances - if not connection_record: - if pool.dispatch.close_detached: - pool.dispatch.close_detached(connection) - pool._close_connection(connection) except BaseException as e: pool.logger.error( "Exception during reset or similar", exc_info=True diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index 191252bfbb..9f2d0b857c 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -29,8 +29,10 @@ from .assertions import in_ # noqa from .assertions import is_ # noqa from .assertions import is_false # noqa from .assertions import is_instance_of # noqa +from .assertions import is_none # noqa from .assertions import is_not # noqa from .assertions import is_not_ # noqa +from .assertions import is_not_none # noqa from .assertions import is_true # noqa from .assertions import le_ # noqa from .assertions import ne_ # noqa diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 0a2aed9d85..db530a961b 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -232,6 +232,14 @@ def is_false(a, msg=None): is_(bool(a), False, msg=msg) +def is_none(a, msg=None): + is_(a, None, msg=msg) + + +def is_not_none(a, msg=None): + is_not(a, None, msg=msg) + + def is_(a, b, msg=None): """Assert a is b, with repr messaging on failure.""" assert a is b, msg or "%r is not %r" % (a, b) diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index f64153f338..750671f9f7 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -97,6 +97,10 @@ def get_current_test_name(): return _fixture_functions.get_current_test_name() +def mark_base_test_class(): + return _fixture_functions.mark_base_test_class() + + class Config(object): def __init__(self, db, db_opts, options, file_config): self._set_name(db) diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index a4c1f3973b..8b334fde20 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -7,6 +7,7 @@ from __future__ import absolute_import +import collections import re import warnings import weakref @@ -20,26 +21,29 @@ from .. import pool class ConnectionKiller(object): def __init__(self): self.proxy_refs = weakref.WeakKeyDictionary() - self.testing_engines = weakref.WeakKeyDictionary() - self.conns = set() + self.testing_engines = collections.defaultdict(set) + self.dbapi_connections = set() def add_pool(self, pool): - event.listen(pool, "connect", self.connect) - event.listen(pool, "checkout", self.checkout) - event.listen(pool, "invalidate", self.invalidate) - - def add_engine(self, engine): - self.add_pool(engine.pool) - self.testing_engines[engine] = True + event.listen(pool, "checkout", self._add_conn) + event.listen(pool, "checkin", self._remove_conn) + event.listen(pool, "close", self._remove_conn) + event.listen(pool, "close_detached", self._remove_conn) + # note we are keeping "invalidated" here, as those are still + # opened connections we would like to roll back + + def _add_conn(self, dbapi_con, con_record, con_proxy): + self.dbapi_connections.add(dbapi_con) + self.proxy_refs[con_proxy] = True - def connect(self, dbapi_conn, con_record): - self.conns.add((dbapi_conn, con_record)) + def _remove_conn(self, dbapi_conn, *arg): + self.dbapi_connections.discard(dbapi_conn) - def checkout(self, dbapi_con, con_record, con_proxy): - self.proxy_refs[con_proxy] = True + def add_engine(self, engine, scope): + self.add_pool(engine.pool) - def invalidate(self, dbapi_con, con_record, exception): - self.conns.discard((dbapi_con, con_record)) + assert scope in ("class", "global", "function", "fixture") + self.testing_engines[scope].add(engine) def _safe(self, fn): try: @@ -54,53 +58,76 @@ class ConnectionKiller(object): if rec is not None and rec.is_valid: self._safe(rec.rollback) - def close_all(self): + def checkin_all(self): + # run pool.checkin() for all ConnectionFairy instances we have + # tracked. + for rec in list(self.proxy_refs): if rec is not None and rec.is_valid: - self._safe(rec._close) - - def _after_test_ctx(self): - # this can cause a deadlock with pg8000 - pg8000 acquires - # prepared statement lock inside of rollback() - if async gc - # is collecting in finalize_fairy, deadlock. - # not sure if this should be for non-cpython only. - # note that firebird/fdb definitely needs this though - for conn, rec in list(self.conns): - if rec.connection is None: - # this is a hint that the connection is closed, which - # is causing segfaults on mysqlclient due to - # https://github.com/PyMySQL/mysqlclient-python/issues/270; - # try to work around here - continue - self._safe(conn.rollback) - - def _stop_test_ctx(self): - if config.options.low_connections: - self._stop_test_ctx_minimal() - else: - self._stop_test_ctx_aggressive() + self.dbapi_connections.discard(rec.connection) + self._safe(rec._checkin) - def _stop_test_ctx_minimal(self): - self.close_all() + # for fairy refs that were GCed and could not close the connection, + # such as asyncio, roll back those remaining connections + for con in self.dbapi_connections: + self._safe(con.rollback) + self.dbapi_connections.clear() - self.conns = set() + def close_all(self): + self.checkin_all() - for rec in list(self.testing_engines): - if rec is not config.db: - rec.dispose() + def prepare_for_drop_tables(self, connection): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return - def _stop_test_ctx_aggressive(self): - self.close_all() - for conn, rec in list(self.conns): - self._safe(conn.close) - rec.connection = None + from . import provision + + provision.prepare_for_drop_tables(connection.engine.url, connection) + + def _drop_testing_engines(self, scope): + eng = self.testing_engines[scope] + for rec in list(eng): + for proxy_ref in list(self.proxy_refs): + if proxy_ref is not None and proxy_ref.is_valid: + if ( + proxy_ref._pool is not None + and proxy_ref._pool is rec.pool + ): + self._safe(proxy_ref._checkin) + rec.dispose() + eng.clear() + + def after_test(self): + self._drop_testing_engines("function") + + def after_test_outside_fixtures(self, test): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return + + if test.__class__.__leave_connections_for_teardown__: + return - self.conns = set() - for rec in list(self.testing_engines): - if hasattr(rec, "sync_engine"): - rec.sync_engine.dispose() - else: - rec.dispose() + self.checkin_all() + + # on PostgreSQL, this will test for any "idle in transaction" + # connections. useful to identify tests with unusual patterns + # that can't be cleaned up correctly. + from . import provision + + with config.db.connect() as conn: + provision.prepare_for_drop_tables(conn.engine.url, conn) + + def stop_test_class_inside_fixtures(self): + self.checkin_all() + self._drop_testing_engines("function") + self._drop_testing_engines("class") + + def final_cleanup(self): + self.checkin_all() + for scope in self.testing_engines: + self._drop_testing_engines(scope) def assert_all_closed(self): for rec in self.proxy_refs: @@ -111,20 +138,6 @@ class ConnectionKiller(object): testing_reaper = ConnectionKiller() -def drop_all_tables(metadata, bind): - testing_reaper.close_all() - if hasattr(bind, "close"): - bind.close() - - if not config.db.dialect.supports_alter: - from . import assertions - - with assertions.expect_warnings("Can't sort tables", assert_=False): - metadata.drop_all(bind) - else: - metadata.drop_all(bind) - - @decorator def assert_conns_closed(fn, *args, **kw): try: @@ -147,7 +160,7 @@ def rollback_open_connections(fn, *args, **kw): def close_first(fn, *args, **kw): """Decorator that closes all connections before fn execution.""" - testing_reaper.close_all() + testing_reaper.checkin_all() fn(*args, **kw) @@ -157,7 +170,7 @@ def close_open_connections(fn, *args, **kw): try: fn(*args, **kw) finally: - testing_reaper.close_all() + testing_reaper.checkin_all() def all_dialects(exclude=None): @@ -239,12 +252,14 @@ def reconnecting_engine(url=None, options=None): return engine -def testing_engine(url=None, options=None, future=False, asyncio=False): +def testing_engine(url=None, options=None, future=None, asyncio=False): """Produce an engine configured by --options with optional overrides.""" if asyncio: from sqlalchemy.ext.asyncio import create_async_engine as create_engine - elif future or config.db and config.db._is_future: + elif future or ( + config.db and config.db._is_future and future is not False + ): from sqlalchemy.future import create_engine else: from sqlalchemy import create_engine @@ -252,8 +267,10 @@ def testing_engine(url=None, options=None, future=False, asyncio=False): if not options: use_reaper = True + scope = "function" else: use_reaper = options.pop("use_reaper", True) + scope = options.pop("scope", "function") url = url or config.db.url @@ -268,16 +285,20 @@ def testing_engine(url=None, options=None, future=False, asyncio=False): default_opt.update(options) engine = create_engine(url, **options) - if asyncio: - engine.sync_engine._has_events = True - else: - engine._has_events = True # enable event blocks, helps with profiling + + if scope == "global": + if asyncio: + engine.sync_engine._has_events = True + else: + engine._has_events = ( + True # enable event blocks, helps with profiling + ) if isinstance(engine.pool, pool.QueuePool): engine.pool._timeout = 0 - engine.pool._max_overflow = 5 + engine.pool._max_overflow = 0 if use_reaper: - testing_reaper.add_engine(engine) + testing_reaper.add_engine(engine, scope) return engine diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index ac4d3d8fa0..f19b4652ad 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -5,6 +5,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +import contextlib import re import sys @@ -12,12 +13,11 @@ import sqlalchemy as sa from . import assertions from . import config from . import schema -from .engines import drop_all_tables -from .engines import testing_engine from .entities import BasicEntity from .entities import ComparableEntity from .entities import ComparableMixin # noqa from .util import adict +from .util import drop_all_tables_from_metadata from .. import event from .. import util from ..orm import declarative_base @@ -25,10 +25,8 @@ from ..orm import registry from ..orm.decl_api import DeclarativeMeta from ..schema import sort_tables_and_constraints -# whether or not we use unittest changes things dramatically, -# as far as how pytest collection works. - +@config.mark_base_test_class() class TestBase(object): # A sequence of database names to always run, regardless of the # constraints below. @@ -48,81 +46,114 @@ class TestBase(object): # skipped. __skip_if__ = None + # if True, the testing reaper will not attempt to touch connection + # state after a test is completed and before the outer teardown + # starts + __leave_connections_for_teardown__ = False + def assert_(self, val, msg=None): assert val, msg - # apparently a handful of tests are doing this....OK - def setup(self): - if hasattr(self, "setUp"): - self.setUp() - - def teardown(self): - if hasattr(self, "tearDown"): - self.tearDown() - @config.fixture() def connection(self): - eng = getattr(self, "bind", config.db) + global _connection_fixture_connection + + eng = getattr(self, "bind", None) or config.db conn = eng.connect() trans = conn.begin() - try: - yield conn - finally: - if trans.is_active: - trans.rollback() - conn.close() + + _connection_fixture_connection = conn + yield conn + + _connection_fixture_connection = None + + if trans.is_active: + trans.rollback() + # trans would not be active here if the test is using + # the legacy @provide_metadata decorator still, as it will + # run a close all connections. + conn.close() @config.fixture() - def future_connection(self): + def future_connection(self, future_engine, connection): + # integrate the future_engine and connection fixtures so + # that users of the "connection" fixture will get at the + # "future" connection + yield connection - eng = testing_engine(future=True) - conn = eng.connect() - trans = conn.begin() - try: - yield conn - finally: - if trans.is_active: - trans.rollback() - conn.close() + @config.fixture() + def future_engine(self): + eng = getattr(self, "bind", None) or config.db + with _push_future_engine(eng): + yield + + @config.fixture() + def testing_engine(self): + from . import engines + + def gen_testing_engine( + url=None, options=None, future=False, asyncio=False + ): + if options is None: + options = {} + options["scope"] = "fixture" + return engines.testing_engine( + url=url, options=options, future=future, asyncio=asyncio + ) + + yield gen_testing_engine + + engines.testing_reaper._drop_testing_engines("fixture") @config.fixture() - def metadata(self): + def metadata(self, request): """Provide bound MetaData for a single test, dropping afterwards.""" - from . import engines from ..sql import schema metadata = schema.MetaData() - try: - yield metadata - finally: - engines.drop_all_tables(metadata, config.db) + request.instance.metadata = metadata + yield metadata + del request.instance.metadata + if ( + _connection_fixture_connection + and _connection_fixture_connection.in_transaction() + ): + trans = _connection_fixture_connection.get_transaction() + trans.rollback() + with _connection_fixture_connection.begin(): + drop_all_tables_from_metadata( + metadata, _connection_fixture_connection + ) + else: + drop_all_tables_from_metadata(metadata, config.db) -class FutureEngineMixin(object): - @classmethod - def setup_class(cls): - from ..future.engine import Engine - from sqlalchemy import testing +_connection_fixture_connection = None - facade = Engine._future_facade(config.db) - config._current.push_engine(facade, testing) - super_ = super(FutureEngineMixin, cls) - if hasattr(super_, "setup_class"): - super_.setup_class() +@contextlib.contextmanager +def _push_future_engine(engine): - @classmethod - def teardown_class(cls): - super_ = super(FutureEngineMixin, cls) - if hasattr(super_, "teardown_class"): - super_.teardown_class() + from ..future.engine import Engine + from sqlalchemy import testing + + facade = Engine._future_facade(engine) + config._current.push_engine(facade, testing) + + yield facade - from sqlalchemy import testing + config._current.pop(testing) - config._current.pop(testing) + +class FutureEngineMixin(object): + @config.fixture(autouse=True, scope="class") + def _push_future_engine(self): + eng = getattr(self, "bind", None) or config.db + with _push_future_engine(eng): + yield class TablesTest(TestBase): @@ -151,18 +182,32 @@ class TablesTest(TestBase): other = None sequences = None - @property - def tables_test_metadata(self): - return self._tables_metadata - - @classmethod - def setup_class(cls): + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ cls._init_class() cls._setup_once_tables() cls._setup_once_inserts() + yield + + cls._teardown_once_metadata_bind() + + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): + self._setup_each_tables() + self._setup_each_inserts() + + yield + + self._teardown_each_tables() + + @property + def tables_test_metadata(self): + return self._tables_metadata + @classmethod def _init_class(cls): if cls.run_define_tables == "each": @@ -213,10 +258,10 @@ class TablesTest(TestBase): if self.run_define_tables == "each": self.tables.clear() if self.run_create_tables == "each": - drop_all_tables(self._tables_metadata, self.bind) + drop_all_tables_from_metadata(self._tables_metadata, self.bind) self._tables_metadata.clear() elif self.run_create_tables == "each": - drop_all_tables(self._tables_metadata, self.bind) + drop_all_tables_from_metadata(self._tables_metadata, self.bind) # no need to run deletes if tables are recreated on setup if ( @@ -242,17 +287,10 @@ class TablesTest(TestBase): file=sys.stderr, ) - def setup(self): - self._setup_each_tables() - self._setup_each_inserts() - - def teardown(self): - self._teardown_each_tables() - @classmethod def _teardown_once_metadata_bind(cls): if cls.run_create_tables: - drop_all_tables(cls._tables_metadata, cls.bind) + drop_all_tables_from_metadata(cls._tables_metadata, cls.bind) if cls.run_dispose_bind == "once": cls.dispose_bind(cls.bind) @@ -262,10 +300,6 @@ class TablesTest(TestBase): if cls.run_setup_bind is not None: cls.bind = None - @classmethod - def teardown_class(cls): - cls._teardown_once_metadata_bind() - @classmethod def setup_bind(cls): return config.db @@ -332,38 +366,47 @@ class RemovesEvents(object): self._event_fns.add((target, name, fn)) event.listen(target, name, fn, **kw) - def teardown(self): + @config.fixture(autouse=True, scope="function") + def _remove_events(self): + yield for key in self._event_fns: event.remove(*key) - super_ = super(RemovesEvents, self) - if hasattr(super_, "teardown"): - super_.teardown() - - -class _ORMTest(object): - @classmethod - def teardown_class(cls): - sa.orm.session.close_all_sessions() - sa.orm.clear_mappers() -def create_session(**kw): - kw.setdefault("autoflush", False) - kw.setdefault("expire_on_commit", False) - return sa.orm.Session(config.db, **kw) +_fixture_sessions = set() def fixture_session(**kw): kw.setdefault("autoflush", True) kw.setdefault("expire_on_commit", True) - return sa.orm.Session(config.db, **kw) + sess = sa.orm.Session(config.db, **kw) + _fixture_sessions.add(sess) + return sess + + +def _close_all_sessions(): + # will close all still-referenced sessions + sa.orm.session.close_all_sessions() + _fixture_sessions.clear() + + +def stop_test_class_inside_fixtures(cls): + _close_all_sessions() + sa.orm.clear_mappers() -class ORMTest(_ORMTest, TestBase): +def after_test(): + + if _fixture_sessions: + + _close_all_sessions() + + +class ORMTest(TestBase): pass -class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): +class MappedTest(TablesTest, assertions.AssertsExecutionResults): # 'once', 'each', None run_setup_classes = "once" @@ -372,8 +415,9 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): classes = None - @classmethod - def setup_class(cls): + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ cls._init_class() if cls.classes is None: @@ -384,18 +428,20 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): cls._setup_once_mappers() cls._setup_once_inserts() - @classmethod - def teardown_class(cls): + yield + cls._teardown_once_class() cls._teardown_once_metadata_bind() - def setup(self): + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): self._setup_each_tables() self._setup_each_classes() self._setup_each_mappers() self._setup_each_inserts() - def teardown(self): + yield + sa.orm.session.close_all_sessions() self._teardown_each_mappers() self._teardown_each_classes() @@ -404,7 +450,6 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): @classmethod def _teardown_once_class(cls): cls.classes.clear() - _ORMTest.teardown_class() @classmethod def _setup_once_classes(cls): @@ -440,6 +485,8 @@ class MappedTest(_ORMTest, TablesTest, assertions.AssertsExecutionResults): """ cls_registry = cls.classes + assert cls_registry is not None + class FindFixture(type): def __init__(cls, classname, bases, dict_): cls_registry[classname] = cls diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py index a95c947e20..1f568dfc8f 100644 --- a/lib/sqlalchemy/testing/plugin/bootstrap.py +++ b/lib/sqlalchemy/testing/plugin/bootstrap.py @@ -40,6 +40,11 @@ def load_file_as_module(name): if to_bootstrap == "pytest": sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") + sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True + if sys.version_info < (3, 0): + sys.modules["sqla_reinvent_fixtures"] = load_file_as_module( + "reinvent_fixtures_py2k" + ) sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin") else: raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py index 3594cd276d..7851fbb3ec 100644 --- a/lib/sqlalchemy/testing/plugin/plugin_base.py +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -21,6 +21,9 @@ import logging import re import sys +# flag which indicates we are in the SQLAlchemy testing suite, +# and not that of Alembic or a third party dialect. +bootstrapped_as_sqlalchemy = False log = logging.getLogger("sqlalchemy.testing.plugin_base") @@ -381,7 +384,7 @@ def _init_symbols(options, file_config): @post def _set_disable_asyncio(opt, file_config): - if opt.disable_asyncio: + if opt.disable_asyncio or not py3k: from sqlalchemy.testing import asyncio asyncio.ENABLE_ASYNCIO = False @@ -458,6 +461,8 @@ def _setup_requirements(argument): config.requirements = testing.requires = req_cls() + config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy + @post def _prep_testing_database(options, file_config): @@ -566,17 +571,22 @@ def generate_sub_tests(cls, module): yield cls -def start_test_class(cls): +def start_test_class_outside_fixtures(cls): _do_skips(cls) _setup_engine(cls) def stop_test_class(cls): - # from sqlalchemy import inspect - # assert not inspect(testing.db).get_table_names() + # close sessions, immediate connections, etc. + fixtures.stop_test_class_inside_fixtures(cls) + + # close outstanding connection pool connections, dispose of + # additional engines + engines.testing_reaper.stop_test_class_inside_fixtures() - provision.stop_test_class(config, config.db, cls) - engines.testing_reaper._stop_test_ctx() + +def stop_test_class_outside_fixtures(cls): + provision.stop_test_class_outside_fixtures(config, config.db, cls) try: if not options.low_connections: assertions.global_cleanup_assertions() @@ -590,14 +600,16 @@ def _restore_engine(): def final_process_cleanup(): - engines.testing_reaper._stop_test_ctx_aggressive() + engines.testing_reaper.final_cleanup() assertions.global_cleanup_assertions() _restore_engine() def _setup_engine(cls): if getattr(cls, "__engine_options__", None): - eng = engines.testing_engine(options=cls.__engine_options__) + opts = dict(cls.__engine_options__) + opts["scope"] = "class" + eng = engines.testing_engine(options=opts) config._current.push_engine(eng, testing) @@ -614,7 +626,12 @@ def before_test(test, test_module_name, test_class, test_name): def after_test(test): - engines.testing_reaper._after_test_ctx() + fixtures.after_test() + engines.testing_reaper.after_test() + + +def after_test_fixtures(test): + engines.testing_reaper.after_test_outside_fixtures(test) def _possible_configs_for_cls(cls, reasons=None, sparse=False): @@ -748,6 +765,10 @@ class FixtureFunctions(ABC): def get_current_test_name(self): raise NotImplementedError() + @abc.abstractmethod + def mark_base_test_class(self): + raise NotImplementedError() + _fixture_fn_class = None diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index 46468a07dc..4eaaecebb1 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -17,6 +17,7 @@ import sys import pytest + try: import typing except ImportError: @@ -33,6 +34,14 @@ except ImportError: has_xdist = False +py2k = sys.version_info < (3, 0) +if py2k: + try: + import sqla_reinvent_fixtures as reinvent_fixtures_py2k + except ImportError: + from . import reinvent_fixtures_py2k + + def pytest_addoption(parser): group = parser.getgroup("sqlalchemy") @@ -238,6 +247,10 @@ def pytest_collection_modifyitems(session, config, items): else: newitems.append(item) + if py2k: + for item in newitems: + reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item) + # seems like the functions attached to a test class aren't sorted already? # is that true and why's that? (when using unittest, they're sorted) items[:] = sorted( @@ -251,7 +264,6 @@ def pytest_collection_modifyitems(session, config, items): def pytest_pycollect_makeitem(collector, name, obj): - if inspect.isclass(obj) and plugin_base.want_class(name, obj): from sqlalchemy.testing import config @@ -259,7 +271,6 @@ def pytest_pycollect_makeitem(collector, name, obj): obj = _apply_maybe_async(obj) ctor = getattr(pytest.Class, "from_parent", pytest.Class) - return [ ctor(name=parametrize_cls.__name__, parent=collector) for parametrize_cls in _parametrize_cls(collector.module, obj) @@ -287,12 +298,11 @@ def _is_wrapped_coroutine_function(fn): def _apply_maybe_async(obj, recurse=True): from sqlalchemy.testing import asyncio - setup_names = {"setup", "setup_class", "teardown", "teardown_class"} for name, value in vars(obj).items(): if ( (callable(value) or isinstance(value, classmethod)) and not getattr(value, "_maybe_async_applied", False) - and (name.startswith("test_") or name in setup_names) + and (name.startswith("test_")) and not _is_wrapped_coroutine_function(value) ): is_classmethod = False @@ -317,9 +327,6 @@ def _apply_maybe_async(obj, recurse=True): return obj -_current_class = None - - def _parametrize_cls(module, cls): """implement a class-based version of pytest parametrize.""" @@ -355,63 +362,153 @@ def _parametrize_cls(module, cls): return classes +_current_class = None + + def pytest_runtest_setup(item): from sqlalchemy.testing import asyncio - # here we seem to get called only based on what we collected - # in pytest_collection_modifyitems. So to do class-based stuff - # we have to tear that out. - global _current_class - if not isinstance(item, pytest.Function): return - # ... so we're doing a little dance here to figure it out... + # pytest_runtest_setup runs *before* pytest fixtures with scope="class". + # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest + # for the whole class and has to run things that are across all current + # databases, so we run this outside of the pytest fixture system altogether + # and ensure asyncio greenlet if any engines are async + + global _current_class + if _current_class is None: - asyncio._maybe_async(class_setup, item.parent.parent) + asyncio._maybe_async_provisioning( + plugin_base.start_test_class_outside_fixtures, + item.parent.parent.cls, + ) _current_class = item.parent.parent - # this is needed for the class-level, to ensure that the - # teardown runs after the class is completed with its own - # class-level teardown... def finalize(): global _current_class - asyncio._maybe_async(class_teardown, item.parent.parent) _current_class = None + asyncio._maybe_async_provisioning( + plugin_base.stop_test_class_outside_fixtures, + item.parent.parent.cls, + ) + item.parent.parent.addfinalizer(finalize) - asyncio._maybe_async(test_setup, item) +def pytest_runtest_call(item): + # runs inside of pytest function fixture scope + # before test function runs -def pytest_runtest_teardown(item): from sqlalchemy.testing import asyncio - # ...but this works better as the hook here rather than - # using a finalizer, as the finalizer seems to get in the way - # of the test reporting failures correctly (you get a bunch of - # pytest assertion stuff instead) - asyncio._maybe_async(test_teardown, item) + asyncio._maybe_async( + plugin_base.before_test, + item, + item.parent.module.__name__, + item.parent.cls, + item.name, + ) -def test_setup(item): - plugin_base.before_test( - item, item.parent.module.__name__, item.parent.cls, item.name - ) +def pytest_runtest_teardown(item, nextitem): + # runs inside of pytest function fixture scope + # after test function runs + from sqlalchemy.testing import asyncio -def test_teardown(item): - plugin_base.after_test(item) + asyncio._maybe_async(plugin_base.after_test, item) -def class_setup(item): +@pytest.fixture(scope="class") +def setup_class_methods(request): from sqlalchemy.testing import asyncio - asyncio._maybe_async_provisioning(plugin_base.start_test_class, item.cls) + cls = request.cls + + if hasattr(cls, "setup_test_class"): + asyncio._maybe_async(cls.setup_test_class) + + if py2k: + reinvent_fixtures_py2k.run_class_fixture_setup(request) + + yield + + if py2k: + reinvent_fixtures_py2k.run_class_fixture_teardown(request) + if hasattr(cls, "teardown_test_class"): + asyncio._maybe_async(cls.teardown_test_class) -def class_teardown(item): - plugin_base.stop_test_class(item.cls) + asyncio._maybe_async(plugin_base.stop_test_class, cls) + + +@pytest.fixture(scope="function") +def setup_test_methods(request): + from sqlalchemy.testing import asyncio + + # called for each test + + self = request.instance + + # 1. run outer xdist-style setup + if hasattr(self, "setup_test"): + asyncio._maybe_async(self.setup_test) + + # alembic test suite is using setUp and tearDown + # xdist methods; support these in the test suite + # for the near term + if hasattr(self, "setUp"): + asyncio._maybe_async(self.setUp) + + # 2. run homegrown function level "autouse" fixtures under py2k + if py2k: + reinvent_fixtures_py2k.run_fn_fixture_setup(request) + + # inside the yield: + + # 3. function level "autouse" fixtures under py3k (examples: TablesTest + # define tables / data, MappedTest define tables / mappers / data) + + # 4. function level fixtures defined on test functions themselves, + # e.g. "connection", "metadata" run next + + # 5. pytest hook pytest_runtest_call then runs + + # 6. test itself runs + + yield + + # yield finishes: + + # 7. pytest hook pytest_runtest_teardown hook runs, this is associated + # with fixtures close all sessions, provisioning.stop_test_class(), + # engines.testing_reaper -> ensure all connection pool connections + # are returned, engines created by testing_engine that aren't the + # config engine are disposed + + # 8. function level fixtures defined on test functions + # themselves, e.g. "connection" rolls back the transaction, "metadata" + # emits drop all + + # 9. function level "autouse" fixtures under py3k (examples: TablesTest / + # MappedTest delete table data, possibly drop tables and clear mappers + # depending on the flags defined by the test class) + + # 10. run homegrown function-level "autouse" fixtures under py2k + if py2k: + reinvent_fixtures_py2k.run_fn_fixture_teardown(request) + + asyncio._maybe_async(plugin_base.after_test_fixtures, self) + + # 11. run outer xdist-style teardown + if hasattr(self, "tearDown"): + asyncio._maybe_async(self.tearDown) + + if hasattr(self, "teardown_test"): + asyncio._maybe_async(self.teardown_test) def getargspec(fn): @@ -461,6 +558,8 @@ def %(name)s(%(args)s): # for the wrapped function decorated.__module__ = fn.__module__ decorated.__name__ = fn.__name__ + if hasattr(fn, "pytestmark"): + decorated.pytestmark = fn.pytestmark return decorated return decorate @@ -470,6 +569,11 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): def skip_test_exception(self, *arg, **kw): return pytest.skip.Exception(*arg, **kw) + def mark_base_test_class(self): + return pytest.mark.usefixtures( + "setup_class_methods", "setup_test_methods" + ) + _combination_id_fns = { "i": lambda obj: obj, "r": repr, @@ -647,8 +751,18 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): fn = asyncio._maybe_async_wrapper(fn) # other wrappers may be added here - # now apply FixtureFunctionMarker - fn = fixture(fn) + if py2k and "autouse" in kw: + # py2k workaround for too-slow collection of autouse fixtures + # in pytest 4.6.11. See notes in reinvent_fixtures_py2k for + # rationale. + + # comment this condition out in order to disable the + # py2k workaround entirely. + reinvent_fixtures_py2k.add_fixture(fn, fixture) + else: + # now apply FixtureFunctionMarker + fn = fixture(fn) + return fn if fn: diff --git a/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py new file mode 100644 index 0000000000..36b68417bc --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py @@ -0,0 +1,112 @@ +""" +invent a quick version of pytest autouse fixtures as pytest's unacceptably slow +collection/high memory use in pytest 4.6.11, which is the highest version that +works in py2k. + +by "too-slow" we mean the test suite can't even manage to be collected for a +single process in less than 70 seconds or so and memory use seems to be very +high as well. for two or four workers the job just times out after ten +minutes. + +so instead we have invented a very limited form of these fixtures, as our +current use of "autouse" fixtures are limited to those in fixtures.py. + +assumptions for these fixtures: + +1. we are only using "function" or "class" scope + +2. the functions must be associated with a test class + +3. the fixture functions cannot themselves use pytest fixtures + +4. the fixture functions must use yield, not return + +When py2k support is removed and we can stay on a modern pytest version, this +can all be removed. + + +""" +import collections + + +_py2k_fixture_fn_names = collections.defaultdict(set) +_py2k_class_fixtures = collections.defaultdict( + lambda: collections.defaultdict(set) +) +_py2k_function_fixtures = collections.defaultdict( + lambda: collections.defaultdict(set) +) + +_py2k_cls_fixture_stack = [] +_py2k_fn_fixture_stack = [] + + +def add_fixture(fn, fixture): + assert fixture.scope in ("class", "function") + _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope)) + + +def scan_for_fixtures_to_use_for_class(item): + test_class = item.parent.parent.obj + + for name in _py2k_fixture_fn_names: + for fixture_fn, scope in _py2k_fixture_fn_names[name]: + meth = getattr(test_class, name, None) + if meth and meth.im_func is fixture_fn: + for sup in test_class.__mro__: + if name in sup.__dict__: + if scope == "class": + _py2k_class_fixtures[test_class][sup].add(meth) + elif scope == "function": + _py2k_function_fixtures[test_class][sup].add(meth) + break + break + + +def run_class_fixture_setup(request): + + cls = request.cls + self = cls.__new__(cls) + + fixtures_for_this_class = _py2k_class_fixtures.get(cls) + + if fixtures_for_this_class: + for sup_ in cls.__mro__: + for fn in fixtures_for_this_class.get(sup_, ()): + iter_ = fn(self) + next(iter_) + + _py2k_cls_fixture_stack.append(iter_) + + +def run_class_fixture_teardown(request): + while _py2k_cls_fixture_stack: + iter_ = _py2k_cls_fixture_stack.pop(-1) + try: + next(iter_) + except StopIteration: + pass + + +def run_fn_fixture_setup(request): + cls = request.cls + self = request.instance + + fixtures_for_this_class = _py2k_function_fixtures.get(cls) + + if fixtures_for_this_class: + for sup_ in reversed(cls.__mro__): + for fn in fixtures_for_this_class.get(sup_, ()): + iter_ = fn(self) + next(iter_) + + _py2k_fn_fixture_stack.append(iter_) + + +def run_fn_fixture_teardown(request): + while _py2k_fn_fixture_stack: + iter_ = _py2k_fn_fixture_stack.pop(-1) + try: + next(iter_) + except StopIteration: + pass diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index 4ee0567f22..2fade1c32d 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -67,6 +67,7 @@ def setup_config(db_url, options, file_config, follower_ident): db_url = follower_url_from_main(db_url, follower_ident) db_opts = {} update_db_opts(db_url, db_opts) + db_opts["scope"] = "global" eng = engines.testing_engine(db_url, db_opts) post_configure_engine(db_url, eng, follower_ident) eng.connect().close() @@ -264,6 +265,7 @@ def drop_all_schema_objects(cfg, eng): if config.requirements.schemas.enabled_for_config(cfg): util.drop_all_tables(eng, inspector, schema=cfg.test_schema) + util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2) drop_all_schema_objects_post_tables(cfg, eng) @@ -299,7 +301,7 @@ def update_db_opts(db_url, db_opts): def post_configure_engine(url, engine, follower_ident): """Perform extra steps after configuring an engine for testing. - (For the internal dialects, currently only used by sqlite.) + (For the internal dialects, currently only used by sqlite, oracle) """ pass @@ -375,7 +377,12 @@ def temp_table_keyword_args(cfg, eng): @register.init -def stop_test_class(config, db, testcls): +def prepare_for_drop_tables(config, connection): + pass + + +@register.init +def stop_test_class_outside_fixtures(config, db, testcls): pass diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 6c3c1005ab..de157d028d 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -293,7 +293,7 @@ class ComponentReflectionTest(fixtures.TablesTest): from sqlalchemy import pool return engines.testing_engine( - options=dict(poolclass=pool.StaticPool) + options=dict(poolclass=pool.StaticPool, scope="class"), ) else: return config.db diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index e0fdbe47a7..e8dd6cf2c9 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -261,10 +261,6 @@ class ServerSideCursorsTest( ) return self.engine - def tearDown(self): - engines.testing_reaper.close_all() - self.engine.dispose() - @testing.combinations( ("global_string", True, "select 1", True), ("global_text", True, text("select 1"), True), @@ -309,24 +305,22 @@ class ServerSideCursorsTest( def test_conn_option(self): engine = self._fixture(False) - # should be enabled for this one - result = ( - engine.connect() - .execution_options(stream_results=True) - .exec_driver_sql("select 1") - ) - assert self._is_server_side(result.cursor) + with engine.connect() as conn: + # should be enabled for this one + result = conn.execution_options( + stream_results=True + ).exec_driver_sql("select 1") + assert self._is_server_side(result.cursor) def test_stmt_enabled_conn_option_disabled(self): engine = self._fixture(False) s = select(1).execution_options(stream_results=True) - # not this one - result = ( - engine.connect().execution_options(stream_results=False).execute(s) - ) - assert not self._is_server_side(result.cursor) + with engine.connect() as conn: + # not this one + result = conn.execution_options(stream_results=False).execute(s) + assert not self._is_server_side(result.cursor) def test_aliases_and_ss(self): engine = self._fixture(False) @@ -344,8 +338,7 @@ class ServerSideCursorsTest( assert not self._is_server_side(result.cursor) result.close() - @testing.provide_metadata - def test_roundtrip_fetchall(self): + def test_roundtrip_fetchall(self, metadata): md = self.metadata engine = self._fixture(True) @@ -385,8 +378,7 @@ class ServerSideCursorsTest( 0, ) - @testing.provide_metadata - def test_roundtrip_fetchmany(self): + def test_roundtrip_fetchmany(self, metadata): md = self.metadata engine = self._fixture(True) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 3a5e02c32b..ebcceaae7c 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -511,24 +511,23 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True @testing.fixture - def do_numeric_test(self, metadata): + def do_numeric_test(self, metadata, connection): @testing.emits_warning( r".*does \*not\* support Decimal objects natively" ) def run(type_, input_, output, filter_=None, check_scale=False): t = Table("t", metadata, Column("x", type_)) - t.create(testing.db) - with config.db.begin() as conn: - conn.execute(t.insert(), [{"x": x} for x in input_]) - - result = {row[0] for row in conn.execute(t.select())} - output = set(output) - if filter_: - result = set(filter_(x) for x in result) - output = set(filter_(x) for x in output) - eq_(result, output) - if check_scale: - eq_([str(x) for x in result], [str(x) for x in output]) + t.create(connection) + connection.execute(t.insert(), [{"x": x} for x in input_]) + + result = {row[0] for row in connection.execute(t.select())} + output = set(output) + if filter_: + result = set(filter_(x) for x in result) + output = set(filter_(x) for x in output) + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) return run @@ -1165,40 +1164,39 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): }, ) - def test_eval_none_flag_orm(self): + def test_eval_none_flag_orm(self, connection): Base = declarative_base() class Data(Base): __table__ = self.tables.data_table - s = Session(testing.db) + with Session(connection) as s: + d1 = Data(name="d1", data=None, nulldata=None) + s.add(d1) + s.commit() - d1 = Data(name="d1", data=None, nulldata=None) - s.add(d1) - s.commit() - - s.bulk_insert_mappings( - Data, [{"name": "d2", "data": None, "nulldata": None}] - ) - eq_( - s.query( - cast(self.tables.data_table.c.data, String()), - cast(self.tables.data_table.c.nulldata, String), + s.bulk_insert_mappings( + Data, [{"name": "d2", "data": None, "nulldata": None}] ) - .filter(self.tables.data_table.c.name == "d1") - .first(), - ("null", None), - ) - eq_( - s.query( - cast(self.tables.data_table.c.data, String()), - cast(self.tables.data_table.c.nulldata, String), + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), ) - .filter(self.tables.data_table.c.name == "d2") - .first(), - ("null", None), - ) class JSONLegacyStringCastIndexTest( diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index eb9fcd1cd1..01185c2841 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -14,6 +14,7 @@ import types from . import config from . import mock from .. import inspect +from ..engine import Connection from ..schema import Column from ..schema import DropConstraint from ..schema import DropTable @@ -207,11 +208,13 @@ def fail(msg): @decorator def provide_metadata(fn, *args, **kw): - """Provide bound MetaData for a single test, dropping afterwards.""" + """Provide bound MetaData for a single test, dropping afterwards. - # import cycle that only occurs with py2k's import resolver - # in py3k this can be moved top level. - from . import engines + Legacy; use the "metadata" pytest fixture. + + """ + + from . import fixtures metadata = schema.MetaData() self = args[0] @@ -220,7 +223,31 @@ def provide_metadata(fn, *args, **kw): try: return fn(*args, **kw) finally: - engines.drop_all_tables(metadata, config.db) + # close out some things that get in the way of dropping tables. + # when using the "metadata" fixture, there is a set ordering + # of things that makes sure things are cleaned up in order, however + # the simple "decorator" nature of this legacy function means + # we have to hardcode some of that cleanup ahead of time. + + # close ORM sessions + fixtures._close_all_sessions() + + # integrate with the "connection" fixture as there are many + # tests where it is used along with provide_metadata + if fixtures._connection_fixture_connection: + # TODO: this warning can be used to find all the places + # this is used with connection fixture + # warn("mixing legacy provide metadata with connection fixture") + drop_all_tables_from_metadata( + metadata, fixtures._connection_fixture_connection + ) + # as the provide_metadata fixture is often used with "testing.db", + # when we do the drop we have to commit the transaction so that + # the DB is actually updated as the CREATE would have been + # committed + fixtures._connection_fixture_connection.get_transaction().commit() + else: + drop_all_tables_from_metadata(metadata, config.db) self.metadata = prev_meta @@ -359,6 +386,29 @@ class adict(dict): get_all = __call__ +def drop_all_tables_from_metadata(metadata, engine_or_connection): + from . import engines + + def go(connection): + engines.testing_reaper.prepare_for_drop_tables(connection) + + if not connection.dialect.supports_alter: + from . import assertions + + with assertions.expect_warnings( + "Can't sort tables", assert_=False + ): + metadata.drop_all(connection) + else: + metadata.drop_all(connection) + + if not isinstance(engine_or_connection, Connection): + with engine_or_connection.begin() as connection: + go(connection) + else: + go(engine_or_connection) + + def drop_all_tables(engine, inspector, schema=None, include_names=None): if include_names is not None: diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index 99ecb4fb34..ca5a3abded 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -230,13 +230,16 @@ class AsyncAdaptedQueue: return self.put_nowait(item) try: - if timeout: + if timeout is not None: return self.await_( asyncio.wait_for(self._queue.put(item), timeout) ) else: return self.await_(self._queue.put(item)) - except asyncio.queues.QueueFull as err: + except ( + asyncio.queues.QueueFull, + asyncio.exceptions.TimeoutError, + ) as err: compat.raise_( Full(), replace_context=err, @@ -254,14 +257,18 @@ class AsyncAdaptedQueue: def get(self, block=True, timeout=None): if not block: return self.get_nowait() + try: - if timeout: + if timeout is not None: return self.await_( asyncio.wait_for(self._queue.get(), timeout) ) else: return self.await_(self._queue.get()) - except asyncio.queues.QueueEmpty as err: + except ( + asyncio.queues.QueueEmpty, + asyncio.exceptions.TimeoutError, + ) as err: compat.raise_( Empty(), replace_context=err, diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py index 0202768ae4..968a747008 100644 --- a/test/aaa_profiling/test_compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -18,7 +18,7 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2, metadata metadata = MetaData() diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 75a4f51cf8..a41a8b9f11 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -241,7 +241,7 @@ def assert_no_mappers(): class EnsureZeroed(fixtures.ORMTest): - def setup(self): + def setup_test(self): _sessions.clear() _mapper_registry.clear() @@ -1032,7 +1032,7 @@ class MemUsageWBackendTest(EnsureZeroed): t2_mapper = mapper(T2, t2) t1_mapper.add_property("bar", relationship(t2_mapper)) - s1 = fixture_session() + s1 = Session(testing.db) # this causes the path_registry to be invoked s1.query(t1_mapper)._compile_context() diff --git a/test/aaa_profiling/test_misc.py b/test/aaa_profiling/test_misc.py index db6fd4b718..5b30a3968b 100644 --- a/test/aaa_profiling/test_misc.py +++ b/test/aaa_profiling/test_misc.py @@ -19,7 +19,7 @@ from sqlalchemy.util import classproperty class EnumTest(fixtures.TestBase): __requires__ = ("cpython", "python_profiling_backend") - def setup(self): + def setup_test(self): class SomeEnum(object): # Implements PEP 435 in the minimal fashion needed by SQLAlchemy diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index f163078d80..8116e5f215 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -29,15 +29,13 @@ class NoCache(object): run_setup_bind = "each" @classmethod - def setup_class(cls): - super(NoCache, cls).setup_class() + def setup_test_class(cls): cls._cache = config.db._compiled_cache config.db._compiled_cache = None @classmethod - def teardown_class(cls): + def teardown_test_class(cls): config.db._compiled_cache = cls._cache - super(NoCache, cls).teardown_class() class MergeTest(NoCache, fixtures.MappedTest): diff --git a/test/aaa_profiling/test_pool.py b/test/aaa_profiling/test_pool.py index fd02f91395..da3c1c5256 100644 --- a/test/aaa_profiling/test_pool.py +++ b/test/aaa_profiling/test_pool.py @@ -17,7 +17,7 @@ class QueuePoolTest(fixtures.TestBase, AssertsExecutionResults): def close(self): pass - def setup(self): + def setup_test(self): # create a throwaway pool which # has the effect of initializing # class-level event listeners on Pool, diff --git a/test/base/test_events.py b/test/base/test_events.py index 19f68e9a35..68db5207ca 100644 --- a/test/base/test_events.py +++ b/test/base/test_events.py @@ -16,7 +16,7 @@ from sqlalchemy.testing.util import gc_collect class TearDownLocalEventsFixture(object): - def tearDown(self): + def teardown_test(self): classes = set() for entry in event.base._registrars.values(): for evt_cls in entry: @@ -30,7 +30,7 @@ class TearDownLocalEventsFixture(object): class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test class- and instance-level event registration.""" - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): def event_one(self, x, y): pass @@ -438,7 +438,7 @@ class NamedCallTest(TearDownLocalEventsFixture, fixtures.TestBase): class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase): """test adaption of legacy args""" - def setUp(self): + def setup_test(self): class TargetEventsOne(event.Events): @event._legacy_signature("0.9", ["x", "y"]) def event_three(self, x, y, z, q): @@ -608,7 +608,7 @@ class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase): class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase): - def setUp(self): + def setup_test(self): class TargetEventsOne(event.Events): def event_one(self, x, y): pass @@ -677,7 +677,7 @@ class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase): class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test default target acceptance.""" - def setUp(self): + def setup_test(self): class TargetEventsOne(event.Events): def event_one(self, x, y): pass @@ -734,7 +734,7 @@ class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase): class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test custom target acceptance.""" - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): @classmethod def _accept_with(cls, target): @@ -771,7 +771,7 @@ class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase): class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase): """test that ad-hoc subclasses are garbage collected.""" - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): def some_event(self, x, y): pass @@ -797,7 +797,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase): """Test custom listen functions which change the listener function signature.""" - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): @classmethod def _listen(cls, event_key, add=False): @@ -855,7 +855,7 @@ class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase): class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase): - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): def event_one(self, arg): pass @@ -889,7 +889,7 @@ class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase): class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase): - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): def event_one(self, target, arg): pass @@ -1109,7 +1109,7 @@ class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase): class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase): - def setUp(self): + def setup_test(self): class TargetEvents(event.Events): def event_one(self, target, arg): pass diff --git a/test/base/test_inspect.py b/test/base/test_inspect.py index 15b98c848b..252d0d9777 100644 --- a/test/base/test_inspect.py +++ b/test/base/test_inspect.py @@ -13,7 +13,7 @@ class TestFixture(object): class TestInspection(fixtures.TestBase): - def tearDown(self): + def teardown_test(self): for type_ in list(inspection._registrars): if issubclass(type_, TestFixture): del inspection._registrars[type_] diff --git a/test/base/test_tutorials.py b/test/base/test_tutorials.py index 14e87ef690..6320ef0527 100644 --- a/test/base/test_tutorials.py +++ b/test/base/test_tutorials.py @@ -48,11 +48,11 @@ class DocTest(fixtures.TestBase): ddl.sort_tables_and_constraints = self.orig_sort - def setup(self): + def setup_test(self): self._setup_logger() self._setup_create_table_patcher() - def teardown(self): + def teardown_test(self): self._teardown_create_table_patcher() self._teardown_logger() diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 8119612e1a..f0bb66aa9f 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -1814,7 +1814,7 @@ class CompileIdentityTest(fixtures.TestBase, AssertsCompiledSQL): class SchemaTest(fixtures.TestBase): - def setup(self): + def setup_test(self): t = Table( "sometable", MetaData(), diff --git a/test/dialect/mssql/test_deprecations.py b/test/dialect/mssql/test_deprecations.py index c869182c5a..27709beb05 100644 --- a/test/dialect/mssql/test_deprecations.py +++ b/test/dialect/mssql/test_deprecations.py @@ -31,7 +31,7 @@ class LegacySchemaAliasingTest(fixtures.TestBase, AssertsCompiledSQL): """ - def setup(self): + def setup_test(self): metadata = MetaData() self.t1 = table( "t1", diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py index cdb37cc615..b806b9247f 100644 --- a/test/dialect/mssql/test_query.py +++ b/test/dialect/mssql/test_query.py @@ -455,7 +455,7 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): return testing.db.execution_options(isolation_level="AUTOCOMMIT") @classmethod - def setup_class(cls): + def setup_test_class(cls): with testing.db.connect().execution_options( isolation_level="AUTOCOMMIT" ) as conn: @@ -463,7 +463,6 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): conn.exec_driver_sql("DROP FULLTEXT CATALOG Catalog") except: pass - super(MatchTest, cls).setup_class() @classmethod def insert_data(cls, connection): diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index 62292b9daa..7fd24e8b51 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -991,7 +991,7 @@ class SQLTest(fixtures.TestBase, AssertsCompiledSQL): class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = mysql.dialect() - def setup(self): + def setup_test(self): self.table = Table( "foos", MetaData(), @@ -1062,7 +1062,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): class RegexpCommon(testing.AssertsCompiledSQL): - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py index 40617e59ce..795b2cbd32 100644 --- a/test/dialect/mysql/test_reflection.py +++ b/test/dialect/mysql/test_reflection.py @@ -1115,7 +1115,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): class RawReflectionTest(fixtures.TestBase): __backend__ = True - def setup(self): + def setup_test(self): dialect = mysql.dialect() self.parser = _reflection.MySQLTableDefinitionParser( dialect, dialect.identifier_preparer diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index 1b8b3fb89b..f09346eb32 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -1355,7 +1355,7 @@ class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "oracle" - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index df87fe89fc..32234bf653 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -439,7 +439,7 @@ class OutParamTest(fixtures.TestBase, AssertsExecutionResults): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): with testing.db.begin() as c: c.exec_driver_sql( """ @@ -471,7 +471,7 @@ end; assert isinstance(result.out_parameters["x_out"], int) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as conn: conn.execute(text("DROP PROCEDURE foo")) diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index 81e4e4ab5a..0df4236e25 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -39,7 +39,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): # currently assuming full DBA privs for the user. # don't really know how else to go here unless # we connect as the other user. @@ -85,7 +85,7 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): conn.exec_driver_sql(stmt) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as conn: for stmt in ( """ @@ -379,7 +379,7 @@ class SystemTableTablenamesTest(fixtures.TestBase): __only_on__ = "oracle" __backend__ = True - def setup(self): + def setup_test(self): with testing.db.begin() as conn: conn.exec_driver_sql("create table my_table (id integer)") conn.exec_driver_sql( @@ -389,7 +389,7 @@ class SystemTableTablenamesTest(fixtures.TestBase): "create table foo_table (id integer) tablespace SYSTEM" ) - def teardown(self): + def teardown_test(self): with testing.db.begin() as conn: conn.exec_driver_sql("drop table my_temp_table") conn.exec_driver_sql("drop table my_table") @@ -421,7 +421,7 @@ class DontReflectIOTTest(fixtures.TestBase): __only_on__ = "oracle" __backend__ = True - def setup(self): + def setup_test(self): with testing.db.begin() as conn: conn.exec_driver_sql( """ @@ -438,7 +438,7 @@ class DontReflectIOTTest(fixtures.TestBase): """, ) - def teardown(self): + def teardown_test(self): with testing.db.begin() as conn: conn.exec_driver_sql("drop table admin_docindex") @@ -715,7 +715,7 @@ class DBLinkReflectionTest(fixtures.TestBase): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy.testing import config cls.dblink = config.file_config.get("sqla_testing", "oracle_db_link") @@ -734,7 +734,7 @@ class DBLinkReflectionTest(fixtures.TestBase): ) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as conn: conn.exec_driver_sql("drop synonym test_table_syn") conn.exec_driver_sql("drop table test_table") diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index f008ea0192..8ea7c0e044 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -1011,7 +1011,7 @@ class EuroNumericTest(fixtures.TestBase): __only_on__ = "oracle+cx_oracle" __backend__ = True - def setup(self): + def setup_test(self): connect = testing.db.pool._creator def _creator(): @@ -1023,7 +1023,7 @@ class EuroNumericTest(fixtures.TestBase): self.engine = testing_engine(options={"creator": _creator}) - def teardown(self): + def teardown_test(self): self.engine.dispose() def test_were_getting_a_comma(self): diff --git a/test/dialect/postgresql/test_async_pg_py3k.py b/test/dialect/postgresql/test_async_pg_py3k.py index fadf939b86..f6d48f3c65 100644 --- a/test/dialect/postgresql/test_async_pg_py3k.py +++ b/test/dialect/postgresql/test_async_pg_py3k.py @@ -27,7 +27,7 @@ class AsyncPgTest(fixtures.TestBase): # TODO: remove when Iae6ab95938a7e92b6d42086aec534af27b5577d3 # merges - from sqlalchemy.testing import engines + from sqlalchemy.testing import util as testing_util from sqlalchemy.sql import schema metadata = schema.MetaData() @@ -35,7 +35,7 @@ class AsyncPgTest(fixtures.TestBase): try: yield metadata finally: - engines.drop_all_tables(metadata, testing.db) + testing_util.drop_all_tables_from_metadata(metadata, testing.db) @async_test async def test_detect_stale_ddl_cache_raise_recover( diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 1763b210b2..b3a0b9bbde 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1810,7 +1810,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): class InsertOnConflictTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = postgresql.dialect() - def setup(self): + def setup_test(self): self.table1 = table1 = table( "mytable", column("myid", Integer), @@ -2222,7 +2222,7 @@ class DistinctOnTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = postgresql.dialect() - def setup(self): + def setup_test(self): self.table = Table( "t", MetaData(), @@ -2373,7 +2373,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = postgresql.dialect() - def setup(self): + def setup_test(self): self.table = Table( "t", MetaData(), @@ -2464,7 +2464,7 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL): class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "postgresql" - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index f760a309b4..9c9d817bba 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -198,17 +198,17 @@ class ExecuteManyMode(object): @config.fixture() def connection(self): - eng = engines.testing_engine(options=self.options) + opts = dict(self.options) + opts["use_reaper"] = False + eng = engines.testing_engine(options=opts) conn = eng.connect() trans = conn.begin() - try: - yield conn - finally: - if trans.is_active: - trans.rollback() - conn.close() - eng.dispose() + yield conn + if trans.is_active: + trans.rollback() + conn.close() + eng.dispose() @classmethod def define_tables(cls, metadata): @@ -510,8 +510,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): # assert result.closed assert result.cursor is None - @testing.provide_metadata - def test_insert_returning_preexecute_pk(self, connection): + def test_insert_returning_preexecute_pk(self, metadata, connection): counter = itertools.count(1) t = Table( @@ -525,7 +524,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): ), Column("data", Integer), ) - self.metadata.create_all(connection) + metadata.create_all(connection) result = connection.execute( t.insert().return_defaults(), diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index 94af168eee..c51fd19432 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -40,10 +40,10 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql" __backend__ = True - def setup(self): + def setup_test(self): self.metadata = MetaData() - def teardown(self): + def teardown_test(self): with testing.db.begin() as conn: self.metadata.drop_all(conn) @@ -890,7 +890,7 @@ class ExtractTest(fixtures.TablesTest): def setup_bind(cls): from sqlalchemy import event - eng = engines.testing_engine() + eng = engines.testing_engine(options={"scope": "class"}) @event.listens_for(eng, "connect") def connect(dbapi_conn, rec): diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 754eff25a0..6586a8308d 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -80,26 +80,24 @@ class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults): ]: sa.event.listen(metadata, "before_drop", sa.DDL(ddl)) - def test_foreign_table_is_reflected(self): + def test_foreign_table_is_reflected(self, connection): metadata = MetaData() - table = Table("test_foreigntable", metadata, autoload_with=testing.db) + table = Table("test_foreigntable", metadata, autoload_with=connection) eq_( set(table.columns.keys()), set(["id", "data"]), "Columns of reflected foreign table didn't equal expected columns", ) - def test_get_foreign_table_names(self): - inspector = inspect(testing.db) - with testing.db.connect(): - ft_names = inspector.get_foreign_table_names() - eq_(ft_names, ["test_foreigntable"]) + def test_get_foreign_table_names(self, connection): + inspector = inspect(connection) + ft_names = inspector.get_foreign_table_names() + eq_(ft_names, ["test_foreigntable"]) - def test_get_table_names_no_foreign(self): - inspector = inspect(testing.db) - with testing.db.connect(): - names = inspector.get_table_names() - eq_(names, ["testtable"]) + def test_get_table_names_no_foreign(self, connection): + inspector = inspect(connection) + names = inspector.get_table_names() + eq_(names, ["testtable"]) class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults): @@ -133,22 +131,22 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults): if testing.against("postgresql >= 11"): Index("my_index", dv.c.q) - def test_get_tablenames(self): + def test_get_tablenames(self, connection): assert {"data_values", "data_values_4_10"}.issubset( - inspect(testing.db).get_table_names() + inspect(connection).get_table_names() ) - def test_reflect_cols(self): - cols = inspect(testing.db).get_columns("data_values") + def test_reflect_cols(self, connection): + cols = inspect(connection).get_columns("data_values") eq_([c["name"] for c in cols], ["modulus", "data", "q"]) - def test_reflect_cols_from_partition(self): - cols = inspect(testing.db).get_columns("data_values_4_10") + def test_reflect_cols_from_partition(self, connection): + cols = inspect(connection).get_columns("data_values_4_10") eq_([c["name"] for c in cols], ["modulus", "data", "q"]) @testing.only_on("postgresql >= 11") - def test_reflect_index(self): - idx = inspect(testing.db).get_indexes("data_values") + def test_reflect_index(self, connection): + idx = inspect(connection).get_indexes("data_values") eq_( idx, [ @@ -162,8 +160,8 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults): ) @testing.only_on("postgresql >= 11") - def test_reflect_index_from_partition(self): - idx = inspect(testing.db).get_indexes("data_values_4_10") + def test_reflect_index_from_partition(self, connection): + idx = inspect(connection).get_indexes("data_values_4_10") # note the name appears to be generated by PG, currently # 'data_values_4_10_q_idx' eq_( @@ -220,44 +218,43 @@ class MaterializedViewReflectionTest( testtable, "before_drop", sa.DDL("DROP VIEW test_regview") ) - def test_mview_is_reflected(self): + def test_mview_is_reflected(self, connection): metadata = MetaData() - table = Table("test_mview", metadata, autoload_with=testing.db) + table = Table("test_mview", metadata, autoload_with=connection) eq_( set(table.columns.keys()), set(["id", "data"]), "Columns of reflected mview didn't equal expected columns", ) - def test_mview_select(self): + def test_mview_select(self, connection): metadata = MetaData() - table = Table("test_mview", metadata, autoload_with=testing.db) - with testing.db.connect() as conn: - eq_(conn.execute(table.select()).fetchall(), [(89, "d1")]) + table = Table("test_mview", metadata, autoload_with=connection) + eq_(connection.execute(table.select()).fetchall(), [(89, "d1")]) - def test_get_view_names(self): - insp = inspect(testing.db) + def test_get_view_names(self, connection): + insp = inspect(connection) eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) - def test_get_view_names_plain(self): - insp = inspect(testing.db) + def test_get_view_names_plain(self, connection): + insp = inspect(connection) eq_( set(insp.get_view_names(include=("plain",))), set(["test_regview"]) ) - def test_get_view_names_plain_string(self): - insp = inspect(testing.db) + def test_get_view_names_plain_string(self, connection): + insp = inspect(connection) eq_(set(insp.get_view_names(include="plain")), set(["test_regview"])) - def test_get_view_names_materialized(self): - insp = inspect(testing.db) + def test_get_view_names_materialized(self, connection): + insp = inspect(connection) eq_( set(insp.get_view_names(include=("materialized",))), set(["test_mview"]), ) - def test_get_view_names_reflection_cache_ok(self): - insp = inspect(testing.db) + def test_get_view_names_reflection_cache_ok(self, connection): + insp = inspect(connection) eq_( set(insp.get_view_names(include=("plain",))), set(["test_regview"]) ) @@ -267,12 +264,12 @@ class MaterializedViewReflectionTest( ) eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) - def test_get_view_names_empty(self): - insp = inspect(testing.db) + def test_get_view_names_empty(self, connection): + insp = inspect(connection) assert_raises(ValueError, insp.get_view_names, include=()) - def test_get_view_definition(self): - insp = inspect(testing.db) + def test_get_view_definition(self, connection): + insp = inspect(connection) eq_( re.sub( r"[\n\t ]+", @@ -290,7 +287,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): with testing.db.begin() as con: for ddl in [ 'CREATE SCHEMA "SomeSchema"', @@ -334,7 +331,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): ) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as con: con.exec_driver_sql("DROP TABLE testtable") con.exec_driver_sql("DROP TABLE test_schema.testtable") @@ -350,9 +347,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"') con.exec_driver_sql('DROP SCHEMA "SomeSchema"') - def test_table_is_reflected(self): + def test_table_is_reflected(self, connection): metadata = MetaData() - table = Table("testtable", metadata, autoload_with=testing.db) + table = Table("testtable", metadata, autoload_with=connection) eq_( set(table.columns.keys()), set(["question", "answer"]), @@ -360,9 +357,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): ) assert isinstance(table.c.answer.type, Integer) - def test_domain_is_reflected(self): + def test_domain_is_reflected(self, connection): metadata = MetaData() - table = Table("testtable", metadata, autoload_with=testing.db) + table = Table("testtable", metadata, autoload_with=connection) eq_( str(table.columns.answer.server_default.arg), "42", @@ -372,28 +369,28 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): not table.columns.answer.nullable ), "Expected reflected column to not be nullable." - def test_enum_domain_is_reflected(self): + def test_enum_domain_is_reflected(self, connection): metadata = MetaData() - table = Table("enum_test", metadata, autoload_with=testing.db) + table = Table("enum_test", metadata, autoload_with=connection) eq_(table.c.data.type.enums, ["test"]) - def test_array_domain_is_reflected(self): + def test_array_domain_is_reflected(self, connection): metadata = MetaData() - table = Table("array_test", metadata, autoload_with=testing.db) + table = Table("array_test", metadata, autoload_with=connection) eq_(table.c.data.type.__class__, ARRAY) eq_(table.c.data.type.item_type.__class__, INTEGER) - def test_quoted_remote_schema_domain_is_reflected(self): + def test_quoted_remote_schema_domain_is_reflected(self, connection): metadata = MetaData() - table = Table("quote_test", metadata, autoload_with=testing.db) + table = Table("quote_test", metadata, autoload_with=connection) eq_(table.c.data.type.__class__, INTEGER) - def test_table_is_reflected_test_schema(self): + def test_table_is_reflected_test_schema(self, connection): metadata = MetaData() table = Table( "testtable", metadata, - autoload_with=testing.db, + autoload_with=connection, schema="test_schema", ) eq_( @@ -403,12 +400,12 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): ) assert isinstance(table.c.anything.type, Integer) - def test_schema_domain_is_reflected(self): + def test_schema_domain_is_reflected(self, connection): metadata = MetaData() table = Table( "testtable", metadata, - autoload_with=testing.db, + autoload_with=connection, schema="test_schema", ) eq_( @@ -420,9 +417,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): table.columns.answer.nullable ), "Expected reflected column to be nullable." - def test_crosschema_domain_is_reflected(self): + def test_crosschema_domain_is_reflected(self, connection): metadata = MetaData() - table = Table("crosschema", metadata, autoload_with=testing.db) + table = Table("crosschema", metadata, autoload_with=connection) eq_( str(table.columns.answer.server_default.arg), "0", @@ -432,7 +429,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): table.columns.answer.nullable ), "Expected reflected column to be nullable." - def test_unknown_types(self): + def test_unknown_types(self, connection): from sqlalchemy.dialects.postgresql import base ischema_names = base.PGDialect.ischema_names @@ -440,13 +437,13 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): try: m2 = MetaData() assert_raises( - exc.SAWarning, Table, "testtable", m2, autoload_with=testing.db + exc.SAWarning, Table, "testtable", m2, autoload_with=connection ) @testing.emits_warning("Did not recognize type") def warns(): m3 = MetaData() - t3 = Table("testtable", m3, autoload_with=testing.db) + t3 = Table("testtable", m3, autoload_with=connection) assert t3.c.answer.type.__class__ == sa.types.NullType finally: @@ -471,9 +468,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): subject = Table("subject", meta2, autoload_with=connection) eq_(subject.primary_key.columns.keys(), ["p2", "p1"]) - @testing.provide_metadata - def test_pg_weirdchar_reflection(self): - meta1 = self.metadata + def test_pg_weirdchar_reflection(self, metadata, connection): + meta1 = metadata subject = Table( "subject", meta1, Column("id$", Integer, primary_key=True) ) @@ -483,101 +479,91 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("ref", Integer, ForeignKey("subject.id$")), ) - meta1.create_all(testing.db) + meta1.create_all(connection) meta2 = MetaData() - subject = Table("subject", meta2, autoload_with=testing.db) - referer = Table("referer", meta2, autoload_with=testing.db) + subject = Table("subject", meta2, autoload_with=connection) + referer = Table("referer", meta2, autoload_with=connection) self.assert_( (subject.c["id$"] == referer.c.ref).compare( subject.join(referer).onclause ) ) - @testing.provide_metadata - def test_reflect_default_over_128_chars(self): + def test_reflect_default_over_128_chars(self, metadata, connection): Table( "t", - self.metadata, + metadata, Column("x", String(200), server_default="abcd" * 40), - ).create(testing.db) + ).create(connection) m = MetaData() - t = Table("t", m, autoload_with=testing.db) + t = Table("t", m, autoload_with=connection) eq_( t.c.x.server_default.arg.text, "'%s'::character varying" % ("abcd" * 40), ) - @testing.fails_if("postgresql < 8.1", "schema name leaks in, not sure") - @testing.provide_metadata - def test_renamed_sequence_reflection(self): - metadata = self.metadata + def test_renamed_sequence_reflection(self, metadata, connection): Table("t", metadata, Column("id", Integer, primary_key=True)) - metadata.create_all(testing.db) + metadata.create_all(connection) m2 = MetaData() - t2 = Table("t", m2, autoload_with=testing.db, implicit_returning=False) + t2 = Table("t", m2, autoload_with=connection, implicit_returning=False) eq_(t2.c.id.server_default.arg.text, "nextval('t_id_seq'::regclass)") - with testing.db.begin() as conn: - r = conn.execute(t2.insert()) - eq_(r.inserted_primary_key, (1,)) + r = connection.execute(t2.insert()) + eq_(r.inserted_primary_key, (1,)) - with testing.db.begin() as conn: - conn.exec_driver_sql( - "alter table t_id_seq rename to foobar_id_seq" - ) + connection.exec_driver_sql( + "alter table t_id_seq rename to foobar_id_seq" + ) m3 = MetaData() - t3 = Table("t", m3, autoload_with=testing.db, implicit_returning=False) + t3 = Table("t", m3, autoload_with=connection, implicit_returning=False) eq_( t3.c.id.server_default.arg.text, "nextval('foobar_id_seq'::regclass)", ) - with testing.db.begin() as conn: - r = conn.execute(t3.insert()) - eq_(r.inserted_primary_key, (2,)) + r = connection.execute(t3.insert()) + eq_(r.inserted_primary_key, (2,)) - @testing.provide_metadata - def test_altered_type_autoincrement_pk_reflection(self): - metadata = self.metadata + def test_altered_type_autoincrement_pk_reflection( + self, metadata, connection + ): + metadata = metadata Table( "t", metadata, Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all(testing.db) + metadata.create_all(connection) - with testing.db.begin() as conn: - conn.exec_driver_sql( - "alter table t alter column id type varchar(50)" - ) + connection.exec_driver_sql( + "alter table t alter column id type varchar(50)" + ) m2 = MetaData() - t2 = Table("t", m2, autoload_with=testing.db) + t2 = Table("t", m2, autoload_with=connection) eq_(t2.c.id.autoincrement, False) eq_(t2.c.x.autoincrement, False) - @testing.provide_metadata - def test_renamed_pk_reflection(self): - metadata = self.metadata + def test_renamed_pk_reflection(self, metadata, connection): + metadata = metadata Table("t", metadata, Column("id", Integer, primary_key=True)) - metadata.create_all(testing.db) - with testing.db.begin() as conn: - conn.exec_driver_sql("alter table t rename id to t_id") + metadata.create_all(connection) + connection.exec_driver_sql("alter table t rename id to t_id") m2 = MetaData() - t2 = Table("t", m2, autoload_with=testing.db) + t2 = Table("t", m2, autoload_with=connection) eq_([c.name for c in t2.primary_key], ["t_id"]) - @testing.provide_metadata - def test_has_temporary_table(self): - assert not inspect(testing.db).has_table("some_temp_table") + def test_has_temporary_table(self, metadata, connection): + assert not inspect(connection).has_table("some_temp_table") user_tmp = Table( "some_temp_table", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("name", String(50)), prefixes=["TEMPORARY"], ) - user_tmp.create(testing.db) - assert inspect(testing.db).has_table("some_temp_table") + user_tmp.create(connection) + assert inspect(connection).has_table("some_temp_table") def test_cross_schema_reflection_one(self, metadata, connection): @@ -898,19 +884,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): A_table.create(connection, checkfirst=True) assert inspect(connection).has_table("A") - def test_uppercase_lowercase_sequence(self): + def test_uppercase_lowercase_sequence(self, connection): a_seq = Sequence("a") A_seq = Sequence("A") - a_seq.create(testing.db) - assert testing.db.dialect.has_sequence(testing.db, "a") - assert not testing.db.dialect.has_sequence(testing.db, "A") - A_seq.create(testing.db, checkfirst=True) - assert testing.db.dialect.has_sequence(testing.db, "A") + a_seq.create(connection) + assert connection.dialect.has_sequence(connection, "a") + assert not connection.dialect.has_sequence(connection, "A") + A_seq.create(connection, checkfirst=True) + assert connection.dialect.has_sequence(connection, "A") - a_seq.drop(testing.db) - A_seq.drop(testing.db) + a_seq.drop(connection) + A_seq.drop(connection) def test_index_reflection(self, metadata, connection): """Reflecting expression-based indexes should warn""" @@ -960,11 +946,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_index_reflection_partial(self, connection): + def test_index_reflection_partial(self, metadata, connection): """Reflect the filter defintion on partial indexes""" - metadata = self.metadata + metadata = metadata t1 = Table( "table1", @@ -978,7 +963,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): metadata.create_all(connection) - ind = testing.db.dialect.get_indexes(connection, t1, None) + ind = connection.dialect.get_indexes(connection, t1, None) partial_definitions = [] for ix in ind: @@ -1073,15 +1058,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): compile_exprs(r3.expressions), ) - @testing.provide_metadata - def test_index_reflection_modified(self): + def test_index_reflection_modified(self, metadata, connection): """reflect indexes when a column name has changed - PG 9 does not update the name of the column in the index def. [ticket:2141] """ - metadata = self.metadata + metadata = metadata Table( "t", @@ -1089,26 +1073,21 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all(testing.db) - with testing.db.begin() as conn: - conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)") - conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y") + metadata.create_all(connection) + connection.exec_driver_sql("CREATE INDEX idx1 ON t (x)") + connection.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y") - ind = testing.db.dialect.get_indexes(conn, "t", None) - expected = [ - {"name": "idx1", "unique": False, "column_names": ["y"]} - ] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] + ind = connection.dialect.get_indexes(connection, "t", None) + expected = [{"name": "idx1", "unique": False, "column_names": ["y"]}] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] - eq_(ind, expected) + eq_(ind, expected) - @testing.fails_if("postgresql < 8.2", "reloptions not supported") - @testing.provide_metadata - def test_index_reflection_with_storage_options(self): + def test_index_reflection_with_storage_options(self, metadata, connection): """reflect indexes with storage options set""" - metadata = self.metadata + metadata = metadata Table( "t", @@ -1116,70 +1095,63 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all(testing.db) + metadata.create_all(connection) - with testing.db.begin() as conn: - conn.exec_driver_sql( - "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)" - ) + connection.exec_driver_sql( + "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)" + ) - ind = testing.db.dialect.get_indexes(conn, "t", None) + ind = testing.db.dialect.get_indexes(connection, "t", None) - expected = [ - { - "unique": False, - "column_names": ["x"], - "name": "idx1", - "dialect_options": { - "postgresql_with": {"fillfactor": "50"} - }, - } - ] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] - eq_(ind, expected) + expected = [ + { + "unique": False, + "column_names": ["x"], + "name": "idx1", + "dialect_options": {"postgresql_with": {"fillfactor": "50"}}, + } + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + eq_(ind, expected) - m = MetaData() - t1 = Table("t", m, autoload_with=conn) - eq_( - list(t1.indexes)[0].dialect_options["postgresql"]["with"], - {"fillfactor": "50"}, - ) + m = MetaData() + t1 = Table("t", m, autoload_with=connection) + eq_( + list(t1.indexes)[0].dialect_options["postgresql"]["with"], + {"fillfactor": "50"}, + ) - @testing.provide_metadata - def test_index_reflection_with_access_method(self): + def test_index_reflection_with_access_method(self, metadata, connection): """reflect indexes with storage options set""" - metadata = self.metadata - Table( "t", metadata, Column("id", Integer, primary_key=True), Column("x", ARRAY(Integer)), ) - metadata.create_all(testing.db) - with testing.db.begin() as conn: - conn.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)") + metadata.create_all(connection) + connection.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)") - ind = testing.db.dialect.get_indexes(conn, "t", None) - expected = [ - { - "unique": False, - "column_names": ["x"], - "name": "idx1", - "dialect_options": {"postgresql_using": "gin"}, - } - ] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] - eq_(ind, expected) - m = MetaData() - t1 = Table("t", m, autoload_with=conn) - eq_( - list(t1.indexes)[0].dialect_options["postgresql"]["using"], - "gin", - ) + ind = testing.db.dialect.get_indexes(connection, "t", None) + expected = [ + { + "unique": False, + "column_names": ["x"], + "name": "idx1", + "dialect_options": {"postgresql_using": "gin"}, + } + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + eq_(ind, expected) + m = MetaData() + t1 = Table("t", m, autoload_with=connection) + eq_( + list(t1.indexes)[0].dialect_options["postgresql"]["using"], + "gin", + ) @testing.skip_if("postgresql < 11.0", "indnkeyatts not supported") def test_index_reflection_with_include(self, metadata, connection): @@ -1199,7 +1171,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): # [{'column_names': ['x', 'name'], # 'name': 'idx1', 'unique': False}] - ind = testing.db.dialect.get_indexes(connection, "t", None) + ind = connection.dialect.get_indexes(connection, "t", None) eq_( ind, [ @@ -1286,15 +1258,14 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): for fk in fks: eq_(fk, fk_ref[fk["name"]]) - @testing.provide_metadata - def test_inspect_enums_schema(self, connection): + def test_inspect_enums_schema(self, metadata, connection): enum_type = postgresql.ENUM( "sad", "ok", "happy", name="mood", schema="test_schema", - metadata=self.metadata, + metadata=metadata, ) enum_type.create(connection) inspector = inspect(connection) @@ -1310,13 +1281,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enums(self): + def test_inspect_enums(self, metadata, connection): enum_type = postgresql.ENUM( - "cat", "dog", "rat", name="pet", metadata=self.metadata + "cat", "dog", "rat", name="pet", metadata=metadata ) - enum_type.create(testing.db) - inspector = inspect(testing.db) + enum_type.create(connection) + inspector = inspect(connection) eq_( inspector.get_enums(), [ @@ -1329,17 +1299,16 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enums_case_sensitive(self): + def test_inspect_enums_case_sensitive(self, metadata, connection): sa.event.listen( - self.metadata, + metadata, "before_create", sa.DDL('create schema "TestSchema"'), ) sa.event.listen( - self.metadata, + metadata, "after_drop", - sa.DDL('drop schema "TestSchema" cascade'), + sa.DDL('drop schema if exists "TestSchema" cascade'), ) for enum in "lower_case", "UpperCase", "Name.With.Dot": @@ -1350,11 +1319,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): "CapsTwo", name=enum, schema=schema, - metadata=self.metadata, + metadata=metadata, ) - self.metadata.create_all(testing.db) - inspector = inspect(testing.db) + metadata.create_all(connection) + inspector = inspect(connection) for schema in None, "test_schema", "TestSchema": eq_( sorted( @@ -1382,17 +1351,18 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enums_case_sensitive_from_table(self): + def test_inspect_enums_case_sensitive_from_table( + self, metadata, connection + ): sa.event.listen( - self.metadata, + metadata, "before_create", sa.DDL('create schema "TestSchema"'), ) sa.event.listen( - self.metadata, + metadata, "after_drop", - sa.DDL('drop schema "TestSchema" cascade'), + sa.DDL('drop schema if exists "TestSchema" cascade'), ) counter = itertools.count() @@ -1403,19 +1373,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): "CapsOne", "CapsTwo", name=enum, - metadata=self.metadata, + metadata=metadata, schema=schema, ) Table( "t%d" % next(counter), - self.metadata, + metadata, Column("q", enum_type), ) - self.metadata.create_all(testing.db) + metadata.create_all(connection) - inspector = inspect(testing.db) + inspector = inspect(connection) counter = itertools.count() for enum in "lower_case", "UpperCase", "Name.With.Dot": for schema in None, "test_schema", "TestSchema": @@ -1439,10 +1409,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enums_star(self): + def test_inspect_enums_star(self, metadata, connection): enum_type = postgresql.ENUM( - "cat", "dog", "rat", name="pet", metadata=self.metadata + "cat", "dog", "rat", name="pet", metadata=metadata ) schema_enum_type = postgresql.ENUM( "sad", @@ -1450,11 +1419,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): "happy", name="mood", schema="test_schema", - metadata=self.metadata, + metadata=metadata, ) - enum_type.create(testing.db) - schema_enum_type.create(testing.db) - inspector = inspect(testing.db) + enum_type.create(connection) + schema_enum_type.create(connection) + inspector = inspect(connection) eq_( inspector.get_enums(), @@ -1486,11 +1455,10 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enum_empty(self): - enum_type = postgresql.ENUM(name="empty", metadata=self.metadata) - enum_type.create(testing.db) - inspector = inspect(testing.db) + def test_inspect_enum_empty(self, metadata, connection): + enum_type = postgresql.ENUM(name="empty", metadata=metadata) + enum_type.create(connection) + inspector = inspect(connection) eq_( inspector.get_enums(), @@ -1504,13 +1472,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - @testing.provide_metadata - def test_inspect_enum_empty_from_table(self): + def test_inspect_enum_empty_from_table(self, metadata, connection): Table( - "t", self.metadata, Column("x", postgresql.ENUM(name="empty")) - ).create(testing.db) + "t", metadata, Column("x", postgresql.ENUM(name="empty")) + ).create(connection) - t = Table("t", MetaData(), autoload_with=testing.db) + t = Table("t", MetaData(), autoload_with=connection) eq_(t.c.x.type.enums, []) def test_reflection_with_unique_constraint(self, metadata, connection): @@ -1749,12 +1716,12 @@ class CustomTypeReflectionTest(fixtures.TestBase): ischema_names = None - def setup(self): + def setup_test(self): ischema_names = postgresql.PGDialect.ischema_names postgresql.PGDialect.ischema_names = ischema_names.copy() self.ischema_names = ischema_names - def teardown(self): + def teardown_test(self): postgresql.PGDialect.ischema_names = self.ischema_names self.ischema_names = None @@ -1788,55 +1755,51 @@ class IntervalReflectionTest(fixtures.TestBase): __only_on__ = "postgresql" __backend__ = True - def test_interval_types(self): - for sym in [ - "YEAR", - "MONTH", - "DAY", - "HOUR", - "MINUTE", - "SECOND", - "YEAR TO MONTH", - "DAY TO HOUR", - "DAY TO MINUTE", - "DAY TO SECOND", - "HOUR TO MINUTE", - "HOUR TO SECOND", - "MINUTE TO SECOND", - ]: - self._test_interval_symbol(sym) - - @testing.provide_metadata - def _test_interval_symbol(self, sym): + @testing.combinations( + ("YEAR",), + ("MONTH",), + ("DAY",), + ("HOUR",), + ("MINUTE",), + ("SECOND",), + ("YEAR TO MONTH",), + ("DAY TO HOUR",), + ("DAY TO MINUTE",), + ("DAY TO SECOND",), + ("HOUR TO MINUTE",), + ("HOUR TO SECOND",), + ("MINUTE TO SECOND",), + argnames="sym", + ) + def test_interval_types(self, sym, metadata, connection): t = Table( "i_test", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("data1", INTERVAL(fields=sym)), ) - t.create(testing.db) + t.create(connection) columns = { rec["name"]: rec - for rec in inspect(testing.db).get_columns("i_test") + for rec in inspect(connection).get_columns("i_test") } assert isinstance(columns["data1"]["type"], INTERVAL) eq_(columns["data1"]["type"].fields, sym.lower()) eq_(columns["data1"]["type"].precision, None) - @testing.provide_metadata - def test_interval_precision(self): + def test_interval_precision(self, metadata, connection): t = Table( "i_test", - self.metadata, + metadata, Column("id", Integer, primary_key=True), Column("data1", INTERVAL(precision=6)), ) - t.create(testing.db) + t.create(connection) columns = { rec["name"]: rec - for rec in inspect(testing.db).get_columns("i_test") + for rec in inspect(connection).get_columns("i_test") } assert isinstance(columns["data1"]["type"], INTERVAL) eq_(columns["data1"]["type"].fields, None) @@ -1871,8 +1834,8 @@ class IdentityReflectionTest(fixtures.TablesTest): Column("id4", SmallInteger, Identity()), ) - def test_reflect_identity(self): - insp = inspect(testing.db) + def test_reflect_identity(self, connection): + insp = inspect(connection) default = dict( always=False, start=1, diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index e8a1876c7a..6202f8f868 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -49,7 +49,6 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes -from sqlalchemy.testing import engines from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message @@ -156,8 +155,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql > 8.3" - @testing.provide_metadata - def test_create_table(self, connection): + def test_create_table(self, metadata, connection): metadata = self.metadata t1 = Table( "table", @@ -177,8 +175,8 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): [(1, "two"), (2, "three"), (3, "three")], ) - @testing.combinations(None, "foo") - def test_create_table_schema_translate_map(self, symbol_name): + @testing.combinations(None, "foo", argnames="symbol_name") + def test_create_table_schema_translate_map(self, connection, symbol_name): # note we can't use the fixture here because it will not drop # from the correct schema metadata = MetaData() @@ -199,35 +197,30 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ), schema=symbol_name, ) - with testing.db.begin() as conn: - conn = conn.execution_options( - schema_translate_map={symbol_name: testing.config.test_schema} - ) - t1.create(conn) - assert "schema_enum" in [ - e["name"] - for e in inspect(conn).get_enums( - schema=testing.config.test_schema - ) - ] - t1.create(conn, checkfirst=True) + conn = connection.execution_options( + schema_translate_map={symbol_name: testing.config.test_schema} + ) + t1.create(conn) + assert "schema_enum" in [ + e["name"] + for e in inspect(conn).get_enums(schema=testing.config.test_schema) + ] + t1.create(conn, checkfirst=True) - conn.execute(t1.insert(), value="two") - conn.execute(t1.insert(), value="three") - conn.execute(t1.insert(), value="three") - eq_( - conn.execute(t1.select().order_by(t1.c.id)).fetchall(), - [(1, "two"), (2, "three"), (3, "three")], - ) + conn.execute(t1.insert(), value="two") + conn.execute(t1.insert(), value="three") + conn.execute(t1.insert(), value="three") + eq_( + conn.execute(t1.select().order_by(t1.c.id)).fetchall(), + [(1, "two"), (2, "three"), (3, "three")], + ) - t1.drop(conn) - assert "schema_enum" not in [ - e["name"] - for e in inspect(conn).get_enums( - schema=testing.config.test_schema - ) - ] - t1.drop(conn, checkfirst=True) + t1.drop(conn) + assert "schema_enum" not in [ + e["name"] + for e in inspect(conn).get_enums(schema=testing.config.test_schema) + ] + t1.drop(conn, checkfirst=True) def test_name_required(self, metadata, connection): etype = Enum("four", "five", "six", metadata=metadata) @@ -270,8 +263,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): [util.u("réveillé"), util.u("drôle"), util.u("S’il")], ) - @testing.provide_metadata - def test_non_native_enum(self, connection): + def test_non_native_enum(self, metadata, connection): metadata = self.metadata t1 = Table( "foo", @@ -290,10 +282,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ) def go(): - t1.create(testing.db) + t1.create(connection) self.assert_sql( - testing.db, + connection, go, [ ( @@ -307,8 +299,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): connection.execute(t1.insert(), {"bar": "two"}) eq_(connection.scalar(select(t1.c.bar)), "two") - @testing.provide_metadata - def test_non_native_enum_w_unicode(self, connection): + def test_non_native_enum_w_unicode(self, metadata, connection): metadata = self.metadata t1 = Table( "foo", @@ -326,10 +317,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ) def go(): - t1.create(testing.db) + t1.create(connection) self.assert_sql( - testing.db, + connection, go, [ ( @@ -346,8 +337,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): connection.execute(t1.insert(), {"bar": util.u("Ü")}) eq_(connection.scalar(select(t1.c.bar)), util.u("Ü")) - @testing.provide_metadata - def test_disable_create(self): + def test_disable_create(self, metadata, connection): metadata = self.metadata e1 = postgresql.ENUM( @@ -357,13 +347,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): t1 = Table("e1", metadata, Column("c1", e1)) # table can be created separately # without conflict - e1.create(bind=testing.db) - t1.create(testing.db) - t1.drop(testing.db) - e1.drop(bind=testing.db) + e1.create(bind=connection) + t1.create(connection) + t1.drop(connection) + e1.drop(bind=connection) - @testing.provide_metadata - def test_dont_keep_checking(self, connection): + def test_dont_keep_checking(self, metadata, connection): metadata = self.metadata e1 = postgresql.ENUM("one", "two", "three", name="myenum") @@ -560,11 +549,10 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): e["name"] for e in inspect(connection).get_enums() ] - def test_non_native_dialect(self): - engine = engines.testing_engine() + def test_non_native_dialect(self, metadata, testing_engine): + engine = testing_engine() engine.connect() engine.dialect.supports_native_enum = False - metadata = MetaData() t1 = Table( "foo", metadata, @@ -583,21 +571,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): def go(): t1.create(engine) - try: - self.assert_sql( - engine, - go, - [ - ( - "CREATE TABLE foo (bar " - "VARCHAR(5), CONSTRAINT myenum CHECK " - "(bar IN ('one', 'two', 'three')))", - {}, - ) - ], - ) - finally: - metadata.drop_all(engine) + self.assert_sql( + engine, + go, + [ + ( + "CREATE TABLE foo (bar " + "VARCHAR(5), CONSTRAINT myenum CHECK " + "(bar IN ('one', 'two', 'three')))", + {}, + ) + ], + ) def test_standalone_enum(self, connection, metadata): etype = Enum( @@ -605,26 +590,26 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ) etype.create(connection) try: - assert testing.db.dialect.has_type(connection, "fourfivesixtype") + assert connection.dialect.has_type(connection, "fourfivesixtype") finally: etype.drop(connection) - assert not testing.db.dialect.has_type( + assert not connection.dialect.has_type( connection, "fourfivesixtype" ) metadata.create_all(connection) try: - assert testing.db.dialect.has_type(connection, "fourfivesixtype") + assert connection.dialect.has_type(connection, "fourfivesixtype") finally: metadata.drop_all(connection) - assert not testing.db.dialect.has_type( + assert not connection.dialect.has_type( connection, "fourfivesixtype" ) - def test_no_support(self): + def test_no_support(self, testing_engine): def server_version_info(self): return (8, 2) - e = engines.testing_engine() + e = testing_engine() dialect = e.dialect dialect._get_server_version_info = server_version_info @@ -692,8 +677,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): eq_(t2.c.value2.type.name, "fourfivesixtype") eq_(t2.c.value2.type.schema, "test_schema") - @testing.provide_metadata - def test_custom_subclass(self, connection): + def test_custom_subclass(self, metadata, connection): class MyEnum(TypeDecorator): impl = Enum("oneHI", "twoHI", "threeHI", name="myenum") @@ -708,13 +692,12 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): return value t1 = Table("table1", self.metadata, Column("data", MyEnum())) - self.metadata.create_all(testing.db) + self.metadata.create_all(connection) connection.execute(t1.insert(), {"data": "two"}) eq_(connection.scalar(select(t1.c.data)), "twoHITHERE") - @testing.provide_metadata - def test_generic_w_pg_variant(self, connection): + def test_generic_w_pg_variant(self, metadata, connection): some_table = Table( "some_table", self.metadata, @@ -752,8 +735,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): e["name"] for e in inspect(connection).get_enums() ] - @testing.provide_metadata - def test_generic_w_some_other_variant(self, connection): + def test_generic_w_some_other_variant(self, metadata, connection): some_table = Table( "some_table", self.metadata, @@ -809,26 +791,28 @@ class RegClassTest(fixtures.TestBase): __only_on__ = "postgresql" __backend__ = True - @staticmethod - def _scalar(expression): - with testing.db.connect() as conn: - return conn.scalar(select(expression)) + @testing.fixture() + def scalar(self, connection): + def go(expression): + return connection.scalar(select(expression)) - def test_cast_name(self): - eq_(self._scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class") + return go - def test_cast_path(self): + def test_cast_name(self, scalar): + eq_(scalar(cast("pg_class", postgresql.REGCLASS)), "pg_class") + + def test_cast_path(self, scalar): eq_( - self._scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)), + scalar(cast("pg_catalog.pg_class", postgresql.REGCLASS)), "pg_class", ) - def test_cast_oid(self): + def test_cast_oid(self, scalar): regclass = cast("pg_class", postgresql.REGCLASS) - oid = self._scalar(cast(regclass, postgresql.OID)) + oid = scalar(cast(regclass, postgresql.OID)) assert isinstance(oid, int) eq_( - self._scalar( + scalar( cast(type_coerce(oid, postgresql.OID), postgresql.REGCLASS) ), "pg_class", @@ -1339,13 +1323,12 @@ class ArrayRoundTripTest(object): Column("dimarr", ProcValue), ) - def _fixture_456(self, table): - with testing.db.begin() as conn: - conn.execute(table.insert(), intarr=[4, 5, 6]) + def _fixture_456(self, table, connection): + connection.execute(table.insert(), intarr=[4, 5, 6]) - def test_reflect_array_column(self): + def test_reflect_array_column(self, connection): metadata2 = MetaData() - tbl = Table("arrtable", metadata2, autoload_with=testing.db) + tbl = Table("arrtable", metadata2, autoload_with=connection) assert isinstance(tbl.c.intarr.type, self.ARRAY) assert isinstance(tbl.c.strarr.type, self.ARRAY) assert isinstance(tbl.c.intarr.type.item_type, Integer) @@ -1564,7 +1547,7 @@ class ArrayRoundTripTest(object): def test_array_getitem_single_exec(self, connection): arrtable = self.tables.arrtable - self._fixture_456(arrtable) + self._fixture_456(arrtable, connection) eq_(connection.scalar(select(arrtable.c.intarr[2])), 5) connection.execute(arrtable.update().values({arrtable.c.intarr[2]: 7})) eq_(connection.scalar(select(arrtable.c.intarr[2])), 7) @@ -1654,11 +1637,10 @@ class ArrayRoundTripTest(object): set([("1", "2", "3"), ("4", "5", "6"), (("4", "5"), ("6", "7"))]), ) - def test_array_plus_native_enum_create(self): - m = MetaData() + def test_array_plus_native_enum_create(self, metadata, connection): t = Table( "t", - m, + metadata, Column( "data_1", self.ARRAY(postgresql.ENUM("a", "b", "c", name="my_enum_1")), @@ -1669,13 +1651,13 @@ class ArrayRoundTripTest(object): ), ) - t.create(testing.db) + t.create(connection) eq_( - set(e["name"] for e in inspect(testing.db).get_enums()), + set(e["name"] for e in inspect(connection).get_enums()), set(["my_enum_1", "my_enum_2"]), ) - t.drop(testing.db) - eq_(inspect(testing.db).get_enums(), []) + t.drop(connection) + eq_(inspect(connection).get_enums(), []) class CoreArrayRoundTripTest( @@ -1690,33 +1672,35 @@ class PGArrayRoundTripTest( ): ARRAY = postgresql.ARRAY - @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),)) - def test_undim_array_contains_typed_exec(self, struct): + @testing.combinations( + (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct" + ) + def test_undim_array_contains_typed_exec(self, struct, connection): arrtable = self.tables.arrtable - self._fixture_456(arrtable) - with testing.db.begin() as conn: - eq_( - conn.scalar( - select(arrtable.c.intarr).where( - arrtable.c.intarr.contains(struct([4, 5])) - ) - ), - [4, 5, 6], - ) + self._fixture_456(arrtable, connection) + eq_( + connection.scalar( + select(arrtable.c.intarr).where( + arrtable.c.intarr.contains(struct([4, 5])) + ) + ), + [4, 5, 6], + ) - @testing.combinations((set,), (list,), (lambda elem: (x for x in elem),)) - def test_dim_array_contains_typed_exec(self, struct): + @testing.combinations( + (set,), (list,), (lambda elem: (x for x in elem),), argnames="struct" + ) + def test_dim_array_contains_typed_exec(self, struct, connection): dim_arrtable = self.tables.dim_arrtable - self._fixture_456(dim_arrtable) - with testing.db.begin() as conn: - eq_( - conn.scalar( - select(dim_arrtable.c.intarr).where( - dim_arrtable.c.intarr.contains(struct([4, 5])) - ) - ), - [4, 5, 6], - ) + self._fixture_456(dim_arrtable, connection) + eq_( + connection.scalar( + select(dim_arrtable.c.intarr).where( + dim_arrtable.c.intarr.contains(struct([4, 5])) + ) + ), + [4, 5, 6], + ) def test_array_contained_by_exec(self, connection): arrtable = self.tables.arrtable @@ -1730,7 +1714,7 @@ class PGArrayRoundTripTest( def test_undim_array_empty(self, connection): arrtable = self.tables.arrtable - self._fixture_456(arrtable) + self._fixture_456(arrtable, connection) eq_( connection.scalar( select(arrtable.c.intarr).where(arrtable.c.intarr.contains([])) @@ -1782,8 +1766,9 @@ class ArrayEnum(fixtures.TestBase): sqltypes.ARRAY, postgresql.ARRAY, argnames="array_cls" ) @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls") - @testing.provide_metadata - def test_raises_non_native_enums(self, array_cls, enum_cls): + def test_raises_non_native_enums( + self, metadata, connection, array_cls, enum_cls + ): Table( "my_table", self.metadata, @@ -1808,7 +1793,7 @@ class ArrayEnum(fixtures.TestBase): "for ARRAY of non-native ENUM; please specify " "create_constraint=False on this Enum datatype.", self.metadata.create_all, - testing.db, + connection, ) @testing.combinations(sqltypes.Enum, postgresql.ENUM, argnames="enum_cls") @@ -1818,8 +1803,7 @@ class ArrayEnum(fixtures.TestBase): (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")), argnames="array_cls", ) - @testing.provide_metadata - def test_array_of_enums(self, array_cls, enum_cls, connection): + def test_array_of_enums(self, array_cls, enum_cls, metadata, connection): tbl = Table( "enum_table", self.metadata, @@ -1875,8 +1859,7 @@ class ArrayJSON(fixtures.TestBase): @testing.combinations( sqltypes.JSON, postgresql.JSON, postgresql.JSONB, argnames="json_cls" ) - @testing.provide_metadata - def test_array_of_json(self, array_cls, json_cls, connection): + def test_array_of_json(self, array_cls, json_cls, metadata, connection): tbl = Table( "json_table", self.metadata, @@ -1982,19 +1965,38 @@ class HashableFlagORMTest(fixtures.TestBase): }, ], ), + ( + "HSTORE", + postgresql.HSTORE(), + [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}], + testing.requires.hstore, + ), + ( + "JSONB", + postgresql.JSONB(), + [ + {"a": "1", "b": "2", "c": "3"}, + { + "d": "4", + "e": {"e1": "5", "e2": "6"}, + "f": {"f1": [9, 10, 11]}, + }, + ], + testing.requires.postgresql_jsonb, + ), + argnames="type_,data", id_="iaa", ) - @testing.provide_metadata - def test_hashable_flag(self, type_, data): - Base = declarative_base(metadata=self.metadata) + def test_hashable_flag(self, metadata, connection, type_, data): + Base = declarative_base(metadata=metadata) class A(Base): __tablename__ = "a1" id = Column(Integer, primary_key=True) data = Column(type_) - Base.metadata.create_all(testing.db) - s = Session(testing.db) + Base.metadata.create_all(connection) + s = Session(connection) s.add_all([A(data=elem) for elem in data]) s.commit() @@ -2006,27 +2008,6 @@ class HashableFlagORMTest(fixtures.TestBase): list(enumerate(data, 1)), ) - @testing.requires.hstore - def test_hstore(self): - self.test_hashable_flag( - postgresql.HSTORE(), - [{"a": "1", "b": "2", "c": "3"}, {"d": "4", "e": "5", "f": "6"}], - ) - - @testing.requires.postgresql_jsonb - def test_jsonb(self): - self.test_hashable_flag( - postgresql.JSONB(), - [ - {"a": "1", "b": "2", "c": "3"}, - { - "d": "4", - "e": {"e1": "5", "e2": "6"}, - "f": {"f1": [9, 10, 11]}, - }, - ], - ) - class TimestampTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql" @@ -2108,14 +2089,14 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): return table - def test_reflection(self, special_types_table): + def test_reflection(self, special_types_table, connection): # cheat so that the "strict type check" # works special_types_table.c.year_interval.type = postgresql.INTERVAL() special_types_table.c.month_interval.type = postgresql.INTERVAL() m = MetaData() - t = Table("sometable", m, autoload_with=testing.db) + t = Table("sometable", m, autoload_with=connection) self.assert_tables_equal(special_types_table, t, strict_types=True) assert t.c.plain_interval.type.precision is None @@ -2210,7 +2191,7 @@ class UUIDTest(fixtures.TestBase): class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): __dialect__ = "postgresql" - def setup(self): + def setup_test(self): metadata = MetaData() self.test_table = Table( "test_table", @@ -2494,17 +2475,16 @@ class HStoreRoundTripTest(fixtures.TablesTest): Column("data", HSTORE), ) - def _fixture_data(self, engine): + def _fixture_data(self, connection): data_table = self.tables.data_table - with engine.begin() as conn: - conn.execute( - data_table.insert(), - {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, - {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, - {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, - {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, - {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}}, - ) + connection.execute( + data_table.insert(), + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, + {"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}}, + {"name": "r3", "data": {"k1": "r3v1", "k2": "r3v2"}}, + {"name": "r4", "data": {"k1": "r4v1", "k2": "r4v2"}}, + {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2"}}, + ) def _assert_data(self, compare, conn): data = conn.execute( @@ -2514,26 +2494,32 @@ class HStoreRoundTripTest(fixtures.TablesTest): ).fetchall() eq_([d for d, in data], compare) - def _test_insert(self, engine): - with engine.begin() as conn: - conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, - ) - self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn) + def _test_insert(self, connection): + connection.execute( + self.tables.data_table.insert(), + {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, + ) + self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection) - def _non_native_engine(self): - if testing.requires.psycopg2_native_hstore.enabled: - engine = engines.testing_engine( - options=dict(use_native_hstore=False) - ) + @testing.fixture + def non_native_hstore_connection(self, testing_engine): + local_engine = testing.requires.psycopg2_native_hstore.enabled + + if local_engine: + engine = testing_engine(options=dict(use_native_hstore=False)) else: engine = testing.db - engine.connect().close() - return engine - def test_reflect(self): - insp = inspect(testing.db) + conn = engine.connect() + trans = conn.begin() + yield conn + try: + trans.rollback() + finally: + conn.close() + + def test_reflect(self, connection): + insp = inspect(connection) cols = insp.get_columns("data_table") assert isinstance(cols[2]["type"], HSTORE) @@ -2548,106 +2534,88 @@ class HStoreRoundTripTest(fixtures.TablesTest): eq_(connection.scalar(select(expr)), "3") @testing.requires.psycopg2_native_hstore - def test_insert_native(self): - engine = testing.db - self._test_insert(engine) + def test_insert_native(self, connection): + self._test_insert(connection) - def test_insert_python(self): - engine = self._non_native_engine() - self._test_insert(engine) + def test_insert_python(self, non_native_hstore_connection): + self._test_insert(non_native_hstore_connection) @testing.requires.psycopg2_native_hstore - def test_criterion_native(self): - engine = testing.db - self._fixture_data(engine) - self._test_criterion(engine) + def test_criterion_native(self, connection): + self._fixture_data(connection) + self._test_criterion(connection) - def test_criterion_python(self): - engine = self._non_native_engine() - self._fixture_data(engine) - self._test_criterion(engine) + def test_criterion_python(self, non_native_hstore_connection): + self._fixture_data(non_native_hstore_connection) + self._test_criterion(non_native_hstore_connection) - def _test_criterion(self, engine): + def _test_criterion(self, connection): data_table = self.tables.data_table - with engine.begin() as conn: - result = conn.execute( - select(data_table.c.data).where( - data_table.c.data["k1"] == "r3v1" - ) - ).first() - eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) + result = connection.execute( + select(data_table.c.data).where(data_table.c.data["k1"] == "r3v1") + ).first() + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) - def _test_fixed_round_trip(self, engine): - with engine.begin() as conn: - s = select( - hstore( - array(["key1", "key2", "key3"]), - array(["value1", "value2", "value3"]), - ) - ) - eq_( - conn.scalar(s), - {"key1": "value1", "key2": "value2", "key3": "value3"}, + def _test_fixed_round_trip(self, connection): + s = select( + hstore( + array(["key1", "key2", "key3"]), + array(["value1", "value2", "value3"]), ) + ) + eq_( + connection.scalar(s), + {"key1": "value1", "key2": "value2", "key3": "value3"}, + ) - def test_fixed_round_trip_python(self): - engine = self._non_native_engine() - self._test_fixed_round_trip(engine) + def test_fixed_round_trip_python(self, non_native_hstore_connection): + self._test_fixed_round_trip(non_native_hstore_connection) @testing.requires.psycopg2_native_hstore - def test_fixed_round_trip_native(self): - engine = testing.db - self._test_fixed_round_trip(engine) + def test_fixed_round_trip_native(self, connection): + self._test_fixed_round_trip(connection) - def _test_unicode_round_trip(self, engine): - with engine.begin() as conn: - s = select( - hstore( - array( - [util.u("réveillé"), util.u("drôle"), util.u("S’il")] - ), - array( - [util.u("réveillé"), util.u("drôle"), util.u("S’il")] - ), - ) - ) - eq_( - conn.scalar(s), - { - util.u("réveillé"): util.u("réveillé"), - util.u("drôle"): util.u("drôle"), - util.u("S’il"): util.u("S’il"), - }, + def _test_unicode_round_trip(self, connection): + s = select( + hstore( + array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]), + array([util.u("réveillé"), util.u("drôle"), util.u("S’il")]), ) + ) + eq_( + connection.scalar(s), + { + util.u("réveillé"): util.u("réveillé"), + util.u("drôle"): util.u("drôle"), + util.u("S’il"): util.u("S’il"), + }, + ) @testing.requires.psycopg2_native_hstore - def test_unicode_round_trip_python(self): - engine = self._non_native_engine() - self._test_unicode_round_trip(engine) + def test_unicode_round_trip_python(self, non_native_hstore_connection): + self._test_unicode_round_trip(non_native_hstore_connection) @testing.requires.psycopg2_native_hstore - def test_unicode_round_trip_native(self): - engine = testing.db - self._test_unicode_round_trip(engine) + def test_unicode_round_trip_native(self, connection): + self._test_unicode_round_trip(connection) - def test_escaped_quotes_round_trip_python(self): - engine = self._non_native_engine() - self._test_escaped_quotes_round_trip(engine) + def test_escaped_quotes_round_trip_python( + self, non_native_hstore_connection + ): + self._test_escaped_quotes_round_trip(non_native_hstore_connection) @testing.requires.psycopg2_native_hstore - def test_escaped_quotes_round_trip_native(self): - engine = testing.db - self._test_escaped_quotes_round_trip(engine) + def test_escaped_quotes_round_trip_native(self, connection): + self._test_escaped_quotes_round_trip(connection) - def _test_escaped_quotes_round_trip(self, engine): - with engine.begin() as conn: - conn.execute( - self.tables.data_table.insert(), - {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}}, - ) - self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], conn) + def _test_escaped_quotes_round_trip(self, connection): + connection.execute( + self.tables.data_table.insert(), + {"name": "r1", "data": {r"key \"foo\"": r'value \"bar"\ xyz'}}, + ) + self._assert_data([{r"key \"foo\"": r'value \"bar"\ xyz'}], connection) - def test_orm_round_trip(self): + def test_orm_round_trip(self, connection): from sqlalchemy import orm class Data(object): @@ -2656,13 +2624,14 @@ class HStoreRoundTripTest(fixtures.TablesTest): self.data = data orm.mapper(Data, self.tables.data_table) - s = orm.Session(testing.db) - d = Data( - name="r1", - data={"key1": "value1", "key2": "value2", "key3": "value3"}, - ) - s.add(d) - eq_(s.query(Data.data, Data).all(), [(d.data, d)]) + + with orm.Session(connection) as s: + d = Data( + name="r1", + data={"key1": "value1", "key2": "value2", "key3": "value3"}, + ) + s.add(d) + eq_(s.query(Data.data, Data).all(), [(d.data, d)]) class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): @@ -2671,7 +2640,7 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): # operator tests @classmethod - def setup_class(cls): + def setup_test_class(cls): table = Table( "data_table", MetaData(), @@ -2852,10 +2821,10 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): def test_actual_type(self): eq_(str(self._col_type()), self._col_str) - def test_reflect(self): + def test_reflect(self, connection): from sqlalchemy import inspect - insp = inspect(testing.db) + insp = inspect(connection) cols = insp.get_columns("data_table") assert isinstance(cols[0]["type"], self._col_type) @@ -2986,8 +2955,8 @@ class _DateTimeTZRangeTests(object): def tstzs(self): if self._tstzs is None: - with testing.db.begin() as conn: - lower = conn.scalar(func.current_timestamp().select()) + with testing.db.connect() as connection: + lower = connection.scalar(func.current_timestamp().select()) upper = lower + datetime.timedelta(1) self._tstzs = (lower, upper) return self._tstzs @@ -3052,7 +3021,7 @@ class DateTimeTZRangeRoundTripTest(_DateTimeTZRangeTests, _RangeTypeRoundTrip): class JSONTest(AssertsCompiledSQL, fixtures.TestBase): __dialect__ = "postgresql" - def setup(self): + def setup_test(self): metadata = MetaData() self.test_table = Table( "test_table", @@ -3151,7 +3120,7 @@ class JSONRoundTripTest(fixtures.TablesTest): Column("nulldata", cls.data_type(none_as_null=True)), ) - def _fixture_data(self, engine): + def _fixture_data(self, connection): data_table = self.tables.data_table data = [ @@ -3162,8 +3131,7 @@ class JSONRoundTripTest(fixtures.TablesTest): {"name": "r5", "data": {"k1": "r5v1", "k2": "r5v2", "k3": 5}}, {"name": "r6", "data": {"k1": {"r6v1": {"subr": [1, 2, 3]}}}}, ] - with engine.begin() as conn: - conn.execute(data_table.insert(), data) + connection.execute(data_table.insert(), data) return data def _assert_data(self, compare, conn, column="data"): @@ -3185,51 +3153,39 @@ class JSONRoundTripTest(fixtures.TablesTest): ).fetchall() eq_([d for d, in data], [None]) - def _test_insert(self, conn): - conn.execute( + def test_reflect(self, connection): + insp = inspect(connection) + cols = insp.get_columns("data_table") + assert isinstance(cols[2]["type"], self.data_type) + + def test_insert(self, connection): + connection.execute( self.tables.data_table.insert(), {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}, ) - self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], conn) + self._assert_data([{"k1": "r1v1", "k2": "r1v2"}], connection) - def _test_insert_nulls(self, conn): - conn.execute( + def test_insert_nulls(self, connection): + connection.execute( self.tables.data_table.insert(), {"name": "r1", "data": null()} ) - self._assert_data([None], conn) + self._assert_data([None], connection) - def _test_insert_none_as_null(self, conn): - conn.execute( + def test_insert_none_as_null(self, connection): + connection.execute( self.tables.data_table.insert(), {"name": "r1", "nulldata": None}, ) - self._assert_column_is_NULL(conn, column="nulldata") + self._assert_column_is_NULL(connection, column="nulldata") - def _test_insert_nulljson_into_none_as_null(self, conn): - conn.execute( + def test_insert_nulljson_into_none_as_null(self, connection): + connection.execute( self.tables.data_table.insert(), {"name": "r1", "nulldata": JSON.NULL}, ) - self._assert_column_is_JSON_NULL(conn, column="nulldata") - - def test_reflect(self): - insp = inspect(testing.db) - cols = insp.get_columns("data_table") - assert isinstance(cols[2]["type"], self.data_type) - - def test_insert(self, connection): - self._test_insert(connection) - - def test_insert_nulls(self, connection): - self._test_insert_nulls(connection) + self._assert_column_is_JSON_NULL(connection, column="nulldata") - def test_insert_none_as_null(self, connection): - self._test_insert_none_as_null(connection) - - def test_insert_nulljson_into_none_as_null(self, connection): - self._test_insert_nulljson_into_none_as_null(connection) - - def test_custom_serialize_deserialize(self): + def test_custom_serialize_deserialize(self, testing_engine): import json def loads(value): @@ -3242,7 +3198,7 @@ class JSONRoundTripTest(fixtures.TablesTest): value["x"] = "dumps_y" return json.dumps(value) - engine = engines.testing_engine( + engine = testing_engine( options=dict(json_serializer=dumps, json_deserializer=loads) ) @@ -3250,14 +3206,26 @@ class JSONRoundTripTest(fixtures.TablesTest): with engine.begin() as conn: eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"}) - def test_criterion(self): - engine = testing.db - self._fixture_data(engine) - self._test_criterion(engine) + def test_criterion(self, connection): + self._fixture_data(connection) + data_table = self.tables.data_table + + result = connection.execute( + select(data_table.c.data).where( + data_table.c.data["k1"].astext == "r3v1" + ) + ).first() + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) + + result = connection.execute( + select(data_table.c.data).where( + data_table.c.data["k1"].astext.cast(String) == "r3v1" + ) + ).first() + eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) def test_path_query(self, connection): - engine = testing.db - self._fixture_data(engine) + self._fixture_data(connection) data_table = self.tables.data_table result = connection.execute( @@ -3271,8 +3239,7 @@ class JSONRoundTripTest(fixtures.TablesTest): "postgresql < 9.4", "Improvement in PostgreSQL behavior?" ) def test_multi_index_query(self, connection): - engine = testing.db - self._fixture_data(engine) + self._fixture_data(connection) data_table = self.tables.data_table result = connection.execute( @@ -3283,20 +3250,18 @@ class JSONRoundTripTest(fixtures.TablesTest): eq_(result.scalar(), "r6") def test_query_returned_as_text(self, connection): - engine = testing.db - self._fixture_data(engine) + self._fixture_data(connection) data_table = self.tables.data_table result = connection.execute( select(data_table.c.data["k1"].astext) ).first() - if engine.dialect.returns_unicode_strings: + if connection.dialect.returns_unicode_strings: assert isinstance(result[0], util.text_type) else: assert isinstance(result[0], util.string_types) def test_query_returned_as_int(self, connection): - engine = testing.db - self._fixture_data(engine) + self._fixture_data(connection) data_table = self.tables.data_table result = connection.execute( select(data_table.c.data["k3"].astext.cast(Integer)).where( @@ -3305,23 +3270,6 @@ class JSONRoundTripTest(fixtures.TablesTest): ).first() assert isinstance(result[0], int) - def _test_criterion(self, engine): - data_table = self.tables.data_table - with engine.begin() as conn: - result = conn.execute( - select(data_table.c.data).where( - data_table.c.data["k1"].astext == "r3v1" - ) - ).first() - eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) - - result = conn.execute( - select(data_table.c.data).where( - data_table.c.data["k1"].astext.cast(String) == "r3v1" - ) - ).first() - eq_(result, ({"k1": "r3v1", "k2": "r3v2"},)) - def test_fixed_round_trip(self, connection): s = select( cast( @@ -3352,42 +3300,41 @@ class JSONRoundTripTest(fixtures.TablesTest): }, ) - def test_eval_none_flag_orm(self): + def test_eval_none_flag_orm(self, connection): Base = declarative_base() class Data(Base): __table__ = self.tables.data_table - s = Session(testing.db) + with Session(connection) as s: + d1 = Data(name="d1", data=None, nulldata=None) + s.add(d1) + s.commit() - d1 = Data(name="d1", data=None, nulldata=None) - s.add(d1) - s.commit() - - s.bulk_insert_mappings( - Data, [{"name": "d2", "data": None, "nulldata": None}] - ) - eq_( - s.query( - cast(self.tables.data_table.c.data, String), - cast(self.tables.data_table.c.nulldata, String), + s.bulk_insert_mappings( + Data, [{"name": "d2", "data": None, "nulldata": None}] ) - .filter(self.tables.data_table.c.name == "d1") - .first(), - ("null", None), - ) - eq_( - s.query( - cast(self.tables.data_table.c.data, String), - cast(self.tables.data_table.c.nulldata, String), + eq_( + s.query( + cast(self.tables.data_table.c.data, String), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), ) - .filter(self.tables.data_table.c.name == "d2") - .first(), - ("null", None), - ) def test_literal(self, connection): - exp = self._fixture_data(testing.db) + exp = self._fixture_data(connection) result = connection.exec_driver_sql( "select data from data_table order by name" ) @@ -3395,11 +3342,10 @@ class JSONRoundTripTest(fixtures.TablesTest): eq_(len(res), len(exp)) for row, expected in zip(res, exp): eq_(row[0], expected["data"]) - result.close() class JSONBTest(JSONTest): - def setup(self): + def setup_test(self): metadata = MetaData() self.test_table = Table( "test_table", diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 4658b40a8d..1926c60652 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -37,6 +37,7 @@ from sqlalchemy import UniqueConstraint from sqlalchemy import util from sqlalchemy.dialects.sqlite import base as sqlite from sqlalchemy.dialects.sqlite import insert +from sqlalchemy.dialects.sqlite import provision from sqlalchemy.dialects.sqlite import pysqlite as pysqlite_dialect from sqlalchemy.engine.url import make_url from sqlalchemy.schema import CreateTable @@ -46,6 +47,7 @@ from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import AssertsExecutionResults from sqlalchemy.testing import combinations +from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings @@ -774,7 +776,7 @@ class AttachedDBTest(fixtures.TestBase): def _fixture(self): meta = self.metadata - self.conn = testing.db.connect() + self.conn = self.engine.connect() Table("created", meta, Column("foo", Integer), Column("bar", String)) Table("local_only", meta, Column("q", Integer), Column("p", Integer)) @@ -798,14 +800,20 @@ class AttachedDBTest(fixtures.TestBase): meta.create_all(self.conn) return ct - def setup(self): - self.conn = testing.db.connect() + def setup_test(self): + self.engine = engines.testing_engine(options={"use_reaper": False}) + + provision._sqlite_post_configure_engine( + self.engine.url, self.engine, config.ident + ) + self.conn = self.engine.connect() self.metadata = MetaData() - def teardown(self): + def teardown_test(self): with self.conn.begin(): self.metadata.drop_all(self.conn) self.conn.close() + self.engine.dispose() def test_no_tables(self): insp = inspect(self.conn) @@ -1495,7 +1503,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): __skip_if__ = (full_text_search_missing,) @classmethod - def setup_class(cls): + def setup_test_class(cls): global metadata, cattable, matchtable metadata = MetaData() exec_sql( @@ -1559,7 +1567,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): ) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): metadata.drop_all(testing.db) def test_expression(self): @@ -1681,7 +1689,7 @@ class AutoIncrementTest(fixtures.TestBase, AssertsCompiledSQL): class ReflectHeadlessFKsTest(fixtures.TestBase): __only_on__ = "sqlite" - def setup(self): + def setup_test(self): exec_sql(testing.db, "CREATE TABLE a (id INTEGER PRIMARY KEY)") # this syntax actually works on other DBs perhaps we'd want to add # tests to test_reflection @@ -1689,7 +1697,7 @@ class ReflectHeadlessFKsTest(fixtures.TestBase): testing.db, "CREATE TABLE b (id INTEGER PRIMARY KEY REFERENCES a)" ) - def teardown(self): + def teardown_test(self): exec_sql(testing.db, "drop table b") exec_sql(testing.db, "drop table a") @@ -1728,7 +1736,7 @@ class ConstraintReflectionTest(fixtures.TestBase): __only_on__ = "sqlite" @classmethod - def setup_class(cls): + def setup_test_class(cls): with testing.db.begin() as conn: conn.exec_driver_sql("CREATE TABLE a1 (id INTEGER PRIMARY KEY)") @@ -1876,7 +1884,7 @@ class ConstraintReflectionTest(fixtures.TestBase): ) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as conn: for name in [ "implicit_referrer_comp_fake", @@ -2370,7 +2378,7 @@ class SavepointTest(fixtures.TablesTest): @classmethod def setup_bind(cls): - engine = engines.testing_engine(options={"use_reaper": False}) + engine = engines.testing_engine(options={"scope": "class"}) @event.listens_for(engine, "connect") def do_connect(dbapi_connection, connection_record): @@ -2579,7 +2587,7 @@ class TypeReflectionTest(fixtures.TestBase): class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "sqlite" - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py index 396b48aa4a..baa766d48f 100644 --- a/test/engine/test_ddlevents.py +++ b/test/engine/test_ddlevents.py @@ -21,7 +21,7 @@ from sqlalchemy.testing.schema import Table class DDLEventTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.bind = engines.mock_engine() self.metadata = MetaData() self.table = Table("t", self.metadata, Column("id", Integer)) @@ -374,7 +374,7 @@ class DDLEventTest(fixtures.TestBase): class DDLExecutionTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.engine = engines.mock_engine() self.metadata = MetaData() self.users = Table( diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index a18cf756b1..0a2c9abe58 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -965,7 +965,7 @@ class TransactionTest(fixtures.TablesTest): class HandleInvalidatedOnConnectTest(fixtures.TestBase): __requires__ = ("sqlite",) - def setUp(self): + def setup_test(self): e = create_engine("sqlite://") connection = Mock(get_server_version_info=Mock(return_value="5.0")) @@ -1021,18 +1021,18 @@ def MockDBAPI(): # noqa class PoolTestBase(fixtures.TestBase): - def setup(self): + def setup_test(self): pool.clear_managers() self._teardown_conns = [] - def teardown(self): + def teardown_test(self): for ref in self._teardown_conns: conn = ref() if conn: conn.close() @classmethod - def teardown_class(cls): + def teardown_test_class(cls): pool.clear_managers() def _queuepool_fixture(self, **kw): @@ -1597,7 +1597,7 @@ class EngineEventsTest(fixtures.TestBase): __requires__ = ("ad_hoc_engines",) __backend__ = True - def tearDown(self): + def teardown_test(self): Engine.dispatch._clear() Engine._has_events = False @@ -1650,6 +1650,7 @@ class EngineEventsTest(fixtures.TestBase): event.listen( engine, "before_cursor_execute", cursor_execute, retval=True ) + with testing.expect_deprecated( r"The argument signature for the " r"\"ConnectionEvents.before_execute\" event listener", @@ -1676,11 +1677,12 @@ class EngineEventsTest(fixtures.TestBase): r"The argument signature for the " r"\"ConnectionEvents.after_execute\" event listener", ): - e1.execute(select(1)) + result = e1.execute(select(1)) + result.close() class DDLExecutionTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.engine = engines.mock_engine() self.metadata = MetaData() self.users = Table( diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 21d4e06e06..a1e4ea218e 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -43,7 +43,6 @@ from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import CompiledSQL -from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.mock import call from sqlalchemy.testing.mock import Mock from sqlalchemy.testing.mock import patch @@ -94,13 +93,13 @@ class ExecuteTest(fixtures.TablesTest): ).default_from() ) - conn = testing.db.connect() - result = ( - conn.execution_options(no_parameters=True) - .exec_driver_sql(stmt) - .scalar() - ) - eq_(result, "%") + with testing.db.connect() as conn: + result = ( + conn.execution_options(no_parameters=True) + .exec_driver_sql(stmt) + .scalar() + ) + eq_(result, "%") def test_raw_positional_invalid(self, connection): assert_raises_message( @@ -261,16 +260,15 @@ class ExecuteTest(fixtures.TablesTest): (4, "sally"), ] - @testing.engines.close_open_connections def test_exception_wrapping_dbapi(self): - conn = testing.db.connect() - # engine does not have exec_driver_sql - assert_raises_message( - tsa.exc.DBAPIError, - r"not_a_valid_statement", - conn.exec_driver_sql, - "not_a_valid_statement", - ) + with testing.db.connect() as conn: + # engine does not have exec_driver_sql + assert_raises_message( + tsa.exc.DBAPIError, + r"not_a_valid_statement", + conn.exec_driver_sql, + "not_a_valid_statement", + ) @testing.requires.sqlite def test_exception_wrapping_non_dbapi_error(self): @@ -864,12 +862,10 @@ class CompiledCacheTest(fixtures.TestBase): ["sqlite", "mysql", "postgresql"], "uses blob value that is problematic for some DBAPIs", ) - @testing.provide_metadata - def test_cache_noleak_on_statement_values(self, connection): + def test_cache_noleak_on_statement_values(self, metadata, connection): # This is a non regression test for an object reference leak caused # by the compiled_cache. - metadata = self.metadata photo = Table( "photo", metadata, @@ -1040,7 +1036,19 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): __requires__ = ("schemas",) __backend__ = True - def test_create_table(self): + @testing.fixture + def plain_tables(self, metadata): + t1 = Table( + "t1", metadata, Column("x", Integer), schema=config.test_schema + ) + t2 = Table( + "t2", metadata, Column("x", Integer), schema=config.test_schema + ) + t3 = Table("t3", metadata, Column("x", Integer), schema=None) + + return t1, t2, t3 + + def test_create_table(self, plain_tables, connection): map_ = { None: config.test_schema, "foo": config.test_schema, @@ -1052,18 +1060,16 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): t2 = Table("t2", metadata, Column("x", Integer), schema="foo") t3 = Table("t3", metadata, Column("x", Integer), schema="bar") - with self.sql_execution_asserter(config.db) as asserter: - with config.db.begin() as conn, conn.execution_options( - schema_translate_map=map_ - ) as conn: + with self.sql_execution_asserter(connection) as asserter: + conn = connection.execution_options(schema_translate_map=map_) - t1.create(conn) - t2.create(conn) - t3.create(conn) + t1.create(conn) + t2.create(conn) + t3.create(conn) - t3.drop(conn) - t2.drop(conn) - t1.drop(conn) + t3.drop(conn) + t2.drop(conn) + t1.drop(conn) asserter.assert_( CompiledSQL("CREATE TABLE [SCHEMA__none].t1 (x INTEGER)"), @@ -1074,14 +1080,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): CompiledSQL("DROP TABLE [SCHEMA__none].t1"), ) - def _fixture(self): - metadata = self.metadata - Table("t1", metadata, Column("x", Integer), schema=config.test_schema) - Table("t2", metadata, Column("x", Integer), schema=config.test_schema) - Table("t3", metadata, Column("x", Integer), schema=None) - metadata.create_all(testing.db) - - def test_ddl_hastable(self): + def test_ddl_hastable(self, plain_tables, connection): map_ = { None: config.test_schema, @@ -1094,27 +1093,28 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): Table("t2", metadata, Column("x", Integer), schema="foo") Table("t3", metadata, Column("x", Integer), schema="bar") - with config.db.begin() as conn: - conn = conn.execution_options(schema_translate_map=map_) - metadata.create_all(conn) + conn = connection.execution_options(schema_translate_map=map_) + metadata.create_all(conn) - insp = inspect(config.db) + insp = inspect(connection) is_true(insp.has_table("t1", schema=config.test_schema)) is_true(insp.has_table("t2", schema=config.test_schema)) is_true(insp.has_table("t3", schema=None)) - with config.db.begin() as conn: - conn = conn.execution_options(schema_translate_map=map_) - metadata.drop_all(conn) + conn = connection.execution_options(schema_translate_map=map_) + + # if this test fails, the tables won't get dropped. so need a + # more robust fixture for this + metadata.drop_all(conn) - insp = inspect(config.db) + insp = inspect(connection) is_false(insp.has_table("t1", schema=config.test_schema)) is_false(insp.has_table("t2", schema=config.test_schema)) is_false(insp.has_table("t3", schema=None)) - @testing.provide_metadata - def test_option_on_execute(self): - self._fixture() + def test_option_on_execute(self, plain_tables, connection): + # provided by metadata fixture provided by plain_tables fixture + self.metadata.create_all(connection) map_ = { None: config.test_schema, @@ -1127,61 +1127,54 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): t2 = Table("t2", metadata, Column("x", Integer), schema="foo") t3 = Table("t3", metadata, Column("x", Integer), schema="bar") - with self.sql_execution_asserter(config.db) as asserter: - with config.db.begin() as conn: + with self.sql_execution_asserter(connection) as asserter: + conn = connection + execution_options = {"schema_translate_map": map_} + conn._execute_20( + t1.insert(), {"x": 1}, execution_options=execution_options + ) + conn._execute_20( + t2.insert(), {"x": 1}, execution_options=execution_options + ) + conn._execute_20( + t3.insert(), {"x": 1}, execution_options=execution_options + ) - execution_options = {"schema_translate_map": map_} - conn._execute_20( - t1.insert(), {"x": 1}, execution_options=execution_options - ) - conn._execute_20( - t2.insert(), {"x": 1}, execution_options=execution_options - ) - conn._execute_20( - t3.insert(), {"x": 1}, execution_options=execution_options - ) + conn._execute_20( + t1.update().values(x=1).where(t1.c.x == 1), + execution_options=execution_options, + ) + conn._execute_20( + t2.update().values(x=2).where(t2.c.x == 1), + execution_options=execution_options, + ) + conn._execute_20( + t3.update().values(x=3).where(t3.c.x == 1), + execution_options=execution_options, + ) + eq_( conn._execute_20( - t1.update().values(x=1).where(t1.c.x == 1), - execution_options=execution_options, - ) + select(t1.c.x), execution_options=execution_options + ).scalar(), + 1, + ) + eq_( conn._execute_20( - t2.update().values(x=2).where(t2.c.x == 1), - execution_options=execution_options, - ) + select(t2.c.x), execution_options=execution_options + ).scalar(), + 2, + ) + eq_( conn._execute_20( - t3.update().values(x=3).where(t3.c.x == 1), - execution_options=execution_options, - ) - - eq_( - conn._execute_20( - select(t1.c.x), execution_options=execution_options - ).scalar(), - 1, - ) - eq_( - conn._execute_20( - select(t2.c.x), execution_options=execution_options - ).scalar(), - 2, - ) - eq_( - conn._execute_20( - select(t3.c.x), execution_options=execution_options - ).scalar(), - 3, - ) + select(t3.c.x), execution_options=execution_options + ).scalar(), + 3, + ) - conn._execute_20( - t1.delete(), execution_options=execution_options - ) - conn._execute_20( - t2.delete(), execution_options=execution_options - ) - conn._execute_20( - t3.delete(), execution_options=execution_options - ) + conn._execute_20(t1.delete(), execution_options=execution_options) + conn._execute_20(t2.delete(), execution_options=execution_options) + conn._execute_20(t3.delete(), execution_options=execution_options) asserter.assert_( CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"), @@ -1207,9 +1200,9 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): CompiledSQL("DELETE FROM [SCHEMA_bar].t3"), ) - @testing.provide_metadata - def test_crud(self): - self._fixture() + def test_crud(self, plain_tables, connection): + # provided by metadata fixture provided by plain_tables fixture + self.metadata.create_all(connection) map_ = { None: config.test_schema, @@ -1222,26 +1215,24 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): t2 = Table("t2", metadata, Column("x", Integer), schema="foo") t3 = Table("t3", metadata, Column("x", Integer), schema="bar") - with self.sql_execution_asserter(config.db) as asserter: - with config.db.begin() as conn, conn.execution_options( - schema_translate_map=map_ - ) as conn: + with self.sql_execution_asserter(connection) as asserter: + conn = connection.execution_options(schema_translate_map=map_) - conn.execute(t1.insert(), {"x": 1}) - conn.execute(t2.insert(), {"x": 1}) - conn.execute(t3.insert(), {"x": 1}) + conn.execute(t1.insert(), {"x": 1}) + conn.execute(t2.insert(), {"x": 1}) + conn.execute(t3.insert(), {"x": 1}) - conn.execute(t1.update().values(x=1).where(t1.c.x == 1)) - conn.execute(t2.update().values(x=2).where(t2.c.x == 1)) - conn.execute(t3.update().values(x=3).where(t3.c.x == 1)) + conn.execute(t1.update().values(x=1).where(t1.c.x == 1)) + conn.execute(t2.update().values(x=2).where(t2.c.x == 1)) + conn.execute(t3.update().values(x=3).where(t3.c.x == 1)) - eq_(conn.scalar(select(t1.c.x)), 1) - eq_(conn.scalar(select(t2.c.x)), 2) - eq_(conn.scalar(select(t3.c.x)), 3) + eq_(conn.scalar(select(t1.c.x)), 1) + eq_(conn.scalar(select(t2.c.x)), 2) + eq_(conn.scalar(select(t3.c.x)), 3) - conn.execute(t1.delete()) - conn.execute(t2.delete()) - conn.execute(t3.delete()) + conn.execute(t1.delete()) + conn.execute(t2.delete()) + conn.execute(t3.delete()) asserter.assert_( CompiledSQL("INSERT INTO [SCHEMA__none].t1 (x) VALUES (:x)"), @@ -1267,9 +1258,10 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): CompiledSQL("DELETE FROM [SCHEMA_bar].t3"), ) - @testing.provide_metadata - def test_via_engine(self): - self._fixture() + def test_via_engine(self, plain_tables, metadata): + + with config.db.begin() as connection: + metadata.create_all(connection) map_ = { None: config.test_schema, @@ -1282,25 +1274,25 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): with self.sql_execution_asserter(config.db) as asserter: eng = config.db.execution_options(schema_translate_map=map_) - conn = eng.connect() - conn.execute(select(t2.c.x)) + with eng.connect() as conn: + conn.execute(select(t2.c.x)) asserter.assert_( CompiledSQL("SELECT [SCHEMA_foo].t2.x FROM [SCHEMA_foo].t2") ) class ExecutionOptionsTest(fixtures.TestBase): - def test_dialect_conn_options(self): + def test_dialect_conn_options(self, testing_engine): engine = testing_engine("sqlite://", options=dict(_initialize=False)) engine.dialect = Mock() - conn = engine.connect() - c2 = conn.execution_options(foo="bar") - eq_( - engine.dialect.set_connection_execution_options.mock_calls, - [call(c2, {"foo": "bar"})], - ) + with engine.connect() as conn: + c2 = conn.execution_options(foo="bar") + eq_( + engine.dialect.set_connection_execution_options.mock_calls, + [call(c2, {"foo": "bar"})], + ) - def test_dialect_engine_options(self): + def test_dialect_engine_options(self, testing_engine): engine = testing_engine("sqlite://") engine.dialect = Mock() e2 = engine.execution_options(foo="bar") @@ -1319,14 +1311,14 @@ class ExecutionOptionsTest(fixtures.TestBase): [call(engine, {"foo": "bar"})], ) - def test_propagate_engine_to_connection(self): + def test_propagate_engine_to_connection(self, testing_engine): engine = testing_engine( "sqlite://", options=dict(execution_options={"foo": "bar"}) ) - conn = engine.connect() - eq_(conn._execution_options, {"foo": "bar"}) + with engine.connect() as conn: + eq_(conn._execution_options, {"foo": "bar"}) - def test_propagate_option_engine_to_connection(self): + def test_propagate_option_engine_to_connection(self, testing_engine): e1 = testing_engine( "sqlite://", options=dict(execution_options={"foo": "bar"}) ) @@ -1336,27 +1328,30 @@ class ExecutionOptionsTest(fixtures.TestBase): eq_(c1._execution_options, {"foo": "bar"}) eq_(c2._execution_options, {"foo": "bar", "bat": "hoho"}) - def test_get_engine_execution_options(self): + c1.close() + c2.close() + + def test_get_engine_execution_options(self, testing_engine): engine = testing_engine("sqlite://") engine.dialect = Mock() e2 = engine.execution_options(foo="bar") eq_(e2.get_execution_options(), {"foo": "bar"}) - def test_get_connection_execution_options(self): + def test_get_connection_execution_options(self, testing_engine): engine = testing_engine("sqlite://", options=dict(_initialize=False)) engine.dialect = Mock() - conn = engine.connect() - c = conn.execution_options(foo="bar") + with engine.connect() as conn: + c = conn.execution_options(foo="bar") - eq_(c.get_execution_options(), {"foo": "bar"}) + eq_(c.get_execution_options(), {"foo": "bar"}) class EngineEventsTest(fixtures.TestBase): __requires__ = ("ad_hoc_engines",) __backend__ = True - def tearDown(self): + def teardown_test(self): Engine.dispatch._clear() Engine._has_events = False @@ -1376,7 +1371,7 @@ class EngineEventsTest(fixtures.TestBase): ): break - def test_per_engine_independence(self): + def test_per_engine_independence(self, testing_engine): e1 = testing_engine(config.db_url) e2 = testing_engine(config.db_url) @@ -1400,7 +1395,7 @@ class EngineEventsTest(fixtures.TestBase): conn.execute(s2) eq_([arg[1][1] for arg in canary.mock_calls], [s1, s1, s2]) - def test_per_engine_plus_global(self): + def test_per_engine_plus_global(self, testing_engine): canary = Mock() event.listen(Engine, "before_execute", canary.be1) e1 = testing_engine(config.db_url) @@ -1409,8 +1404,6 @@ class EngineEventsTest(fixtures.TestBase): event.listen(e1, "before_execute", canary.be2) event.listen(Engine, "before_execute", canary.be3) - e1.connect() - e2.connect() with e1.connect() as conn: conn.execute(select(1)) @@ -1424,7 +1417,7 @@ class EngineEventsTest(fixtures.TestBase): eq_(canary.be2.call_count, 1) eq_(canary.be3.call_count, 2) - def test_per_connection_plus_engine(self): + def test_per_connection_plus_engine(self, testing_engine): canary = Mock() e1 = testing_engine(config.db_url) @@ -1442,9 +1435,14 @@ class EngineEventsTest(fixtures.TestBase): eq_(canary.be1.call_count, 2) eq_(canary.be2.call_count, 2) - @testing.combinations((True, False), (True, True), (False, False)) + @testing.combinations( + (True, False), + (True, True), + (False, False), + argnames="mock_out_on_connect, add_our_own_onconnect", + ) def test_insert_connect_is_definitely_first( - self, mock_out_on_connect, add_our_own_onconnect + self, mock_out_on_connect, add_our_own_onconnect, testing_engine ): """test issue #5708. @@ -1478,7 +1476,7 @@ class EngineEventsTest(fixtures.TestBase): patcher = util.nullcontext() with patcher: - e1 = create_engine(config.db_url) + e1 = testing_engine(config.db_url) initialize = e1.dialect.initialize @@ -1559,10 +1557,11 @@ class EngineEventsTest(fixtures.TestBase): conn.exec_driver_sql(select1(testing.db)) eq_(m1.mock_calls, []) - def test_add_event_after_connect(self): + def test_add_event_after_connect(self, testing_engine): # new feature as of #2978 + canary = Mock() - e1 = create_engine(config.db_url) + e1 = testing_engine(config.db_url, future=False) assert not e1._has_events conn = e1.connect() @@ -1575,9 +1574,9 @@ class EngineEventsTest(fixtures.TestBase): conn._branch().execute(select(1)) eq_(canary.be1.call_count, 2) - def test_force_conn_events_false(self): + def test_force_conn_events_false(self, testing_engine): canary = Mock() - e1 = create_engine(config.db_url) + e1 = testing_engine(config.db_url, future=False) assert not e1._has_events event.listen(e1, "before_execute", canary.be1) @@ -1593,7 +1592,7 @@ class EngineEventsTest(fixtures.TestBase): conn._branch().execute(select(1)) eq_(canary.be1.call_count, 0) - def test_cursor_events_ctx_execute_scalar(self): + def test_cursor_events_ctx_execute_scalar(self, testing_engine): canary = Mock() e1 = testing_engine(config.db_url) @@ -1620,7 +1619,7 @@ class EngineEventsTest(fixtures.TestBase): [call(conn, ctx.cursor, stmt, ctx.parameters[0], ctx, False)], ) - def test_cursor_events_execute(self): + def test_cursor_events_execute(self, testing_engine): canary = Mock() e1 = testing_engine(config.db_url) @@ -1653,9 +1652,15 @@ class EngineEventsTest(fixtures.TestBase): ), ((), {"z": 10}, [], {"z": 10}, testing.requires.legacy_engine), (({"z": 10},), {}, [], {"z": 10}), + argnames="multiparams, params, expected_multiparams, expected_params", ) def test_modify_parameters_from_event_one( - self, multiparams, params, expected_multiparams, expected_params + self, + multiparams, + params, + expected_multiparams, + expected_params, + testing_engine, ): # this is testing both the normalization added to parameters # as of I97cb4d06adfcc6b889f10d01cc7775925cffb116 as well as @@ -1704,7 +1709,9 @@ class EngineEventsTest(fixtures.TestBase): [(15,), (19,)], ) - def test_modify_parameters_from_event_three(self, connection): + def test_modify_parameters_from_event_three( + self, connection, testing_engine + ): def before_execute( conn, clauseelement, multiparams, params, execution_options ): @@ -1721,7 +1728,7 @@ class EngineEventsTest(fixtures.TestBase): with e1.connect() as conn: conn.execute(select(literal("1"))) - def test_argument_format_execute(self): + def test_argument_format_execute(self, testing_engine): def before_execute( conn, clauseelement, multiparams, params, execution_options ): @@ -1956,9 +1963,9 @@ class EngineEventsTest(fixtures.TestBase): ) @testing.requires.ad_hoc_engines - def test_dispose_event(self): + def test_dispose_event(self, testing_engine): canary = Mock() - eng = create_engine(testing.db.url) + eng = testing_engine(testing.db.url) event.listen(eng, "engine_disposed", canary) conn = eng.connect() @@ -2102,13 +2109,13 @@ class EngineEventsTest(fixtures.TestBase): event.listen(engine, "commit", tracker("commit")) event.listen(engine, "rollback", tracker("rollback")) - conn = engine.connect() - trans = conn.begin() - conn.execute(select(1)) - trans.rollback() - trans = conn.begin() - conn.execute(select(1)) - trans.commit() + with engine.connect() as conn: + trans = conn.begin() + conn.execute(select(1)) + trans.rollback() + trans = conn.begin() + conn.execute(select(1)) + trans.commit() eq_( canary, @@ -2145,13 +2152,13 @@ class EngineEventsTest(fixtures.TestBase): event.listen(engine, "commit", tracker("commit"), named=True) event.listen(engine, "rollback", tracker("rollback"), named=True) - conn = engine.connect() - trans = conn.begin() - conn.execute(select(1)) - trans.rollback() - trans = conn.begin() - conn.execute(select(1)) - trans.commit() + with engine.connect() as conn: + trans = conn.begin() + conn.execute(select(1)) + trans.rollback() + trans = conn.begin() + conn.execute(select(1)) + trans.commit() eq_( canary, @@ -2310,7 +2317,7 @@ class HandleErrorTest(fixtures.TestBase): __requires__ = ("ad_hoc_engines",) __backend__ = True - def tearDown(self): + def teardown_test(self): Engine.dispatch._clear() Engine._has_events = False @@ -2742,7 +2749,7 @@ class HandleErrorTest(fixtures.TestBase): class HandleInvalidatedOnConnectTest(fixtures.TestBase): __requires__ = ("sqlite",) - def setUp(self): + def setup_test(self): e = create_engine("sqlite://") connection = Mock(get_server_version_info=Mock(return_value="5.0")) @@ -3014,6 +3021,9 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase): ], ) + c.close() + c2.close() + class DialectEventTest(fixtures.TestBase): @contextmanager @@ -3370,7 +3380,7 @@ class SetInputSizesTest(fixtures.TablesTest): ) @testing.fixture - def input_sizes_fixture(self): + def input_sizes_fixture(self, testing_engine): canary = mock.Mock() def do_set_input_sizes(cursor, list_of_tuples, context): diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py index 29b8132aa3..c565892487 100644 --- a/test/engine/test_logging.py +++ b/test/engine/test_logging.py @@ -30,7 +30,7 @@ class LogParamsTest(fixtures.TestBase): __only_on__ = "sqlite" __requires__ = ("ad_hoc_engines",) - def setup(self): + def setup_test(self): self.eng = engines.testing_engine(options={"echo": True}) self.no_param_engine = engines.testing_engine( options={"echo": True, "hide_parameters": True} @@ -44,7 +44,7 @@ class LogParamsTest(fixtures.TestBase): for log in [logging.getLogger("sqlalchemy.engine")]: log.addHandler(self.buf) - def teardown(self): + def teardown_test(self): exec_sql(self.eng, "drop table if exists foo") for log in [logging.getLogger("sqlalchemy.engine")]: log.removeHandler(self.buf) @@ -413,14 +413,14 @@ class LogParamsTest(fixtures.TestBase): class PoolLoggingTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.existing_level = logging.getLogger("sqlalchemy.pool").level self.buf = logging.handlers.BufferingHandler(100) for log in [logging.getLogger("sqlalchemy.pool")]: log.addHandler(self.buf) - def teardown(self): + def teardown_test(self): for log in [logging.getLogger("sqlalchemy.pool")]: log.removeHandler(self.buf) logging.getLogger("sqlalchemy.pool").setLevel(self.existing_level) @@ -528,7 +528,7 @@ class LoggingNameTest(fixtures.TestBase): kw.update({"echo": True}) return engines.testing_engine(options=kw) - def setup(self): + def setup_test(self): self.buf = logging.handlers.BufferingHandler(100) for log in [ logging.getLogger("sqlalchemy.engine"), @@ -536,7 +536,7 @@ class LoggingNameTest(fixtures.TestBase): ]: log.addHandler(self.buf) - def teardown(self): + def teardown_test(self): for log in [ logging.getLogger("sqlalchemy.engine"), logging.getLogger("sqlalchemy.pool"), @@ -588,13 +588,13 @@ class LoggingNameTest(fixtures.TestBase): class EchoTest(fixtures.TestBase): __requires__ = ("ad_hoc_engines",) - def setup(self): + def setup_test(self): self.level = logging.getLogger("sqlalchemy.engine").level logging.getLogger("sqlalchemy.engine").setLevel(logging.WARN) self.buf = logging.handlers.BufferingHandler(100) logging.getLogger("sqlalchemy.engine").addHandler(self.buf) - def teardown(self): + def teardown_test(self): logging.getLogger("sqlalchemy.engine").removeHandler(self.buf) logging.getLogger("sqlalchemy.engine").setLevel(self.level) diff --git a/test/engine/test_pool.py b/test/engine/test_pool.py index 550fedb8e6..decdce3f9b 100644 --- a/test/engine/test_pool.py +++ b/test/engine/test_pool.py @@ -17,7 +17,9 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_none from sqlalchemy.testing import is_not +from sqlalchemy.testing import is_not_none from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.engines import testing_engine @@ -63,18 +65,18 @@ def MockDBAPI(): # noqa class PoolTestBase(fixtures.TestBase): - def setup(self): + def setup_test(self): pool.clear_managers() self._teardown_conns = [] - def teardown(self): + def teardown_test(self): for ref in self._teardown_conns: conn = ref() if conn: conn.close() @classmethod - def teardown_class(cls): + def teardown_test_class(cls): pool.clear_managers() def _with_teardown(self, connection): @@ -364,10 +366,17 @@ class PoolEventsTest(PoolTestBase): p = self._queuepool_fixture() canary = [] + @event.listens_for(p, "checkin") def checkin(*arg, **kw): canary.append("checkin") - event.listen(p, "checkin", checkin) + @event.listens_for(p, "close_detached") + def close_detached(*arg, **kw): + canary.append("close_detached") + + @event.listens_for(p, "detach") + def detach(*arg, **kw): + canary.append("detach") return p, canary @@ -629,15 +638,35 @@ class PoolEventsTest(PoolTestBase): assert canary.call_args_list[0][0][0] is dbapi_con assert canary.call_args_list[0][0][2] is exc + @testing.combinations((True, testing.requires.python3), (False,)) @testing.requires.predictable_gc - def test_checkin_event_gc(self): + def test_checkin_event_gc(self, detach_gced): p, canary = self._checkin_event_fixture() + if detach_gced: + p._is_asyncio = True + c1 = p.connect() + + dbapi_connection = weakref.ref(c1.connection) + eq_(canary, []) del c1 lazy_gc() - eq_(canary, ["checkin"]) + + if detach_gced: + # "close_detached" is not called because for asyncio the + # connection is just lost. + eq_(canary, ["detach"]) + + else: + eq_(canary, ["checkin"]) + + gc_collect() + if detach_gced: + is_none(dbapi_connection()) + else: + is_not_none(dbapi_connection()) def test_checkin_event_on_subsequently_recreated(self): p, canary = self._checkin_event_fixture() @@ -744,7 +773,7 @@ class PoolEventsTest(PoolTestBase): eq_(conn.info["important_flag"], True) conn.close() - def teardown(self): + def teardown_test(self): # TODO: need to get remove() functionality # going pool.Pool.dispatch._clear() @@ -1490,12 +1519,16 @@ class QueuePoolTest(PoolTestBase): self._assert_cleanup_on_pooled_reconnect(dbapi, p) + @testing.combinations((True, testing.requires.python3), (False,)) @testing.requires.predictable_gc - def test_userspace_disconnectionerror_weakref_finalizer(self): + def test_userspace_disconnectionerror_weakref_finalizer(self, detach_gced): dbapi, pool = self._queuepool_dbapi_fixture( pool_size=1, max_overflow=2 ) + if detach_gced: + pool._is_asyncio = True + @event.listens_for(pool, "checkout") def handle_checkout_event(dbapi_con, con_record, con_proxy): if getattr(dbapi_con, "boom") == "yes": @@ -1514,8 +1547,12 @@ class QueuePoolTest(PoolTestBase): del conn gc_collect() - # new connection was reset on return appropriately - eq_(dbapi_conn.mock_calls, [call.rollback()]) + if detach_gced: + # new connection was detached + abandoned on return + eq_(dbapi_conn.mock_calls, []) + else: + # new connection reset and returned to pool + eq_(dbapi_conn.mock_calls, [call.rollback()]) # old connection was just closed - did not get an # erroneous reset on return diff --git a/test/engine/test_processors.py b/test/engine/test_processors.py index 3810de06a5..5a4220c827 100644 --- a/test/engine/test_processors.py +++ b/test/engine/test_processors.py @@ -25,7 +25,7 @@ class CBooleanProcessorTest(_BooleanProcessorTest): __requires__ = ("cextensions",) @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy import cprocessors cls.module = cprocessors @@ -83,7 +83,7 @@ class _DateProcessorTest(fixtures.TestBase): class PyDateProcessorTest(_DateProcessorTest): @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy import processors cls.module = type( @@ -100,7 +100,7 @@ class CDateProcessorTest(_DateProcessorTest): __requires__ = ("cextensions",) @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy import cprocessors cls.module = cprocessors @@ -185,7 +185,7 @@ class _DistillArgsTest(fixtures.TestBase): class PyDistillArgsTest(_DistillArgsTest): @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy.engine import util cls.module = type( @@ -202,7 +202,7 @@ class CDistillArgsTest(_DistillArgsTest): __requires__ = ("cextensions",) @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy import cutils as util cls.module = util diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 5fe7f6cc2a..7a64b25508 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -162,7 +162,7 @@ def MockDBAPI(): class PrePingMockTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.dbapi = MockDBAPI() def _pool_fixture(self, pre_ping, pool_kw=None): @@ -182,7 +182,7 @@ class PrePingMockTest(fixtures.TestBase): ) return _pool - def teardown(self): + def teardown_test(self): self.dbapi.dispose() def test_ping_not_on_first_connect(self): @@ -357,7 +357,7 @@ class PrePingMockTest(fixtures.TestBase): class MockReconnectTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.dbapi = MockDBAPI() self.db = testing_engine( @@ -373,7 +373,7 @@ class MockReconnectTest(fixtures.TestBase): e, MockDisconnect ) - def teardown(self): + def teardown_test(self): self.dbapi.dispose() def test_reconnect(self): @@ -1004,10 +1004,10 @@ class RealReconnectTest(fixtures.TestBase): __backend__ = True __requires__ = "graceful_disconnects", "ad_hoc_engines" - def setup(self): + def setup_test(self): self.engine = engines.reconnecting_engine() - def teardown(self): + def teardown_test(self): self.engine.dispose() def test_reconnect(self): @@ -1336,7 +1336,7 @@ class PrePingRealTest(fixtures.TestBase): class InvalidateDuringResultTest(fixtures.TestBase): __backend__ = True - def setup(self): + def setup_test(self): self.engine = engines.reconnecting_engine() self.meta = MetaData() table = Table( @@ -1353,7 +1353,7 @@ class InvalidateDuringResultTest(fixtures.TestBase): [{"id": i, "name": "row %d" % i} for i in range(1, 100)], ) - def teardown(self): + def teardown_test(self): with self.engine.begin() as conn: self.meta.drop_all(conn) self.engine.dispose() @@ -1470,7 +1470,7 @@ class ReconnectRecipeTest(fixtures.TestBase): __backend__ = True - def setup(self): + def setup_test(self): self.engine = engines.reconnecting_engine( options=dict(future=self.future) ) @@ -1483,7 +1483,7 @@ class ReconnectRecipeTest(fixtures.TestBase): ) self.meta.create_all(self.engine) - def teardown(self): + def teardown_test(self): self.meta.drop_all(self.engine) self.engine.dispose() diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 658cdd79f0..0a46ddeeca 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -796,7 +796,7 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): assert f1 in b1.constraints assert len(b1.constraints) == 2 - def test_override_keys(self, connection, metadata): + def test_override_keys(self, metadata, connection): """test that columns can be overridden with a 'key', and that ForeignKey targeting during reflection still works.""" @@ -1375,7 +1375,7 @@ class CreateDropTest(fixtures.TablesTest): run_create_tables = None @classmethod - def teardown_class(cls): + def teardown_test_class(cls): # TablesTest is used here without # run_create_tables, so add an explicit drop of whatever is in # metadata @@ -1658,7 +1658,6 @@ class SchemaTest(fixtures.TestBase): @testing.requires.schemas @testing.requires.cross_schema_fk_reflection @testing.requires.implicit_default_schema - @testing.provide_metadata def test_blank_schema_arg(self, connection, metadata): Table( @@ -1913,7 +1912,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL): __backend__ = True @testing.requires.denormalized_names - def setup(self): + def setup_test(self): with testing.db.begin() as conn: conn.exec_driver_sql( """ @@ -1926,7 +1925,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL): ) @testing.requires.denormalized_names - def teardown(self): + def teardown_test(self): with testing.db.begin() as conn: conn.exec_driver_sql("drop table weird_casing") diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index 79126fc5bb..47504b60a3 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -1,6 +1,5 @@ import sys -from sqlalchemy import create_engine from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import func @@ -640,12 +639,12 @@ class AutoRollbackTest(fixtures.TestBase): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): global metadata metadata = MetaData() @classmethod - def teardown_class(cls): + def teardown_test_class(cls): metadata.drop_all(testing.db) def test_rollback_deadlock(self): @@ -871,11 +870,13 @@ class IsolationLevelTest(fixtures.TestBase): def test_per_engine(self): # new in 0.9 - eng = create_engine( + eng = testing_engine( testing.db.url, - execution_options={ - "isolation_level": self._non_default_isolation_level() - }, + options=dict( + execution_options={ + "isolation_level": self._non_default_isolation_level() + } + ), ) conn = eng.connect() eq_( @@ -884,7 +885,7 @@ class IsolationLevelTest(fixtures.TestBase): ) def test_per_option_engine(self): - eng = create_engine(testing.db.url).execution_options( + eng = testing_engine(testing.db.url).execution_options( isolation_level=self._non_default_isolation_level() ) @@ -895,14 +896,14 @@ class IsolationLevelTest(fixtures.TestBase): ) def test_isolation_level_accessors_connection_default(self): - eng = create_engine(testing.db.url) + eng = testing_engine(testing.db.url) with eng.connect() as conn: eq_(conn.default_isolation_level, self._default_isolation_level()) with eng.connect() as conn: eq_(conn.get_isolation_level(), self._default_isolation_level()) def test_isolation_level_accessors_connection_option_modified(self): - eng = create_engine(testing.db.url) + eng = testing_engine(testing.db.url) with eng.connect() as conn: c2 = conn.execution_options( isolation_level=self._non_default_isolation_level() diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index 7dae1411e5..59a44f8e2e 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -269,7 +269,7 @@ class AsyncEngineTest(EngineFixture): await trans.rollback(), @async_test - async def test_pool_exhausted(self, async_engine): + async def test_pool_exhausted_some_timeout(self, async_engine): engine = create_async_engine( testing.db.url, pool_size=1, @@ -277,7 +277,19 @@ class AsyncEngineTest(EngineFixture): pool_timeout=0.1, ) async with engine.connect(): - with expect_raises(asyncio.TimeoutError): + with expect_raises(exc.TimeoutError): + await engine.connect() + + @async_test + async def test_pool_exhausted_no_timeout(self, async_engine): + engine = create_async_engine( + testing.db.url, + pool_size=1, + max_overflow=0, + pool_timeout=0, + ) + async with engine.connect(): + with expect_raises(exc.TimeoutError): await engine.connect() @async_test diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index 2b80b753eb..e25e7cfc29 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -27,11 +27,11 @@ Base = None class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): - def setup(self): + def setup_test(self): global Base Base = decl.declarative_base(testing.db) - def teardown(self): + def teardown_test(self): close_all_sessions() clear_mappers() Base.metadata.drop_all(testing.db) diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py index d7fcbf9e8e..c327de7d4f 100644 --- a/test/ext/declarative/test_reflection.py +++ b/test/ext/declarative/test_reflection.py @@ -4,7 +4,6 @@ from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.ext.declarative import DeferredReflection from sqlalchemy.orm import clear_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import decl_api as decl from sqlalchemy.orm import declared_attr from sqlalchemy.orm import exc as orm_exc @@ -14,6 +13,7 @@ from sqlalchemy.orm.decl_base import _DeferredMapperConfig from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect @@ -22,20 +22,19 @@ from sqlalchemy.testing.util import gc_collect class DeclarativeReflectionBase(fixtures.TablesTest): __requires__ = ("reflectable_autoincrement",) - def setup(self): + def setup_test(self): global Base, registry registry = decl.registry() Base = registry.generate_base() - def teardown(self): - super(DeclarativeReflectionBase, self).teardown() + def teardown_test(self): clear_mappers() class DeferredReflectBase(DeclarativeReflectionBase): - def teardown(self): - super(DeferredReflectBase, self).teardown() + def teardown_test(self): + super(DeferredReflectBase, self).teardown_test() _DeferredMapperConfig._configs.clear() @@ -101,22 +100,23 @@ class DeferredReflectionTest(DeferredReflectBase): u1 = User( name="u1", addresses=[Address(email="one"), Address(email="two")] ) - sess = create_session(testing.db) - sess.add(u1) - sess.flush() - sess.expunge_all() - eq_( - sess.query(User).all(), - [ - User( - name="u1", - addresses=[Address(email="one"), Address(email="two")], - ) - ], - ) - a1 = sess.query(Address).filter(Address.email == "two").one() - eq_(a1, Address(email="two")) - eq_(a1.user, User(name="u1")) + with fixture_session() as sess: + sess.add(u1) + sess.commit() + + with fixture_session() as sess: + eq_( + sess.query(User).all(), + [ + User( + name="u1", + addresses=[Address(email="one"), Address(email="two")], + ) + ], + ) + a1 = sess.query(Address).filter(Address.email == "two").one() + eq_(a1, Address(email="two")) + eq_(a1.user, User(name="u1")) def test_exception_prepare_not_called(self): class User(DeferredReflection, fixtures.ComparableEntity, Base): @@ -191,15 +191,25 @@ class DeferredReflectionTest(DeferredReflectBase): return {"primary_key": cls.__table__.c.id} DeferredReflection.prepare(testing.db) - sess = Session(testing.db) - sess.add_all( - [User(name="G"), User(name="Q"), User(name="A"), User(name="C")] - ) - sess.commit() - eq_( - sess.query(User).order_by(User.name).all(), - [User(name="A"), User(name="C"), User(name="G"), User(name="Q")], - ) + with fixture_session() as sess: + sess.add_all( + [ + User(name="G"), + User(name="Q"), + User(name="A"), + User(name="C"), + ] + ) + sess.commit() + eq_( + sess.query(User).order_by(User.name).all(), + [ + User(name="A"), + User(name="C"), + User(name="G"), + User(name="Q"), + ], + ) @testing.requires.predictable_gc def test_cls_not_strong_ref(self): @@ -255,14 +265,14 @@ class DeferredSecondaryReflectionTest(DeferredReflectBase): u1 = User(name="u1", items=[Item(name="i1"), Item(name="i2")]) - sess = Session(testing.db) - sess.add(u1) - sess.commit() + with fixture_session() as sess: + sess.add(u1) + sess.commit() - eq_( - sess.query(User).all(), - [User(name="u1", items=[Item(name="i1"), Item(name="i2")])], - ) + eq_( + sess.query(User).all(), + [User(name="u1", items=[Item(name="i1"), Item(name="i2")])], + ) def test_string_resolution(self): class User(DeferredReflection, fixtures.ComparableEntity, Base): @@ -296,27 +306,26 @@ class DeferredInhReflectBase(DeferredReflectBase): Foo = Base.registry._class_registry["Foo"] Bar = Base.registry._class_registry["Bar"] - s = Session(testing.db) - - s.add_all( - [ - Bar(data="d1", bar_data="b1"), - Bar(data="d2", bar_data="b2"), - Bar(data="d3", bar_data="b3"), - Foo(data="d4"), - ] - ) - s.commit() - - eq_( - s.query(Foo).order_by(Foo.id).all(), - [ - Bar(data="d1", bar_data="b1"), - Bar(data="d2", bar_data="b2"), - Bar(data="d3", bar_data="b3"), - Foo(data="d4"), - ], - ) + with fixture_session() as s: + s.add_all( + [ + Bar(data="d1", bar_data="b1"), + Bar(data="d2", bar_data="b2"), + Bar(data="d3", bar_data="b3"), + Foo(data="d4"), + ] + ) + s.commit() + + eq_( + s.query(Foo).order_by(Foo.id).all(), + [ + Bar(data="d1", bar_data="b1"), + Bar(data="d2", bar_data="b2"), + Bar(data="d3", bar_data="b3"), + Foo(data="d4"), + ], + ) class DeferredSingleInhReflectionTest(DeferredInhReflectBase): diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index b1f5cc956f..31ae050c11 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -101,7 +101,7 @@ class AutoFlushTest(fixtures.TablesTest): Column("name", String(50)), ) - def teardown(self): + def teardown_test(self): clear_mappers() def _fixture(self, collection_class, is_dict=False): @@ -198,7 +198,7 @@ class AutoFlushTest(fixtures.TablesTest): class _CollectionOperations(fixtures.TestBase): - def setup(self): + def setup_test(self): collection_class = self.collection_class metadata = MetaData() @@ -260,7 +260,7 @@ class _CollectionOperations(fixtures.TestBase): self.session = fixture_session() self.Parent, self.Child = Parent, Child - def teardown(self): + def teardown_test(self): self.metadata.drop_all(testing.db) def roundtrip(self, obj): @@ -885,7 +885,7 @@ class CustomObjectTest(_CollectionOperations): class ProxyFactoryTest(ListTest): - def setup(self): + def setup_test(self): metadata = MetaData() parents_table = Table( @@ -1157,7 +1157,7 @@ class ScalarTest(fixtures.TestBase): class LazyLoadTest(fixtures.TestBase): - def setup(self): + def setup_test(self): metadata = MetaData() parents_table = Table( @@ -1197,7 +1197,7 @@ class LazyLoadTest(fixtures.TestBase): self.Parent, self.Child = Parent, Child self.table = parents_table - def teardown(self): + def teardown_test(self): self.metadata.drop_all(testing.db) def roundtrip(self, obj): @@ -2294,7 +2294,7 @@ class ComparatorTest(fixtures.MappedTest, AssertsCompiledSQL): class DictOfTupleUpdateTest(fixtures.TestBase): - def setup(self): + def setup_test(self): class B(object): def __init__(self, key, elem): self.key = key @@ -2434,7 +2434,7 @@ class CompositeAccessTest(fixtures.DeclarativeMappedTest): class AttributeAccessTest(fixtures.TestBase): - def teardown(self): + def teardown_test(self): clear_mappers() def test_resolve_aliased_class(self): diff --git a/test/ext/test_baked.py b/test/ext/test_baked.py index 71fabc629f..2d4e9848e5 100644 --- a/test/ext/test_baked.py +++ b/test/ext/test_baked.py @@ -27,7 +27,7 @@ class BakedTest(_fixtures.FixtureTest): run_inserts = "once" run_deletes = None - def setup(self): + def setup_test(self): self.bakery = baked.bakery() diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index 058c1dfd77..d011417d77 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -426,7 +426,7 @@ class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" - def teardown(self): + def teardown_test(self): for cls in (Select, BindParameter): deregister(cls) diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py index ad9bf0bc05..f3eceb0dca 100644 --- a/test/ext/test_extendedattr.py +++ b/test/ext/test_extendedattr.py @@ -32,7 +32,7 @@ def modifies_instrumentation_finders(fn, *args, **kw): class _ExtBase(object): @classmethod - def teardown_class(cls): + def teardown_test_class(cls): instrumentation._reinstall_default_lookups() @@ -89,7 +89,7 @@ MyBaseClass, MyClass = None, None class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): @classmethod - def setup_class(cls): + def setup_test_class(cls): global MyBaseClass, MyClass class MyBaseClass(object): @@ -143,7 +143,7 @@ class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest): else: del self._goofy_dict[key] - def teardown(self): + def teardown_test(self): clear_mappers() def test_instance_dict(self): diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index 038bdd83e1..bb06d9648b 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -19,7 +19,6 @@ from sqlalchemy import update from sqlalchemy import util from sqlalchemy.ext.horizontal_shard import ShardedSession from sqlalchemy.orm import clear_mappers -from sqlalchemy.orm import create_session from sqlalchemy.orm import deferred from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship @@ -42,7 +41,7 @@ class ShardTest(object): schema = None - def setUp(self): + def setup_test(self): global db1, db2, db3, db4, weather_locations, weather_reports db1, db2, db3, db4 = self._dbs = self._init_dbs() @@ -88,7 +87,7 @@ class ShardTest(object): @classmethod def setup_session(cls): - global create_session + global sharded_session shard_lookup = { "North America": "north_america", "Asia": "asia", @@ -128,10 +127,10 @@ class ShardTest(object): else: return ids - create_session = sessionmaker( + sharded_session = sessionmaker( class_=ShardedSession, autoflush=True, autocommit=False ) - create_session.configure( + sharded_session.configure( shards={ "north_america": db1, "asia": db2, @@ -180,7 +179,7 @@ class ShardTest(object): tokyo.reports.append(Report(80.0, id_=1)) newyork.reports.append(Report(75, id_=1)) quito.reports.append(Report(85)) - sess = create_session(future=True) + sess = sharded_session(future=True) for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: sess.add(c) sess.flush() @@ -671,11 +670,10 @@ class DistinctEngineShardTest(ShardTest, fixtures.TestBase): self.dbs = [db1, db2, db3, db4] return self.dbs - def teardown(self): + def teardown_test(self): clear_mappers() - for db in self.dbs: - db.connect().invalidate() + testing_reaper.checkin_all() for i in range(1, 5): os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) @@ -702,10 +700,10 @@ class AttachedFileShardTest(ShardTest, fixtures.TestBase): self.engine = e return db1, db2, db3, db4 - def teardown(self): + def teardown_test(self): clear_mappers() - self.engine.connect().invalidate() + testing_reaper.checkin_all() for i in range(1, 5): os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) @@ -778,10 +776,13 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase): self.postgresql_engine = e2 return db1, db2, db3, db4 - def teardown(self): + def teardown_test(self): clear_mappers() - self.sqlite_engine.connect().invalidate() + # the tests in this suite don't cleanly close out the Session + # at the moment so use the reaper to close all connections + testing_reaper.checkin_all() + for i in [1, 3]: os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) @@ -789,6 +790,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase): self.tables_test_metadata.drop_all(conn) for i in [2, 4]: conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,)) + self.postgresql_engine.dispose() class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest): @@ -904,11 +906,11 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): return self.dbs - def teardown(self): + def teardown_test(self): for db in self.dbs: db.connect().invalidate() - testing_reaper.close_all() + testing_reaper.checkin_all() for i in range(1, 3): os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index 048a8b52d1..3bab7db934 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -697,7 +697,7 @@ class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy import literal symbols = ("usd", "gbp", "cad", "eur", "aud") diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index eba2ac0cbb..21244de73d 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -90,11 +90,10 @@ class _MutableDictTestFixture(object): def _type_fixture(cls): return MutableDict - def teardown(self): + def teardown_test(self): # clear out mapper events Mapper.dispatch._clear() ClassManager.dispatch._clear() - super(_MutableDictTestFixture, self).teardown() class _MutableDictTestBase(_MutableDictTestFixture): @@ -312,11 +311,10 @@ class _MutableListTestFixture(object): def _type_fixture(cls): return MutableList - def teardown(self): + def teardown_test(self): # clear out mapper events Mapper.dispatch._clear() ClassManager.dispatch._clear() - super(_MutableListTestFixture, self).teardown() class _MutableListTestBase(_MutableListTestFixture): @@ -619,11 +617,10 @@ class _MutableSetTestFixture(object): def _type_fixture(cls): return MutableSet - def teardown(self): + def teardown_test(self): # clear out mapper events Mapper.dispatch._clear() ClassManager.dispatch._clear() - super(_MutableSetTestFixture, self).teardown() class _MutableSetTestBase(_MutableSetTestFixture): @@ -1234,17 +1231,15 @@ class _CompositeTestBase(object): Column("unrelated_data", String(50)), ) - def setup(self): + def setup_test(self): from sqlalchemy.ext import mutable mutable._setup_composite_listener() - super(_CompositeTestBase, self).setup() - def teardown(self): + def teardown_test(self): # clear out mapper events Mapper.dispatch._clear() ClassManager.dispatch._clear() - super(_CompositeTestBase, self).teardown() @classmethod def _type_fixture(cls): diff --git a/test/ext/test_orderinglist.py b/test/ext/test_orderinglist.py index f23d6cb576..280fad6cf0 100644 --- a/test/ext/test_orderinglist.py +++ b/test/ext/test_orderinglist.py @@ -8,7 +8,7 @@ from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing.fixtures import create_session +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import picklers @@ -60,7 +60,7 @@ def alpha_ordering(index, collection): class OrderingListTest(fixtures.TestBase): - def setup(self): + def setup_test(self): global metadata, slides_table, bullets_table, Slide, Bullet slides_table, bullets_table = None, None Slide, Bullet = None, None @@ -122,7 +122,7 @@ class OrderingListTest(fixtures.TestBase): metadata.create_all(testing.db) - def teardown(self): + def teardown_test(self): metadata.drop_all(testing.db) def test_append_no_reorder(self): @@ -167,7 +167,7 @@ class OrderingListTest(fixtures.TestBase): self.assert_(s1.bullets[2].position == 3) self.assert_(s1.bullets[3].position == 4) - session = create_session() + session = fixture_session() session.add(s1) session.flush() @@ -232,7 +232,7 @@ class OrderingListTest(fixtures.TestBase): s1.bullets._reorder() self.assert_(s1.bullets[4].position == 5) - session = create_session() + session = fixture_session() session.add(s1) session.flush() @@ -289,7 +289,7 @@ class OrderingListTest(fixtures.TestBase): self.assert_(len(s1.bullets) == 6) self.assert_(s1.bullets[5].position == 5) - session = create_session() + session = fixture_session() session.add(s1) session.flush() @@ -338,7 +338,7 @@ class OrderingListTest(fixtures.TestBase): self.assert_(s1.bullets[li].position == li) self.assert_(s1.bullets[li] == b[bi]) - session = create_session() + session = fixture_session() session.add(s1) session.flush() @@ -365,7 +365,7 @@ class OrderingListTest(fixtures.TestBase): self.assert_(len(s1.bullets) == 3) self.assert_(s1.bullets[2].position == 2) - session = create_session() + session = fixture_session() session.add(s1) session.flush() diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 4c005d336c..4d9162105b 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -61,11 +61,11 @@ class DeclarativeTestBase( ): __dialect__ = "default" - def setup(self): + def setup_test(self): global Base Base = declarative_base(testing.db) - def teardown(self): + def teardown_test(self): close_all_sessions() clear_mappers() Base.metadata.drop_all(testing.db) diff --git a/test/orm/declarative/test_concurrency.py b/test/orm/declarative/test_concurrency.py index 5f12d82723..ecddc2e5fa 100644 --- a/test/orm/declarative/test_concurrency.py +++ b/test/orm/declarative/test_concurrency.py @@ -17,7 +17,7 @@ from sqlalchemy.testing.fixtures import fixture_session class ConcurrentUseDeclMappingTest(fixtures.TestBase): - def teardown(self): + def teardown_test(self): clear_mappers() @classmethod diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py index cc29cab7de..e09b1570e2 100644 --- a/test/orm/declarative/test_inheritance.py +++ b/test/orm/declarative/test_inheritance.py @@ -27,11 +27,11 @@ Base = None class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): - def setup(self): + def setup_test(self): global Base Base = decl.declarative_base(testing.db) - def teardown(self): + def teardown_test(self): close_all_sessions() clear_mappers() Base.metadata.drop_all(testing.db) diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index 631527daf9..ad4832c357 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -38,13 +38,13 @@ mapper_registry = None class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): - def setup(self): + def setup_test(self): global Base, mapper_registry mapper_registry = registry(metadata=MetaData()) Base = mapper_registry.generate_base() - def teardown(self): + def teardown_test(self): close_all_sessions() clear_mappers() with testing.db.begin() as conn: diff --git a/test/orm/declarative/test_reflection.py b/test/orm/declarative/test_reflection.py index 241528c44e..e7b2a70588 100644 --- a/test/orm/declarative/test_reflection.py +++ b/test/orm/declarative/test_reflection.py @@ -17,14 +17,13 @@ from sqlalchemy.testing.schema import Table class DeclarativeReflectionBase(fixtures.TablesTest): __requires__ = ("reflectable_autoincrement",) - def setup(self): + def setup_test(self): global Base, registry registry = decl.registry(metadata=MetaData()) Base = registry.generate_base() - def teardown(self): - super(DeclarativeReflectionBase, self).teardown() + def teardown_test(self): clear_mappers() diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index bdcdedc44e..da07b4941b 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -31,7 +31,6 @@ from sqlalchemy.orm import synonym from sqlalchemy.orm.util import instance_str from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message -from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -1889,7 +1888,6 @@ class VersioningTest(fixtures.MappedTest): @testing.emits_warning(r".*updated rowcount") @testing.requires.sane_rowcount_w_returning - @engines.close_open_connections def test_save_update(self): subtable, base, stuff = ( self.tables.subtable, @@ -2927,7 +2925,7 @@ class NoPKOnSubTableWarningTest(fixtures.TestBase): ) return parent, child - def tearDown(self): + def teardown_test(self): clear_mappers() def test_warning_on_sub(self): @@ -3417,27 +3415,26 @@ class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest): @classmethod def insert_data(cls, connection): Parent, A, B = cls.classes("Parent", "A", "B") - s = fixture_session() - - p1 = Parent(id=1) - p2 = Parent(id=2) - s.add_all([p1, p2]) - s.flush() + with Session(connection) as s: + p1 = Parent(id=1) + p2 = Parent(id=2) + s.add_all([p1, p2]) + s.flush() - s.add_all( - [ - A(id=1, parent_id=1), - B(id=2, parent_id=1), - A(id=3, parent_id=1), - B(id=4, parent_id=1), - ] - ) - s.flush() + s.add_all( + [ + A(id=1, parent_id=1), + B(id=2, parent_id=1), + A(id=3, parent_id=1), + B(id=4, parent_id=1), + ] + ) + s.flush() - s.query(A).filter(A.id.in_([3, 4])).update( - {A.type: None}, synchronize_session=False - ) - s.commit() + s.query(A).filter(A.id.in_([3, 4])).update( + {A.type: None}, synchronize_session=False + ) + s.commit() def test_pk_is_null(self): Parent, A = self.classes("Parent", "A") @@ -3527,10 +3524,12 @@ class UnexpectedPolymorphicIdentityTest(fixtures.DeclarativeMappedTest): ASingleSubA, ASingleSubB, AJoinedSubA, AJoinedSubB = cls.classes( "ASingleSubA", "ASingleSubB", "AJoinedSubA", "AJoinedSubB" ) - s = fixture_session() + with Session(connection) as s: - s.add_all([ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()]) - s.commit() + s.add_all( + [ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()] + ) + s.commit() def test_single_invalid_ident(self): ASingle, ASingleSubA = self.classes("ASingle", "ASingleSubA") diff --git a/test/orm/test_attributes.py b/test/orm/test_attributes.py index 8820aa6a45..0a0a5d12b7 100644 --- a/test/orm/test_attributes.py +++ b/test/orm/test_attributes.py @@ -209,7 +209,7 @@ class AttributeImplAPITest(fixtures.MappedTest): class AttributesTest(fixtures.ORMTest): - def setup(self): + def setup_test(self): global MyTest, MyTest2 class MyTest(object): @@ -218,7 +218,7 @@ class AttributesTest(fixtures.ORMTest): class MyTest2(object): pass - def teardown(self): + def teardown_test(self): global MyTest, MyTest2 MyTest, MyTest2 = None, None @@ -3690,7 +3690,7 @@ class EventPropagateTest(fixtures.TestBase): class CollectionInitTest(fixtures.TestBase): - def setUp(self): + def setup_test(self): class A(object): pass @@ -3749,7 +3749,7 @@ class CollectionInitTest(fixtures.TestBase): class TestUnlink(fixtures.TestBase): - def setUp(self): + def setup_test(self): class A(object): pass diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py index 2f54f7fff0..014fa152e9 100644 --- a/test/orm/test_bind.py +++ b/test/orm/test_bind.py @@ -428,39 +428,46 @@ class BindIntegrationTest(_fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - c = testing.db.connect() - - sess = Session(bind=c, autocommit=False) - u = User(name="u1") - sess.add(u) - sess.flush() - sess.close() - assert not c.in_transaction() - assert c.exec_driver_sql("select count(1) from users").scalar() == 0 - - sess = Session(bind=c, autocommit=False) - u = User(name="u2") - sess.add(u) - sess.flush() - sess.commit() - assert not c.in_transaction() - assert c.exec_driver_sql("select count(1) from users").scalar() == 1 - - with c.begin(): - c.exec_driver_sql("delete from users") - assert c.exec_driver_sql("select count(1) from users").scalar() == 0 - - c = testing.db.connect() - - trans = c.begin() - sess = Session(bind=c, autocommit=True) - u = User(name="u3") - sess.add(u) - sess.flush() - assert c.in_transaction() - trans.commit() - assert not c.in_transaction() - assert c.exec_driver_sql("select count(1) from users").scalar() == 1 + with testing.db.connect() as c: + + sess = Session(bind=c, autocommit=False) + u = User(name="u1") + sess.add(u) + sess.flush() + sess.close() + assert not c.in_transaction() + assert ( + c.exec_driver_sql("select count(1) from users").scalar() == 0 + ) + + sess = Session(bind=c, autocommit=False) + u = User(name="u2") + sess.add(u) + sess.flush() + sess.commit() + assert not c.in_transaction() + assert ( + c.exec_driver_sql("select count(1) from users").scalar() == 1 + ) + + with c.begin(): + c.exec_driver_sql("delete from users") + assert ( + c.exec_driver_sql("select count(1) from users").scalar() == 0 + ) + + with testing.db.connect() as c: + trans = c.begin() + sess = Session(bind=c, autocommit=True) + u = User(name="u3") + sess.add(u) + sess.flush() + assert c.in_transaction() + trans.commit() + assert not c.in_transaction() + assert ( + c.exec_driver_sql("select count(1) from users").scalar() == 1 + ) class SessionBindTest(fixtures.MappedTest): @@ -506,6 +513,7 @@ class SessionBindTest(fixtures.MappedTest): finally: if hasattr(bind, "close"): bind.close() + sess.close() def test_session_unbound(self): Foo = self.classes.Foo diff --git a/test/orm/test_collection.py b/test/orm/test_collection.py index 3d09bd4460..2a0aafbbcc 100644 --- a/test/orm/test_collection.py +++ b/test/orm/test_collection.py @@ -92,13 +92,12 @@ class CollectionsTest(OrderedDictFixture, fixtures.ORMTest): return str((id(self), self.a, self.b, self.c)) @classmethod - def setup_class(cls): + def setup_test_class(cls): instrumentation.register_class(cls.Entity) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): instrumentation.unregister_class(cls.Entity) - super(CollectionsTest, cls).teardown_class() _entity_id = 1 diff --git a/test/orm/test_compile.py b/test/orm/test_compile.py index df652daf45..20d8ecc2db 100644 --- a/test/orm/test_compile.py +++ b/test/orm/test_compile.py @@ -19,7 +19,7 @@ from sqlalchemy.testing import fixtures class CompileTest(fixtures.ORMTest): """test various mapper compilation scenarios""" - def teardown(self): + def teardown_test(self): clear_mappers() def test_with_polymorphic(self): diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index e1ef67fed6..ed11b89c9b 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -1743,8 +1743,7 @@ class PostUpdateOnUpdateTest(fixtures.DeclarativeMappedTest): id = Column(Integer, primary_key=True) a_id = Column(ForeignKey("a.id", name="a_fk")) - def setup(self): - super(PostUpdateOnUpdateTest, self).setup() + def setup_test(self): PostUpdateOnUpdateTest.counter = count() PostUpdateOnUpdateTest.db_counter = count() diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index 6d946cfe6e..15063ebe92 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -2199,6 +2199,7 @@ class SessionTest(fixtures.RemovesEvents, _LocalFixture): class AutocommitClosesOnFailTest(fixtures.MappedTest): __requires__ = ("deferrable_fks",) + __only_on__ = ("postgresql+psycopg2",) # needs #5824 for asyncpg @classmethod def define_tables(cls, metadata): @@ -4498,44 +4499,49 @@ class JoinTest(QueryTest, AssertsCompiledSQL): warnings += (join_aliased_dep,) # load a user who has an order that contains item id 3 and address # id 1 (order 3, owned by jack) - with testing.expect_deprecated_20(*warnings): - result = ( - fixture_session() - .query(User) - .join("orders", "items", aliased=aliased_) - .filter_by(id=3) - .reset_joinpoint() - .join("orders", "address", aliased=aliased_) - .filter_by(id=1) - .all() - ) - assert [User(id=7, name="jack")] == result - with testing.expect_deprecated_20(*warnings): - result = ( - fixture_session() - .query(User) - .join("orders", "items", aliased=aliased_, isouter=True) - .filter_by(id=3) - .reset_joinpoint() - .join("orders", "address", aliased=aliased_, isouter=True) - .filter_by(id=1) - .all() - ) - assert [User(id=7, name="jack")] == result - - with testing.expect_deprecated_20(*warnings): - result = ( - fixture_session() - .query(User) - .outerjoin("orders", "items", aliased=aliased_) - .filter_by(id=3) - .reset_joinpoint() - .outerjoin("orders", "address", aliased=aliased_) - .filter_by(id=1) - .all() - ) - assert [User(id=7, name="jack")] == result + with fixture_session() as sess: + with testing.expect_deprecated_20(*warnings): + result = ( + sess.query(User) + .join("orders", "items", aliased=aliased_) + .filter_by(id=3) + .reset_joinpoint() + .join("orders", "address", aliased=aliased_) + .filter_by(id=1) + .all() + ) + assert [User(id=7, name="jack")] == result + + with fixture_session() as sess: + with testing.expect_deprecated_20(*warnings): + result = ( + sess.query(User) + .join( + "orders", "items", aliased=aliased_, isouter=True + ) + .filter_by(id=3) + .reset_joinpoint() + .join( + "orders", "address", aliased=aliased_, isouter=True + ) + .filter_by(id=1) + .all() + ) + assert [User(id=7, name="jack")] == result + + with fixture_session() as sess: + with testing.expect_deprecated_20(*warnings): + result = ( + sess.query(User) + .outerjoin("orders", "items", aliased=aliased_) + .filter_by(id=3) + .reset_joinpoint() + .outerjoin("orders", "address", aliased=aliased_) + .filter_by(id=1) + .all() + ) + assert [User(id=7, name="jack")] == result class AliasFromCorrectLeftTest( diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index 4498fc1ff9..7eedb37c92 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -559,15 +559,15 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): 5, ), ]: - sess = fixture_session() + with fixture_session() as sess: - def go(): - eq_( - sess.query(User).options(*opt).order_by(User.id).all(), - self.static.user_item_keyword_result, - ) + def go(): + eq_( + sess.query(User).options(*opt).order_by(User.id).all(), + self.static.user_item_keyword_result, + ) - self.assert_sql_count(testing.db, go, count) + self.assert_sql_count(testing.db, go, count) def test_disable_dynamic(self): """test no joined option on a dynamic.""" diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 1c918a88cd..e85c23d6f6 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -45,13 +45,14 @@ from test.orm import _fixtures class _RemoveListeners(object): - def teardown(self): + @testing.fixture(autouse=True) + def _remove_listeners(self): + yield events.MapperEvents._clear() events.InstanceEvents._clear() events.SessionEvents._clear() events.InstrumentationEvents._clear() events.QueryEvents._clear() - super(_RemoveListeners, self).teardown() class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): @@ -1174,7 +1175,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest): argnames="target, event_name, fn", )(fn) - def teardown(self): + def teardown_test(self): A = self.classes.A A._sa_class_manager.dispatch._clear() diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index cc95964664..f622bff025 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -2395,16 +2395,19 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): ] adalias = addresses.alias() - q = ( - fixture_session() - .query(User) - .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name)) - .outerjoin(adalias, "addresses") - .group_by(users) - .order_by(users.c.id) - ) - assert q.all() == expected + with fixture_session() as sess: + q = ( + sess.query(User) + .add_columns( + func.count(adalias.c.id), ("Name:" + users.c.name) + ) + .outerjoin(adalias, "addresses") + .group_by(users) + .order_by(users.c.id) + ) + + eq_(q.all(), expected) # test with a straight statement s = ( @@ -2417,52 +2420,57 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): .group_by(*[c for c in users.c]) .order_by(users.c.id) ) - q = fixture_session().query(User) - result = ( - q.add_columns(s.selected_columns.count, s.selected_columns.concat) - .from_statement(s) - .all() - ) - assert result == expected - - sess.expunge_all() - # test with select_entity_from() - q = ( - fixture_session() - .query(User) - .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name)) - .select_entity_from(users.outerjoin(addresses)) - .group_by(users) - .order_by(users.c.id) - ) + with fixture_session() as sess: + q = sess.query(User) + result = ( + q.add_columns( + s.selected_columns.count, s.selected_columns.concat + ) + .from_statement(s) + .all() + ) + eq_(result, expected) - assert q.all() == expected - sess.expunge_all() + with fixture_session() as sess: + # test with select_entity_from() + q = ( + fixture_session() + .query(User) + .add_columns( + func.count(addresses.c.id), ("Name:" + users.c.name) + ) + .select_entity_from(users.outerjoin(addresses)) + .group_by(users) + .order_by(users.c.id) + ) - q = ( - fixture_session() - .query(User) - .add_columns(func.count(addresses.c.id), ("Name:" + users.c.name)) - .outerjoin("addresses") - .group_by(users) - .order_by(users.c.id) - ) + eq_(q.all(), expected) - assert q.all() == expected - sess.expunge_all() + with fixture_session() as sess: + q = ( + sess.query(User) + .add_columns( + func.count(addresses.c.id), ("Name:" + users.c.name) + ) + .outerjoin("addresses") + .group_by(users) + .order_by(users.c.id) + ) + eq_(q.all(), expected) - q = ( - fixture_session() - .query(User) - .add_columns(func.count(adalias.c.id), ("Name:" + users.c.name)) - .outerjoin(adalias, "addresses") - .group_by(users) - .order_by(users.c.id) - ) + with fixture_session() as sess: + q = ( + sess.query(User) + .add_columns( + func.count(adalias.c.id), ("Name:" + users.c.name) + ) + .outerjoin(adalias, "addresses") + .group_by(users) + .order_by(users.c.id) + ) - assert q.all() == expected - sess.expunge_all() + eq_(q.all(), expected) def test_expression_selectable_matches_mzero(self): User, Address = self.classes.User, self.classes.Address diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index 3061de309b..43cf81e6d4 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -717,24 +717,24 @@ class LazyTest(_fixtures.FixtureTest): ), ) - sess = fixture_session() + with fixture_session() as sess: - # load address - a1 = ( - sess.query(Address) - .filter_by(email_address="ed@wood.com") - .one() - ) + # load address + a1 = ( + sess.query(Address) + .filter_by(email_address="ed@wood.com") + .one() + ) - # load user that is attached to the address - u1 = sess.query(User).get(8) + # load user that is attached to the address + u1 = sess.query(User).get(8) - def go(): - # lazy load of a1.user should get it from the session - assert a1.user is u1 + def go(): + # lazy load of a1.user should get it from the session + assert a1.user is u1 - self.assert_sql_count(testing.db, go, 0) - sa.orm.clear_mappers() + self.assert_sql_count(testing.db, go, 0) + sa.orm.clear_mappers() def test_uses_get_compatible_types(self): """test the use_get optimization with compatible @@ -789,24 +789,23 @@ class LazyTest(_fixtures.FixtureTest): properties=dict(user=relationship(mapper(User, users))), ) - sess = fixture_session() - - # load address - a1 = ( - sess.query(Address) - .filter_by(email_address="ed@wood.com") - .one() - ) + with fixture_session() as sess: + # load address + a1 = ( + sess.query(Address) + .filter_by(email_address="ed@wood.com") + .one() + ) - # load user that is attached to the address - u1 = sess.query(User).get(8) + # load user that is attached to the address + u1 = sess.query(User).get(8) - def go(): - # lazy load of a1.user should get it from the session - assert a1.user is u1 + def go(): + # lazy load of a1.user should get it from the session + assert a1.user is u1 - self.assert_sql_count(testing.db, go, 0) - sa.orm.clear_mappers() + self.assert_sql_count(testing.db, go, 0) + sa.orm.clear_mappers() def test_many_to_one(self): users, Address, addresses, User = ( diff --git a/test/orm/test_load_on_fks.py b/test/orm/test_load_on_fks.py index 0e8ac97e3e..42b5b3e459 100644 --- a/test/orm/test_load_on_fks.py +++ b/test/orm/test_load_on_fks.py @@ -9,14 +9,12 @@ from sqlalchemy.orm import Session from sqlalchemy.orm.attributes import instance_state from sqlalchemy.testing import AssertsExecutionResults from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column -engine = testing.db - - class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase): - def setUp(self): + def setup_test(self): global Parent, Child, Base Base = declarative_base() @@ -36,27 +34,27 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase): ) parent_id = Column(Integer, ForeignKey("parent.id")) - Base.metadata.create_all(engine) + Base.metadata.create_all(testing.db) - def tearDown(self): - Base.metadata.drop_all(engine) + def teardown_test(self): + Base.metadata.drop_all(testing.db) def test_annoying_autoflush_one(self): - sess = Session(engine) + sess = fixture_session() p1 = Parent() sess.add(p1) p1.children = [] def test_annoying_autoflush_two(self): - sess = Session(engine) + sess = fixture_session() p1 = Parent() sess.add(p1) assert p1.children == [] def test_dont_load_if_no_keys(self): - sess = Session(engine) + sess = fixture_session() p1 = Parent() sess.add(p1) @@ -68,7 +66,9 @@ class FlushOnPendingTest(AssertsExecutionResults, fixtures.TestBase): class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): - def setUp(self): + __leave_connections_for_teardown__ = True + + def setup_test(self): global Parent, Child, Base Base = declarative_base() @@ -91,10 +91,10 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): parent = relationship(Parent, backref=backref("children")) - Base.metadata.create_all(engine) + Base.metadata.create_all(testing.db) global sess, p1, p2, c1, c2 - sess = Session(bind=engine) + sess = Session(bind=testing.db) p1 = Parent() p2 = Parent() @@ -105,9 +105,9 @@ class LoadOnFKsTest(AssertsExecutionResults, fixtures.TestBase): sess.commit() - def tearDown(self): + def teardown_test(self): sess.rollback() - Base.metadata.drop_all(engine) + Base.metadata.drop_all(testing.db) def test_load_on_pending_allows_backref_event(self): Child.parent.property.load_on_pending = True diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 013eb21e11..d182fd2c17 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -2560,7 +2560,7 @@ class MagicNamesTest(fixtures.MappedTest): class DocumentTest(fixtures.TestBase): - def setup(self): + def setup_test(self): self.mapper = registry().map_imperatively @@ -2624,14 +2624,14 @@ class DocumentTest(fixtures.TestBase): class ORMLoggingTest(_fixtures.FixtureTest): - def setup(self): + def setup_test(self): self.buf = logging.handlers.BufferingHandler(100) for log in [logging.getLogger("sqlalchemy.orm")]: log.addHandler(self.buf) self.mapper = registry().map_imperatively - def teardown(self): + def teardown_test(self): for log in [logging.getLogger("sqlalchemy.orm")]: log.removeHandler(self.buf) diff --git a/test/orm/test_options.py b/test/orm/test_options.py index b22b318e9b..6f47c1238a 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -1523,9 +1523,7 @@ class PickleTest(PathTest, QueryTest): class LocalOptsTest(PathTest, QueryTest): @classmethod - def setup_class(cls): - super(LocalOptsTest, cls).setup_class() - + def setup_test_class(cls): @strategy_options.loader_option() def some_col_opt_only(loadopt, key, opts): return loadopt.set_column_strategy( diff --git a/test/orm/test_query.py b/test/orm/test_query.py index fd8e849fb5..7546ba1626 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -6000,12 +6000,19 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): [User.orders_syn, Order.items_syn], [User.orders_syn_2, Order.items_syn], ): - q = fixture_session().query(User) - for path in j: - q = q.join(path) - q = q.filter_by(id=3) - result = q.all() - assert [User(id=7, name="jack"), User(id=9, name="fred")] == result + with fixture_session() as sess: + q = sess.query(User) + for path in j: + q = q.join(path) + q = q.filter_by(id=3) + result = q.all() + eq_( + result, + [ + User(id=7, name="jack"), + User(id=9, name="fred"), + ], + ) def test_with_parent(self): Order, User = self.classes.Order, self.classes.User @@ -6018,17 +6025,17 @@ class SynonymTest(QueryTest, AssertsCompiledSQL): ("name_syn", "orders_syn"), ("name_syn", "orders_syn_2"), ): - sess = fixture_session() - q = sess.query(User) + with fixture_session() as sess: + q = sess.query(User) - u1 = q.filter_by(**{nameprop: "jack"}).one() + u1 = q.filter_by(**{nameprop: "jack"}).one() - o = sess.query(Order).with_parent(u1, property=orderprop).all() - assert [ - Order(description="order 1"), - Order(description="order 3"), - Order(description="order 5"), - ] == o + o = sess.query(Order).with_parent(u1, property=orderprop).all() + assert [ + Order(description="order 1"), + Order(description="order 3"), + Order(description="order 5"), + ] == o def test_froms_aliased_col(self): Address, User = self.classes.Address, self.classes.User diff --git a/test/orm/test_rel_fn.py b/test/orm/test_rel_fn.py index 12c084b2d2..ef1bf2e603 100644 --- a/test/orm/test_rel_fn.py +++ b/test/orm/test_rel_fn.py @@ -26,7 +26,7 @@ from sqlalchemy.testing import mock class _JoinFixtures(object): @classmethod - def setup_class(cls): + def setup_test_class(cls): m = MetaData() cls.left = Table( "lft", diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index 5979f08ae6..8d73cd40e1 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -637,7 +637,7 @@ class OverlappingFksSiblingTest(fixtures.TestBase): """ - def teardown(self): + def teardown_test(self): clear_mappers() def _fixture_one( @@ -2474,7 +2474,7 @@ class JoinConditionErrorTest(fixtures.TestBase): assert_raises(sa.exc.ArgumentError, configure_mappers) - def teardown(self): + def teardown_test(self): clear_mappers() @@ -4354,7 +4354,7 @@ class AmbiguousFKResolutionTest(_RelationshipErrors, fixtures.MappedTest): class SecondaryArgTest(fixtures.TestBase): - def teardown(self): + def teardown_test(self): clear_mappers() @testing.combinations((True,), (False,)) diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py index 5535fe5d68..4895c7d3a0 100644 --- a/test/orm/test_selectin_relations.py +++ b/test/orm/test_selectin_relations.py @@ -713,35 +713,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def _do_query_tests(self, opts, count): Order, User = self.classes.Order, self.classes.User - sess = fixture_session() + with fixture_session() as sess: - def go(): - eq_( - sess.query(User).options(*opts).order_by(User.id).all(), - self.static.user_item_keyword_result, - ) + def go(): + eq_( + sess.query(User).options(*opts).order_by(User.id).all(), + self.static.user_item_keyword_result, + ) - self.assert_sql_count(testing.db, go, count) + self.assert_sql_count(testing.db, go, count) - eq_( - sess.query(User) - .options(*opts) - .filter(User.name == "fred") - .order_by(User.id) - .all(), - self.static.user_item_keyword_result[2:3], - ) + eq_( + sess.query(User) + .options(*opts) + .filter(User.name == "fred") + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[2:3], + ) - sess = fixture_session() - eq_( - sess.query(User) - .options(*opts) - .join(User.orders) - .filter(Order.id == 3) - .order_by(User.id) - .all(), - self.static.user_item_keyword_result[0:1], - ) + with fixture_session() as sess: + eq_( + sess.query(User) + .options(*opts) + .join(User.orders) + .filter(Order.id == 3) + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[0:1], + ) def test_cyclical(self): """A circular eager relationship breaks the cycle with a lazy loader""" diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 20c4752b82..3d4566af3e 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -1273,7 +1273,7 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): run_inserts = None - def setup(self): + def setup_test(self): mapper(self.classes.User, self.tables.users) def _assert_modified(self, u1): @@ -1288,11 +1288,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): def _assert_no_cycle(self, u1): assert sa.orm.attributes.instance_state(u1)._strong_obj is None - def _persistent_fixture(self): + def _persistent_fixture(self, gc_collect=False): User = self.classes.User u1 = User() u1.name = "ed" - sess = fixture_session() + if gc_collect: + sess = Session(testing.db) + else: + sess = fixture_session() sess.add(u1) sess.flush() return sess, u1 @@ -1389,14 +1392,14 @@ class NoCyclesOnTransientDetachedTest(_fixtures.FixtureTest): @testing.requires.predictable_gc def test_move_gc_session_persistent_dirty(self): - sess, u1 = self._persistent_fixture() + sess, u1 = self._persistent_fixture(gc_collect=True) u1.name = "edchanged" self._assert_cycle(u1) self._assert_modified(u1) del sess gc_collect() self._assert_cycle(u1) - s2 = fixture_session() + s2 = Session(testing.db) s2.add(u1) self._assert_cycle(u1) self._assert_modified(u1) @@ -1565,7 +1568,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): mapper(User, users) - sess = fixture_session() + sess = Session(testing.db) u1 = User(name="u1") sess.add(u1) @@ -1573,7 +1576,7 @@ class WeakIdentityMapTest(_fixtures.FixtureTest): # can't add u1 to Session, # already belongs to u2 - s2 = fixture_session() + s2 = Session(testing.db) assert_raises_message( sa.exc.InvalidRequestError, r".*is already attached to session", @@ -1725,11 +1728,10 @@ class DisposedStates(fixtures.MappedTest): mapper(T, cls.tables.t1) - def teardown(self): + def teardown_test(self): from sqlalchemy.orm.session import _sessions _sessions.clear() - super(DisposedStates, self).teardown() def _set_imap_in_disposal(self, sess, *objs): """remove selected objects from the given session, as though diff --git a/test/orm/test_subquery_relations.py b/test/orm/test_subquery_relations.py index fe20442a30..150cee2225 100644 --- a/test/orm/test_subquery_relations.py +++ b/test/orm/test_subquery_relations.py @@ -734,35 +734,35 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): def _do_query_tests(self, opts, count): Order, User = self.classes.Order, self.classes.User - sess = fixture_session() + with fixture_session() as sess: - def go(): - eq_( - sess.query(User).options(*opts).order_by(User.id).all(), - self.static.user_item_keyword_result, - ) + def go(): + eq_( + sess.query(User).options(*opts).order_by(User.id).all(), + self.static.user_item_keyword_result, + ) - self.assert_sql_count(testing.db, go, count) + self.assert_sql_count(testing.db, go, count) - eq_( - sess.query(User) - .options(*opts) - .filter(User.name == "fred") - .order_by(User.id) - .all(), - self.static.user_item_keyword_result[2:3], - ) + eq_( + sess.query(User) + .options(*opts) + .filter(User.name == "fred") + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[2:3], + ) - sess = fixture_session() - eq_( - sess.query(User) - .options(*opts) - .join(User.orders) - .filter(Order.id == 3) - .order_by(User.id) - .all(), - self.static.user_item_keyword_result[0:1], - ) + with fixture_session() as sess: + eq_( + sess.query(User) + .options(*opts) + .join(User.orders) + .filter(Order.id == 3) + .order_by(User.id) + .all(), + self.static.user_item_keyword_result[0:1], + ) def test_cyclical(self): """A circular eager relationship breaks the cycle with a lazy loader""" diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index 550cf6535b..7f77b01c78 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -66,17 +66,18 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - conn = testing.db.connect() - trans = conn.begin() - sess = Session(bind=conn, autocommit=False, autoflush=True) - sess.begin(subtransactions=True) - u = User(name="ed") - sess.add(u) - sess.flush() - sess.commit() # commit does nothing - trans.rollback() # rolls back - assert len(sess.query(User).all()) == 0 - sess.close() + + with testing.db.connect() as conn: + trans = conn.begin() + sess = Session(bind=conn, autocommit=False, autoflush=True) + sess.begin(subtransactions=True) + u = User(name="ed") + sess.add(u) + sess.flush() + sess.commit() # commit does nothing + trans.rollback() # rolls back + assert len(sess.query(User).all()) == 0 + sess.close() @engines.close_open_connections def test_subtransaction_on_external_no_begin(self): @@ -260,34 +261,33 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): users = self.tables.users engine = Engine._future_facade(testing.db) - session = Session(engine, autocommit=False) - - session.begin() - session.connection().execute(users.insert().values(name="user1")) - session.begin(subtransactions=True) - session.begin_nested() - session.connection().execute(users.insert().values(name="user2")) - assert ( - session.connection() - .exec_driver_sql("select count(1) from users") - .scalar() - == 2 - ) - session.rollback() - assert ( - session.connection() - .exec_driver_sql("select count(1) from users") - .scalar() - == 1 - ) - session.connection().execute(users.insert().values(name="user3")) - session.commit() - assert ( - session.connection() - .exec_driver_sql("select count(1) from users") - .scalar() - == 2 - ) + with Session(engine, autocommit=False) as session: + session.begin() + session.connection().execute(users.insert().values(name="user1")) + session.begin(subtransactions=True) + session.begin_nested() + session.connection().execute(users.insert().values(name="user2")) + assert ( + session.connection() + .exec_driver_sql("select count(1) from users") + .scalar() + == 2 + ) + session.rollback() + assert ( + session.connection() + .exec_driver_sql("select count(1) from users") + .scalar() + == 1 + ) + session.connection().execute(users.insert().values(name="user3")) + session.commit() + assert ( + session.connection() + .exec_driver_sql("select count(1) from users") + .scalar() + == 2 + ) @testing.requires.savepoints def test_dirty_state_transferred_deep_nesting(self): @@ -295,27 +295,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - s = Session(testing.db) - u1 = User(name="u1") - s.add(u1) - s.commit() - - nt1 = s.begin_nested() - nt2 = s.begin_nested() - u1.name = "u2" - assert attributes.instance_state(u1) not in nt2._dirty - assert attributes.instance_state(u1) not in nt1._dirty - s.flush() - assert attributes.instance_state(u1) in nt2._dirty - assert attributes.instance_state(u1) not in nt1._dirty + with fixture_session() as s: + u1 = User(name="u1") + s.add(u1) + s.commit() + + nt1 = s.begin_nested() + nt2 = s.begin_nested() + u1.name = "u2" + assert attributes.instance_state(u1) not in nt2._dirty + assert attributes.instance_state(u1) not in nt1._dirty + s.flush() + assert attributes.instance_state(u1) in nt2._dirty + assert attributes.instance_state(u1) not in nt1._dirty - s.commit() - assert attributes.instance_state(u1) in nt2._dirty - assert attributes.instance_state(u1) in nt1._dirty + s.commit() + assert attributes.instance_state(u1) in nt2._dirty + assert attributes.instance_state(u1) in nt1._dirty - s.rollback() - assert attributes.instance_state(u1).expired - eq_(u1.name, "u1") + s.rollback() + assert attributes.instance_state(u1).expired + eq_(u1.name, "u1") @testing.requires.savepoints def test_dirty_state_transferred_deep_nesting_future(self): @@ -323,27 +323,27 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - s = Session(testing.db, future=True) - u1 = User(name="u1") - s.add(u1) - s.commit() - - nt1 = s.begin_nested() - nt2 = s.begin_nested() - u1.name = "u2" - assert attributes.instance_state(u1) not in nt2._dirty - assert attributes.instance_state(u1) not in nt1._dirty - s.flush() - assert attributes.instance_state(u1) in nt2._dirty - assert attributes.instance_state(u1) not in nt1._dirty + with fixture_session(future=True) as s: + u1 = User(name="u1") + s.add(u1) + s.commit() + + nt1 = s.begin_nested() + nt2 = s.begin_nested() + u1.name = "u2" + assert attributes.instance_state(u1) not in nt2._dirty + assert attributes.instance_state(u1) not in nt1._dirty + s.flush() + assert attributes.instance_state(u1) in nt2._dirty + assert attributes.instance_state(u1) not in nt1._dirty - nt2.commit() - assert attributes.instance_state(u1) in nt2._dirty - assert attributes.instance_state(u1) in nt1._dirty + nt2.commit() + assert attributes.instance_state(u1) in nt2._dirty + assert attributes.instance_state(u1) in nt1._dirty - nt1.rollback() - assert attributes.instance_state(u1).expired - eq_(u1.name, "u1") + nt1.rollback() + assert attributes.instance_state(u1).expired + eq_(u1.name, "u1") @testing.requires.independent_connections def test_transactions_isolated(self): @@ -1049,23 +1049,25 @@ class SessionTransactionTest(fixtures.RemovesEvents, FixtureTest): mapper(User, users) - session = Session(testing.db) + with fixture_session() as session: - with expect_warnings(".*during handling of a previous exception.*"): - session.begin_nested() - savepoint = session.connection()._nested_transaction._savepoint + with expect_warnings( + ".*during handling of a previous exception.*" + ): + session.begin_nested() + savepoint = session.connection()._nested_transaction._savepoint - # force the savepoint to disappear - session.connection().dialect.do_release_savepoint( - session.connection(), savepoint - ) + # force the savepoint to disappear + session.connection().dialect.do_release_savepoint( + session.connection(), savepoint + ) - # now do a broken flush - session.add_all([User(id=1), User(id=1)]) + # now do a broken flush + session.add_all([User(id=1), User(id=1)]) - assert_raises_message( - sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush - ) + assert_raises_message( + sa_exc.DBAPIError, "ROLLBACK TO SAVEPOINT ", session.flush + ) class _LocalFixture(FixtureTest): @@ -1170,39 +1172,40 @@ class SubtransactionRecipeTest(FixtureTest): def test_recipe_heavy_nesting(self, subtransaction_recipe): users = self.tables.users - session = Session(testing.db, future=self.future) - - with subtransaction_recipe(session): - session.connection().execute(users.insert().values(name="user1")) + with fixture_session(future=self.future) as session: with subtransaction_recipe(session): - savepoint = session.begin_nested() session.connection().execute( - users.insert().values(name="user2") + users.insert().values(name="user1") ) + with subtransaction_recipe(session): + savepoint = session.begin_nested() + session.connection().execute( + users.insert().values(name="user2") + ) + assert ( + session.connection() + .exec_driver_sql("select count(1) from users") + .scalar() + == 2 + ) + savepoint.rollback() + + with subtransaction_recipe(session): + assert ( + session.connection() + .exec_driver_sql("select count(1) from users") + .scalar() + == 1 + ) + session.connection().execute( + users.insert().values(name="user3") + ) assert ( session.connection() .exec_driver_sql("select count(1) from users") .scalar() == 2 ) - savepoint.rollback() - - with subtransaction_recipe(session): - assert ( - session.connection() - .exec_driver_sql("select count(1) from users") - .scalar() - == 1 - ) - session.connection().execute( - users.insert().values(name="user3") - ) - assert ( - session.connection() - .exec_driver_sql("select count(1) from users") - .scalar() - == 2 - ) @engines.close_open_connections def test_recipe_subtransaction_on_external_subtrans( @@ -1228,13 +1231,12 @@ class SubtransactionRecipeTest(FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = Session(testing.db, future=self.future) - - with subtransaction_recipe(sess): - u = User(name="u1") - sess.add(u) - sess.close() - assert len(sess.query(User).all()) == 1 + with fixture_session(future=self.future) as sess: + with subtransaction_recipe(sess): + u = User(name="u1") + sess.add(u) + sess.close() + assert len(sess.query(User).all()) == 1 def test_recipe_subtransaction_on_noautocommit( self, subtransaction_recipe @@ -1242,16 +1244,15 @@ class SubtransactionRecipeTest(FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - sess = Session(testing.db, future=self.future) - - sess.begin() - with subtransaction_recipe(sess): - u = User(name="u1") - sess.add(u) - sess.flush() - sess.rollback() # rolls back - assert len(sess.query(User).all()) == 0 - sess.close() + with fixture_session(future=self.future) as sess: + sess.begin() + with subtransaction_recipe(sess): + u = User(name="u1") + sess.add(u) + sess.flush() + sess.rollback() # rolls back + assert len(sess.query(User).all()) == 0 + sess.close() @testing.requires.savepoints def test_recipe_mixed_transaction_control(self, subtransaction_recipe): @@ -1259,30 +1260,28 @@ class SubtransactionRecipeTest(FixtureTest): mapper(User, users) - sess = Session(testing.db, future=self.future) + with fixture_session(future=self.future) as sess: - sess.begin() - sess.begin_nested() + sess.begin() + sess.begin_nested() - with subtransaction_recipe(sess): + with subtransaction_recipe(sess): - sess.add(User(name="u1")) + sess.add(User(name="u1")) - sess.commit() - sess.commit() + sess.commit() + sess.commit() - eq_(len(sess.query(User).all()), 1) - sess.close() + eq_(len(sess.query(User).all()), 1) + sess.close() - t1 = sess.begin() - t2 = sess.begin_nested() - - sess.add(User(name="u2")) + t1 = sess.begin() + t2 = sess.begin_nested() - t2.commit() - assert sess._legacy_transaction() is t1 + sess.add(User(name="u2")) - sess.close() + t2.commit() + assert sess._legacy_transaction() is t1 def test_recipe_error_on_using_inactive_session_commands( self, subtransaction_recipe @@ -1290,56 +1289,55 @@ class SubtransactionRecipeTest(FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - sess = Session(testing.db, future=self.future) - sess.begin() - - try: - with subtransaction_recipe(sess): - sess.add(User(name="u1")) - sess.flush() - raise Exception("force rollback") - except: - pass - - if self.recipe_rollsback_early: - # that was a real rollback, so no transaction - assert not sess.in_transaction() - is_(sess.get_transaction(), None) - else: - assert sess.in_transaction() - - sess.close() - assert not sess.in_transaction() - - def test_recipe_multi_nesting(self, subtransaction_recipe): - sess = Session(testing.db, future=self.future) - - with subtransaction_recipe(sess): - assert sess.in_transaction() + with fixture_session(future=self.future) as sess: + sess.begin() try: with subtransaction_recipe(sess): - assert sess._legacy_transaction() + sess.add(User(name="u1")) + sess.flush() raise Exception("force rollback") except: pass if self.recipe_rollsback_early: + # that was a real rollback, so no transaction assert not sess.in_transaction() + is_(sess.get_transaction(), None) else: assert sess.in_transaction() - assert not sess.in_transaction() + sess.close() + assert not sess.in_transaction() + + def test_recipe_multi_nesting(self, subtransaction_recipe): + with fixture_session(future=self.future) as sess: + with subtransaction_recipe(sess): + assert sess.in_transaction() + + try: + with subtransaction_recipe(sess): + assert sess._legacy_transaction() + raise Exception("force rollback") + except: + pass + + if self.recipe_rollsback_early: + assert not sess.in_transaction() + else: + assert sess.in_transaction() + + assert not sess.in_transaction() def test_recipe_deactive_status_check(self, subtransaction_recipe): - sess = Session(testing.db, future=self.future) - sess.begin() + with fixture_session(future=self.future) as sess: + sess.begin() - with subtransaction_recipe(sess): - sess.rollback() + with subtransaction_recipe(sess): + sess.rollback() - assert not sess.in_transaction() - sess.commit() # no error + assert not sess.in_transaction() + sess.commit() # no error class FixtureDataTest(_LocalFixture): @@ -1394,28 +1392,28 @@ class CleanSavepointTest(FixtureTest): mapper(User, users) - s = Session(bind=testing.db, future=future) - u1 = User(name="u1") - u2 = User(name="u2") - s.add_all([u1, u2]) - s.commit() - u1.name - u2.name - trans = s._transaction - assert trans is not None - s.begin_nested() - update_fn(s, u2) - eq_(u2.name, "u2modified") - s.rollback() + with fixture_session(future=future) as s: + u1 = User(name="u1") + u2 = User(name="u2") + s.add_all([u1, u2]) + s.commit() + u1.name + u2.name + trans = s._transaction + assert trans is not None + s.begin_nested() + update_fn(s, u2) + eq_(u2.name, "u2modified") + s.rollback() - if future: - assert s._transaction is None - assert "name" not in u1.__dict__ - else: - assert s._transaction is trans - eq_(u1.__dict__["name"], "u1") - assert "name" not in u2.__dict__ - eq_(u2.name, "u2") + if future: + assert s._transaction is None + assert "name" not in u1.__dict__ + else: + assert s._transaction is trans + eq_(u1.__dict__["name"], "u1") + assert "name" not in u2.__dict__ + eq_(u2.name, "u2") @testing.requires.savepoints def test_rollback_ignores_clean_on_savepoint(self): @@ -2116,82 +2114,108 @@ class ContextManagerPlusFutureTest(FixtureTest): eq_(sess.query(User).count(), 1) def test_explicit_begin(self): - s1 = Session(testing.db) - with s1.begin() as trans: - is_(trans, s1._legacy_transaction()) - s1.connection() + with fixture_session() as s1: + with s1.begin() as trans: + is_(trans, s1._legacy_transaction()) + s1.connection() - is_(s1._transaction, None) + is_(s1._transaction, None) def test_no_double_begin_explicit(self): - s1 = Session(testing.db) - s1.begin() - assert_raises_message( - sa_exc.InvalidRequestError, - "A transaction is already begun on this Session.", - s1.begin, - ) + with fixture_session() as s1: + s1.begin() + assert_raises_message( + sa_exc.InvalidRequestError, + "A transaction is already begun on this Session.", + s1.begin, + ) @testing.requires.savepoints def test_future_rollback_is_global(self): users = self.tables.users - s1 = Session(testing.db, future=True) + with fixture_session(future=True) as s1: + s1.begin() - s1.begin() + s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}]) - s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}]) + s1.begin_nested() - s1.begin_nested() - - s1.connection().execute( - users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}] - ) + s1.connection().execute( + users.insert(), + [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}], + ) - eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3) + eq_( + s1.connection().scalar( + select(func.count()).select_from(users) + ), + 3, + ) - # rolls back the whole transaction - s1.rollback() - is_(s1._legacy_transaction(), None) + # rolls back the whole transaction + s1.rollback() + is_(s1._legacy_transaction(), None) - eq_(s1.connection().scalar(select(func.count()).select_from(users)), 0) + eq_( + s1.connection().scalar( + select(func.count()).select_from(users) + ), + 0, + ) - s1.commit() - is_(s1._legacy_transaction(), None) + s1.commit() + is_(s1._legacy_transaction(), None) @testing.requires.savepoints def test_old_rollback_is_local(self): users = self.tables.users - s1 = Session(testing.db) + with fixture_session() as s1: - t1 = s1.begin() + t1 = s1.begin() - s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}]) + s1.connection().execute(users.insert(), [{"id": 1, "name": "n1"}]) - s1.begin_nested() + s1.begin_nested() - s1.connection().execute( - users.insert(), [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}] - ) + s1.connection().execute( + users.insert(), + [{"id": 2, "name": "n2"}, {"id": 3, "name": "n3"}], + ) - eq_(s1.connection().scalar(select(func.count()).select_from(users)), 3) + eq_( + s1.connection().scalar( + select(func.count()).select_from(users) + ), + 3, + ) - # rolls back only the savepoint - s1.rollback() + # rolls back only the savepoint + s1.rollback() - is_(s1._legacy_transaction(), t1) + is_(s1._legacy_transaction(), t1) - eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1) + eq_( + s1.connection().scalar( + select(func.count()).select_from(users) + ), + 1, + ) - s1.commit() - eq_(s1.connection().scalar(select(func.count()).select_from(users)), 1) - is_not(s1._legacy_transaction(), None) + s1.commit() + eq_( + s1.connection().scalar( + select(func.count()).select_from(users) + ), + 1, + ) + is_not(s1._legacy_transaction(), None) def test_session_as_ctx_manager_one(self): users = self.tables.users - with Session(testing.db) as sess: + with fixture_session() as sess: is_not(sess._legacy_transaction(), None) sess.connection().execute( @@ -2212,7 +2236,7 @@ class ContextManagerPlusFutureTest(FixtureTest): def test_session_as_ctx_manager_future_one(self): users = self.tables.users - with Session(testing.db, future=True) as sess: + with fixture_session(future=True) as sess: is_(sess._legacy_transaction(), None) sess.connection().execute( @@ -2234,7 +2258,7 @@ class ContextManagerPlusFutureTest(FixtureTest): users = self.tables.users try: - with Session(testing.db) as sess: + with fixture_session() as sess: is_not(sess._legacy_transaction(), None) sess.connection().execute( @@ -2250,7 +2274,7 @@ class ContextManagerPlusFutureTest(FixtureTest): users = self.tables.users try: - with Session(testing.db, future=True) as sess: + with fixture_session(future=True) as sess: is_(sess._legacy_transaction(), None) sess.connection().execute( @@ -2265,7 +2289,7 @@ class ContextManagerPlusFutureTest(FixtureTest): def test_begin_context_manager(self): users = self.tables.users - with Session(testing.db) as sess: + with fixture_session() as sess: with sess.begin(): sess.connection().execute( users.insert().values(id=1, name="user1") @@ -2296,12 +2320,13 @@ class ContextManagerPlusFutureTest(FixtureTest): # committed eq_(sess.connection().execute(users.select()).all(), [(1, "user1")]) + sess.close() def test_begin_context_manager_rollback_trans(self): users = self.tables.users try: - with Session(testing.db) as sess: + with fixture_session() as sess: with sess.begin(): sess.connection().execute( users.insert().values(id=1, name="user1") @@ -2318,12 +2343,13 @@ class ContextManagerPlusFutureTest(FixtureTest): # rolled back eq_(sess.connection().execute(users.select()).all(), []) + sess.close() def test_begin_context_manager_rollback_outer(self): users = self.tables.users try: - with Session(testing.db) as sess: + with fixture_session() as sess: with sess.begin(): sess.connection().execute( users.insert().values(id=1, name="user1") @@ -2340,6 +2366,7 @@ class ContextManagerPlusFutureTest(FixtureTest): # committed eq_(sess.connection().execute(users.select()).all(), [(1, "user1")]) + sess.close() def test_sessionmaker_begin_context_manager_rollback_trans(self): users = self.tables.users @@ -2363,6 +2390,7 @@ class ContextManagerPlusFutureTest(FixtureTest): # rolled back eq_(sess.connection().execute(users.select()).all(), []) + sess.close() def test_sessionmaker_begin_context_manager_rollback_outer(self): users = self.tables.users @@ -2386,36 +2414,37 @@ class ContextManagerPlusFutureTest(FixtureTest): # committed eq_(sess.connection().execute(users.select()).all(), [(1, "user1")]) + sess.close() class TransactionFlagsTest(fixtures.TestBase): def test_in_transaction(self): - s1 = Session(testing.db) + with fixture_session() as s1: - eq_(s1.in_transaction(), False) + eq_(s1.in_transaction(), False) - trans = s1.begin() + trans = s1.begin() - eq_(s1.in_transaction(), True) - is_(s1.get_transaction(), trans) + eq_(s1.in_transaction(), True) + is_(s1.get_transaction(), trans) - n1 = s1.begin_nested() + n1 = s1.begin_nested() - eq_(s1.in_transaction(), True) - is_(s1.get_transaction(), trans) - is_(s1.get_nested_transaction(), n1) + eq_(s1.in_transaction(), True) + is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), n1) - n1.rollback() + n1.rollback() - is_(s1.get_nested_transaction(), None) - is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), None) + is_(s1.get_transaction(), trans) - eq_(s1.in_transaction(), True) + eq_(s1.in_transaction(), True) - s1.commit() + s1.commit() - eq_(s1.in_transaction(), False) - is_(s1.get_transaction(), None) + eq_(s1.in_transaction(), False) + is_(s1.get_transaction(), None) def test_in_transaction_subtransactions(self): """we'd like to do away with subtransactions for future sessions @@ -2425,72 +2454,71 @@ class TransactionFlagsTest(fixtures.TestBase): the external API works. """ - s1 = Session(testing.db) - - eq_(s1.in_transaction(), False) + with fixture_session() as s1: + eq_(s1.in_transaction(), False) - trans = s1.begin() + trans = s1.begin() - eq_(s1.in_transaction(), True) - is_(s1.get_transaction(), trans) + eq_(s1.in_transaction(), True) + is_(s1.get_transaction(), trans) - subtrans = s1.begin(_subtrans=True) - is_(s1.get_transaction(), trans) - eq_(s1.in_transaction(), True) + subtrans = s1.begin(_subtrans=True) + is_(s1.get_transaction(), trans) + eq_(s1.in_transaction(), True) - is_(s1._transaction, subtrans) + is_(s1._transaction, subtrans) - s1.rollback() + s1.rollback() - eq_(s1.in_transaction(), True) - is_(s1._transaction, trans) + eq_(s1.in_transaction(), True) + is_(s1._transaction, trans) - s1.rollback() + s1.rollback() - eq_(s1.in_transaction(), False) - is_(s1._transaction, None) + eq_(s1.in_transaction(), False) + is_(s1._transaction, None) def test_in_transaction_nesting(self): - s1 = Session(testing.db) + with fixture_session() as s1: - eq_(s1.in_transaction(), False) + eq_(s1.in_transaction(), False) - trans = s1.begin() + trans = s1.begin() - eq_(s1.in_transaction(), True) - is_(s1.get_transaction(), trans) + eq_(s1.in_transaction(), True) + is_(s1.get_transaction(), trans) - sp1 = s1.begin_nested() + sp1 = s1.begin_nested() - eq_(s1.in_transaction(), True) - is_(s1.get_transaction(), trans) - is_(s1.get_nested_transaction(), sp1) + eq_(s1.in_transaction(), True) + is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), sp1) - sp2 = s1.begin_nested() + sp2 = s1.begin_nested() - eq_(s1.in_transaction(), True) - eq_(s1.in_nested_transaction(), True) - is_(s1.get_transaction(), trans) - is_(s1.get_nested_transaction(), sp2) + eq_(s1.in_transaction(), True) + eq_(s1.in_nested_transaction(), True) + is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), sp2) - sp2.rollback() + sp2.rollback() - eq_(s1.in_transaction(), True) - eq_(s1.in_nested_transaction(), True) - is_(s1.get_transaction(), trans) - is_(s1.get_nested_transaction(), sp1) + eq_(s1.in_transaction(), True) + eq_(s1.in_nested_transaction(), True) + is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), sp1) - sp1.rollback() + sp1.rollback() - is_(s1.get_nested_transaction(), None) - eq_(s1.in_transaction(), True) - eq_(s1.in_nested_transaction(), False) - is_(s1.get_transaction(), trans) + is_(s1.get_nested_transaction(), None) + eq_(s1.in_transaction(), True) + eq_(s1.in_nested_transaction(), False) + is_(s1.get_transaction(), trans) - s1.rollback() + s1.rollback() - eq_(s1.in_transaction(), False) - is_(s1.get_transaction(), None) + eq_(s1.in_transaction(), False) + is_(s1.get_transaction(), None) class NaturalPKRollbackTest(fixtures.MappedTest): @@ -2674,8 +2702,11 @@ class NaturalPKRollbackTest(fixtures.MappedTest): class JoinIntoAnExternalTransactionFixture(object): """Test the "join into an external transaction" examples""" - def setup(self): - self.connection = testing.db.connect() + __leave_connections_for_teardown__ = True + + def setup_test(self): + self.engine = testing.db + self.connection = self.engine.connect() self.metadata = MetaData() self.table = Table( @@ -2686,6 +2717,17 @@ class JoinIntoAnExternalTransactionFixture(object): self.setup_session() + def teardown_test(self): + self.teardown_session() + + with self.connection.begin(): + self._assert_count(0) + + with self.connection.begin(): + self.table.drop(self.connection) + + self.connection.close() + def test_something(self): A = self.A @@ -2727,18 +2769,6 @@ class JoinIntoAnExternalTransactionFixture(object): ) eq_(result, count) - def teardown(self): - self.teardown_session() - - with self.connection.begin(): - self._assert_count(0) - - with self.connection.begin(): - self.table.drop(self.connection) - - # return connection to the Engine - self.connection.close() - class NewStyleJoinIntoAnExternalTransactionTest( JoinIntoAnExternalTransactionFixture @@ -2775,7 +2805,8 @@ class NewStyleJoinIntoAnExternalTransactionTest( # rollback - everything that happened with the # Session above (including calls to commit()) # is rolled back. - self.trans.rollback() + if self.trans.is_active: + self.trans.rollback() class FutureJoinIntoAnExternalTransactionTest( diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 84373b2dca..2c35bec45f 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -203,14 +203,6 @@ class UnicodeSchemaTest(fixtures.MappedTest): cls.tables["t1"] = t1 cls.tables["t2"] = t2 - @classmethod - def setup_class(cls): - super(UnicodeSchemaTest, cls).setup_class() - - @classmethod - def teardown_class(cls): - super(UnicodeSchemaTest, cls).teardown_class() - def test_mapping(self): t2, t1 = self.tables.t2, self.tables.t1 diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 4e713627c1..65089f773c 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -771,14 +771,13 @@ class RudimentaryFlushTest(UOWTest): class SingleCycleTest(UOWTest): - def teardown(self): + def teardown_test(self): engines.testing_reaper.rollback_all() # mysql can't handle delete from nodes # since it doesn't deal with the FKs correctly, # so wipe out the parent_id first with testing.db.begin() as conn: conn.execute(self.tables.nodes.update().values(parent_id=None)) - super(SingleCycleTest, self).teardown() def test_one_to_many_save(self): Node, nodes = self.classes.Node, self.tables.nodes diff --git a/test/requirements.py b/test/requirements.py index d5a7183727..3c9b39ac71 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1624,15 +1624,15 @@ class DefaultRequirements(SuiteRequirements): @property def postgresql_utf8_server_encoding(self): + def go(config): + if not against(config, "postgresql"): + return False - return only_if( - lambda config: against(config, "postgresql") - and config.db.connect(close_with_result=True) - .exec_driver_sql("show server_encoding") - .scalar() - .lower() - == "utf8" - ) + with config.db.connect() as conn: + enc = conn.exec_driver_sql("show server_encoding").scalar() + return enc.lower() == "utf8" + + return only_if(go) @property def cxoracle6_or_greater(self): diff --git a/test/sql/test_case_statement.py b/test/sql/test_case_statement.py index 4bef1df7f3..b44971cecd 100644 --- a/test/sql/test_case_statement.py +++ b/test/sql/test_case_statement.py @@ -26,7 +26,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): metadata = MetaData() global info_table info_table = Table( @@ -52,7 +52,7 @@ class CaseTest(fixtures.TestBase, AssertsCompiledSQL): ) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): with testing.db.begin() as conn: info_table.drop(conn) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 70281d4e89..1ac3613f74 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -1203,7 +1203,7 @@ class CacheKeyTest(CacheKeyFixture, CoreFixtures, fixtures.TestBase): class CompareAndCopyTest(CoreFixtures, fixtures.TestBase): @classmethod - def setup_class(cls): + def setup_test_class(cls): # TODO: we need to get dialects here somehow, perhaps in test_suite? [ importlib.import_module("sqlalchemy.dialects.%s" % d) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index fdffe04bf3..4429753ecb 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -4306,7 +4306,7 @@ class StringifySpecialTest(fixtures.TestBase): class KwargPropagationTest(fixtures.TestBase): @classmethod - def setup_class(cls): + def setup_test_class(cls): from sqlalchemy.sql.expression import ColumnClause, TableClause class CatchCol(ColumnClause): diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 2a2e70bc39..8be7eed1f5 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -503,9 +503,8 @@ class DefaultRoundTripTest(fixtures.TablesTest): Column("col11", MyType(), default="foo"), ) - def teardown(self): + def teardown_test(self): self.default_generator["x"] = 50 - super(DefaultRoundTripTest, self).teardown() def test_standalone(self, connection): t = self.tables.default_test @@ -1226,7 +1225,7 @@ class SpecialTypePKTest(fixtures.TestBase): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): class MyInteger(TypeDecorator): impl = Integer diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py index acc12a5feb..7775652208 100644 --- a/test/sql/test_deprecations.py +++ b/test/sql/test_deprecations.py @@ -2488,12 +2488,12 @@ class LegacySequenceExecTest(fixtures.TestBase): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): cls.seq = Sequence("my_sequence") cls.seq.create(testing.db) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): cls.seq.drop(testing.db) def _assert_seq_result(self, ret): @@ -2574,7 +2574,7 @@ class LegacySequenceExecTest(fixtures.TestBase): class DDLDeprecatedBindTest(fixtures.TestBase): - def teardown(self): + def teardown_test(self): with testing.db.begin() as conn: if inspect(conn).has_table("foo"): conn.execute(schema.DropTable(table("foo"))) diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 4edc9d0258..a6001ba9da 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -47,7 +47,7 @@ class TraversalTest(fixtures.TestBase, AssertsExecutionResults): ability to copy and modify a ClauseElement in place.""" @classmethod - def setup_class(cls): + def setup_test_class(cls): global A, B # establish two fictitious ClauseElements. @@ -321,7 +321,7 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2, t3 t1 = table("table1", column("col1"), column("col2"), column("col3")) t2 = table("table2", column("col1"), column("col2"), column("col3")) @@ -1012,7 +1012,7 @@ class ColumnAdapterTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2 t1 = table( "table1", @@ -1196,7 +1196,7 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2 t1 = table("table1", column("col1"), column("col2"), column("col3")) t2 = table("table2", column("col1"), column("col2"), column("col3")) @@ -1943,7 +1943,7 @@ class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): global table1, table2, table3, table4 def _table(name): @@ -2031,7 +2031,7 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2 t1 = table("table1", column("col1"), column("col2"), column("col3")) t2 = table("table2", column("col1"), column("col2"), column("col3")) @@ -2128,7 +2128,7 @@ class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): # fixme: consolidate converage from elsewhere here and expand @classmethod - def setup_class(cls): + def setup_test_class(cls): global t1, t2 t1 = table("table1", column("col1"), column("col2"), column("col3")) t2 = table("table2", column("col1"), column("col2"), column("col3")) diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py index 6afe41aaca..b0bcee18e2 100644 --- a/test/sql/test_from_linter.py +++ b/test/sql/test_from_linter.py @@ -25,7 +25,7 @@ class TestFindUnmatchingFroms(fixtures.TablesTest): Table("table_c", metadata, Column("col_c", Integer, primary_key=True)) Table("table_d", metadata, Column("col_d", Integer, primary_key=True)) - def setup(self): + def setup_test(self): self.a = self.tables.table_a self.b = self.tables.table_b self.c = self.tables.table_c @@ -267,8 +267,10 @@ class TestLinter(fixtures.TablesTest): with self.bind.connect() as conn: conn.execute(query) - def test_no_linting(self): - eng = engines.testing_engine(options={"enable_from_linting": False}) + def test_no_linting(self, metadata, connection): + eng = engines.testing_engine( + options={"enable_from_linting": False, "use_reaper": False} + ) eng.pool = self.bind.pool # needed for SQLite a, b = self.tables("table_a", "table_b") query = select(a.c.col_a).where(b.c.col_b == 5) diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 91076f9c38..32ea642d74 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -54,10 +54,10 @@ table1 = table( class CompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" - def setup(self): + def setup_test(self): self._registry = deepcopy(functions._registry) - def teardown(self): + def teardown_test(self): functions._registry = self._registry def test_compile(self): @@ -938,7 +938,7 @@ class ReturnTypeTest(AssertsCompiledSQL, fixtures.TestBase): class ExecuteTest(fixtures.TestBase): __backend__ = True - def tearDown(self): + def teardown_test(self): pass def test_conn_execute(self, connection): @@ -1113,10 +1113,10 @@ class ExecuteTest(fixtures.TestBase): class RegisterTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" - def setup(self): + def setup_test(self): self._registry = deepcopy(functions._registry) - def teardown(self): + def teardown_test(self): functions._registry = self._registry def test_GenericFunction_is_registered(self): diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 502f70ce78..9bf351f5c4 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -4257,7 +4257,7 @@ class DialectKWArgTest(fixtures.TestBase): with mock.patch("sqlalchemy.dialects.registry.load", load): yield - def teardown(self): + def teardown_test(self): Index._kw_registry.clear() def test_participating(self): diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index aaeed68ddb..270e79ba16 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -608,7 +608,7 @@ class ExtensionOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): - def setUp(self): + def setup_test(self): class MyTypeCompiler(compiler.GenericTypeCompiler): def visit_mytype(self, type_, **kw): return "MYTYPE" @@ -766,7 +766,7 @@ class JSONIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): - def setUp(self): + def setup_test(self): class MyTypeCompiler(compiler.GenericTypeCompiler): def visit_mytype(self, type_, **kw): return "MYTYPE" @@ -2370,7 +2370,7 @@ class MatchTest(fixtures.TestBase, testing.AssertsCompiledSQL): class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) @@ -2403,7 +2403,7 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): class RegexpTestStrCompiler(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default_enhanced" - def setUp(self): + def setup_test(self): self.table = table( "mytable", column("myid", Integer), column("name", String) ) diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 136f10cf4c..7ad12c6203 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -661,8 +661,8 @@ class CursorResultTest(fixtures.TablesTest): assert_raises(KeyError, lambda: row._mapping["Case_insensitive"]) assert_raises(KeyError, lambda: row._mapping["casesensitive"]) - def test_row_case_sensitive_unoptimized(self): - with engines.testing_engine().connect() as ins_conn: + def test_row_case_sensitive_unoptimized(self, testing_engine): + with testing_engine().connect() as ins_conn: row = ins_conn.execute( select( literal_column("1").label("case_insensitive"), @@ -1234,8 +1234,7 @@ class CursorResultTest(fixtures.TablesTest): eq_(proxy[0], "value") eq_(proxy._mapping["key"], "value") - @testing.provide_metadata - def test_no_rowcount_on_selects_inserts(self): + def test_no_rowcount_on_selects_inserts(self, metadata, testing_engine): """assert that rowcount is only called on deletes and updates. This because cursor.rowcount may can be expensive on some dialects @@ -1244,9 +1243,7 @@ class CursorResultTest(fixtures.TablesTest): """ - metadata = self.metadata - - engine = engines.testing_engine() + engine = testing_engine() t = Table("t1", metadata, Column("data", String(10))) metadata.create_all(engine) @@ -2132,7 +2129,9 @@ class AlternateCursorResultTest(fixtures.TablesTest): @classmethod def setup_bind(cls): - cls.engine = engine = engines.testing_engine("sqlite://") + cls.engine = engine = engines.testing_engine( + "sqlite://", options={"scope": "class"} + ) return engine @classmethod diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py index 65325aa6f7..5cfc2663f6 100644 --- a/test/sql/test_sequences.py +++ b/test/sql/test_sequences.py @@ -100,12 +100,12 @@ class SequenceExecTest(fixtures.TestBase): __backend__ = True @classmethod - def setup_class(cls): + def setup_test_class(cls): cls.seq = Sequence("my_sequence") cls.seq.create(testing.db) @classmethod - def teardown_class(cls): + def teardown_test_class(cls): cls.seq.drop(testing.db) def _assert_seq_result(self, ret): diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 0e11478007..64ace87dfa 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1375,7 +1375,7 @@ class VariantBackendTest(fixtures.TestBase, AssertsCompiledSQL): class VariantTest(fixtures.TestBase, AssertsCompiledSQL): - def setup(self): + def setup_test(self): class UTypeOne(types.UserDefinedType): def get_col_spec(self): return "UTYPEONE" @@ -2504,7 +2504,7 @@ class BinaryTest(fixtures.TablesTest, AssertsExecutionResults): class JSONTest(fixtures.TestBase): - def setup(self): + def setup_test(self): metadata = MetaData() self.test_table = Table( "test_table", @@ -3445,7 +3445,12 @@ class BooleanTest( @testing.requires.non_native_boolean_unconstrained def test_constraint(self, connection): assert_raises( - (exc.IntegrityError, exc.ProgrammingError, exc.OperationalError), + ( + exc.IntegrityError, + exc.ProgrammingError, + exc.OperationalError, + exc.InternalError, # older pymysql's do this + ), connection.exec_driver_sql, "insert into boolean_table (id, value) values(1, 5)", ) diff --git a/tox.ini b/tox.ini index 8f2fd9de0f..ea2b76e166 100644 --- a/tox.ini +++ b/tox.ini @@ -15,7 +15,9 @@ install_command=python -m pip install {env:TOX_PIP_OPTS:} {opts} {packages} usedevelop= cov: True -deps=pytest>=4.6.11 # this can be 6.x once we are on python 3 only +deps= + pytest>=4.6.11,<5.0; python_version < '3' + pytest>=6.2; python_version >= '3' pytest-xdist greenlet != 0.4.17 mock; python_version < '3.3' @@ -74,9 +76,11 @@ setenv= sqlite_file: SQLITE={env:TOX_SQLITE_FILE:--db sqlite_file} postgresql: POSTGRESQL={env:TOX_POSTGRESQL:--db postgresql} + py2{,7}-postgresql: POSTGRESQL={env:TOX_POSTGRESQL_PY2K:{env:TOX_POSTGRESQL:--db postgresql}} py3{,5,6,7,8,9,10,11}-postgresql: EXTRA_PG_DRIVERS={env:EXTRA_PG_DRIVERS:--dbdriver psycopg2 --dbdriver asyncpg --dbdriver pg8000} mysql: MYSQL={env:TOX_MYSQL:--db mysql} + py2{,7}-mysql: MYSQL={env:TOX_MYSQL_PY2K:{env:TOX_MYSQL:--db mysql}} mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql} py3{,5,6,7,8,9,10,11}-mysql: EXTRA_MYSQL_DRIVERS={env:EXTRA_MYSQL_DRIVERS:--dbdriver mysqldb --dbdriver pymysql --dbdriver mariadbconnector --dbdriver aiomysql} @@ -89,7 +93,7 @@ setenv= # tox as of 2.0 blocks all environment variables from the # outside, unless they are here (or in TOX_TESTENV_PASSENV, # wildcards OK). Need at least these -passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_MYSQL TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS EXTRA_PG_DRIVERS EXTRA_MYSQL_DRIVERS +passenv=ORACLE_HOME NLS_LANG TOX_POSTGRESQL TOX_POSTGRESQL_PY2K TOX_MYSQL TOX_MYSQL_PY2K TOX_ORACLE TOX_MSSQL TOX_SQLITE TOX_SQLITE_FILE TOX_WORKERS EXTRA_PG_DRIVERS EXTRA_MYSQL_DRIVERS # for nocext, we rm *.so in lib in case we are doing usedevelop=True commands= -- 2.39.5