From: Mike Bayer Date: Sun, 15 Nov 2020 21:58:50 +0000 (-0500) Subject: correct for "autocommit" deprecation warning X-Git-Tag: rel_1_4_0b2~106^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=ba5cbf9366e9b2c5ed8e27e91815d7a2c3b63e41;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git correct for "autocommit" deprecation warning Ensure no autocommit warnings occur internally or within tests. Also includes fixes for SQL Server full text tests which apparently have not been working at all for a long time, as it used long removed APIs. CI has not had fulltext running for some years and is now installed. Change-Id: Id806e1856c9da9f0a9eac88cebc7a94ecc95eb96 --- diff --git a/lib/sqlalchemy/dialects/mysql/provision.py b/lib/sqlalchemy/dialects/mysql/provision.py index c1d83bbb76..50b6e3c850 100644 --- a/lib/sqlalchemy/dialects/mysql/provision.py +++ b/lib/sqlalchemy/dialects/mysql/provision.py @@ -41,12 +41,13 @@ def generate_driver_url(url, driver, query_str): @create_db.for_db("mysql", "mariadb") def _mysql_create_db(cfg, eng, ident): - with eng.connect() as conn: + with eng.begin() as conn: try: _mysql_drop_db(cfg, conn, ident) except Exception: pass + with eng.begin() as conn: conn.exec_driver_sql( "CREATE DATABASE %s CHARACTER SET utf8mb4" % ident ) @@ -66,7 +67,7 @@ def _mysql_configure_follower(config, ident): @drop_db.for_db("mysql", "mariadb") def _mysql_drop_db(cfg, eng, ident): - with eng.connect() as conn: + with eng.begin() as conn: conn.exec_driver_sql("DROP DATABASE %s_test_schema" % ident) conn.exec_driver_sql("DROP DATABASE %s_test_schema_2" % ident) conn.exec_driver_sql("DROP DATABASE %s" % ident) diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py index d19dfc9fe6..aadc2c5a99 100644 --- a/lib/sqlalchemy/dialects/oracle/provision.py +++ b/lib/sqlalchemy/dialects/oracle/provision.py @@ -17,7 +17,7 @@ def _oracle_create_db(cfg, eng, ident): # NOTE: make sure you've run "ALTER DATABASE default tablespace users" or # similar, so that the default tablespace is not "system"; reflection will # fail otherwise - with eng.connect() as conn: + with eng.begin() as conn: conn.exec_driver_sql("create user %s identified by xe" % ident) conn.exec_driver_sql("create user %s_ts1 identified by xe" % ident) conn.exec_driver_sql("create user %s_ts2 identified by xe" % ident) @@ -45,7 +45,7 @@ def _ora_drop_ignore(conn, dbname): @drop_db.for_db("oracle") def _oracle_drop_db(cfg, eng, ident): - with eng.connect() as conn: + with eng.begin() as conn: # cx_Oracle seems to occasionally leak open connections when a large # suite it run, even if we confirm we have zero references to # connection objects. @@ -65,7 +65,7 @@ def _oracle_update_db_opts(db_url, db_opts): def _reap_oracle_dbs(url, idents): log.info("db reaper connecting to %r", url) eng = create_engine(url) - with eng.connect() as conn: + with eng.begin() as conn: log.info("identifiers in file: %s", ", ".join(idents)) diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py index 9433ec4585..575316c61d 100644 --- a/lib/sqlalchemy/dialects/postgresql/provision.py +++ b/lib/sqlalchemy/dialects/postgresql/provision.py @@ -13,7 +13,7 @@ from ...testing.provision import temp_table_keyword_args def _pg_create_db(cfg, eng, ident): template_db = cfg.options.postgresql_templatedb - with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn: try: _pg_drop_db(cfg, conn, ident) except Exception: @@ -51,15 +51,16 @@ def _pg_create_db(cfg, eng, ident): @drop_db.for_db("postgresql") def _pg_drop_db(cfg, eng, ident): with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: - conn.execute( - text( - "select pg_terminate_backend(pid) from pg_stat_activity " - "where usename=current_user and pid != pg_backend_pid() " - "and datname=:dname" - ), - dname=ident, - ) - conn.exec_driver_sql("DROP DATABASE %s" % ident) + with conn.begin(): + conn.execute( + text( + "select pg_terminate_backend(pid) from pg_stat_activity " + "where usename=current_user and pid != pg_backend_pid() " + "and datname=:dname" + ), + dname=ident, + ) + conn.exec_driver_sql("DROP DATABASE %s" % ident) @temp_table_keyword_args.for_db("postgresql") diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 9a5518a961..028af9fbb7 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -840,7 +840,15 @@ class Connection(Connectable): def _commit_impl(self, autocommit=False): assert not self.__branch_from - if autocommit: + # AUTOCOMMIT isolation-level is a dialect-specific concept, however + # if a connection has this set as the isolation level, we can skip + # the "autocommit" warning as the operation will do "autocommit" + # in any case + if ( + autocommit + and self._execution_options.get("isolation_level", None) + != "AUTOCOMMIT" + ): util.warn_deprecated_20( "The current statement is being autocommitted using " "implicit autocommit, which will be removed in " @@ -2687,9 +2695,11 @@ class Engine(Connectable, log.Identified): self.pool = self.pool.recreate() self.dispatch.engine_disposed(self) - def _execute_default(self, default): + def _execute_default( + self, default, multiparams=(), params=util.EMPTY_DICT + ): with self.connect() as conn: - return conn._execute_default(default, (), {}) + return conn._execute_default(default, multiparams, params) @contextlib.contextmanager def _optional_conn_ctx_manager(self, connection=None): diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 4b19ff02a1..b5e45c18d9 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2258,10 +2258,10 @@ class DefaultGenerator(SchemaItem): "or in the ORM by the :meth:`.Session.execute` method of " ":class:`.Session`.", ) - def execute(self, bind=None, **kwargs): + def execute(self, bind=None): if bind is None: bind = _bind_or_error(self) - return bind.execute(self, **kwargs) + return bind._execute_default(self, (), util.EMPTY_DICT) def _execute_on_connection( self, connection, multiparams, params, execution_options diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py index 4addca009b..a94ee55dc0 100644 --- a/lib/sqlalchemy/testing/suite/test_cte.py +++ b/lib/sqlalchemy/testing/suite/test_cte.py @@ -1,4 +1,3 @@ -from .. import config from .. import fixtures from ..assertions import eq_ from ..schema import Column @@ -48,164 +47,158 @@ class CTETest(fixtures.TablesTest): ], ) - def test_select_nonrecursive_round_trip(self): + def test_select_nonrecursive_round_trip(self, connection): some_table = self.tables.some_table - with config.db.connect() as conn: - cte = ( - select(some_table) - .where(some_table.c.data.in_(["d2", "d3", "d4"])) - .cte("some_cte") - ) - result = conn.execute( - select(cte.c.data).where(cte.c.data.in_(["d4", "d5"])) - ) - eq_(result.fetchall(), [("d4",)]) + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + result = connection.execute( + select(cte.c.data).where(cte.c.data.in_(["d4", "d5"])) + ) + eq_(result.fetchall(), [("d4",)]) - def test_select_recursive_round_trip(self): + def test_select_recursive_round_trip(self, connection): some_table = self.tables.some_table - with config.db.connect() as conn: - cte = ( - select(some_table) - .where(some_table.c.data.in_(["d2", "d3", "d4"])) - .cte("some_cte", recursive=True) - ) + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte", recursive=True) + ) - cte_alias = cte.alias("c1") - st1 = some_table.alias() - # note that SQL Server requires this to be UNION ALL, - # can't be UNION - cte = cte.union_all( - select(st1).where(st1.c.id == cte_alias.c.parent_id) - ) - result = conn.execute( - select(cte.c.data) - .where(cte.c.data != "d2") - .order_by(cte.c.data.desc()) - ) - eq_( - result.fetchall(), - [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)], - ) + cte_alias = cte.alias("c1") + st1 = some_table.alias() + # note that SQL Server requires this to be UNION ALL, + # can't be UNION + cte = cte.union_all( + select(st1).where(st1.c.id == cte_alias.c.parent_id) + ) + result = connection.execute( + select(cte.c.data) + .where(cte.c.data != "d2") + .order_by(cte.c.data.desc()) + ) + eq_( + result.fetchall(), + [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)], + ) - def test_insert_from_select_round_trip(self): + def test_insert_from_select_round_trip(self, connection): some_table = self.tables.some_table some_other_table = self.tables.some_other_table - with config.db.connect() as conn: - cte = ( - select(some_table) - .where(some_table.c.data.in_(["d2", "d3", "d4"])) - .cte("some_cte") - ) - conn.execute( - some_other_table.insert().from_select( - ["id", "data", "parent_id"], select(cte) - ) - ) - eq_( - conn.execute( - select(some_other_table).order_by(some_other_table.c.id) - ).fetchall(), - [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)], + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(cte) ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)], + ) @testing.requires.ctes_with_update_delete @testing.requires.update_from - def test_update_from_round_trip(self): + def test_update_from_round_trip(self, connection): some_table = self.tables.some_table some_other_table = self.tables.some_other_table - with config.db.connect() as conn: - conn.execute( - some_other_table.insert().from_select( - ["id", "data", "parent_id"], select(some_table) - ) + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) ) + ) - cte = ( - select(some_table) - .where(some_table.c.data.in_(["d2", "d3", "d4"])) - .cte("some_cte") - ) - conn.execute( - some_other_table.update() - .values(parent_id=5) - .where(some_other_table.c.data == cte.c.data) - ) - eq_( - conn.execute( - select(some_other_table).order_by(some_other_table.c.id) - ).fetchall(), - [ - (1, "d1", None), - (2, "d2", 5), - (3, "d3", 5), - (4, "d4", 5), - (5, "d5", 3), - ], - ) + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.update() + .values(parent_id=5) + .where(some_other_table.c.data == cte.c.data) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [ + (1, "d1", None), + (2, "d2", 5), + (3, "d3", 5), + (4, "d4", 5), + (5, "d5", 3), + ], + ) @testing.requires.ctes_with_update_delete @testing.requires.delete_from - def test_delete_from_round_trip(self): + def test_delete_from_round_trip(self, connection): some_table = self.tables.some_table some_other_table = self.tables.some_other_table - with config.db.connect() as conn: - conn.execute( - some_other_table.insert().from_select( - ["id", "data", "parent_id"], select(some_table) - ) + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) ) + ) - cte = ( - select(some_table) - .where(some_table.c.data.in_(["d2", "d3", "d4"])) - .cte("some_cte") - ) - conn.execute( - some_other_table.delete().where( - some_other_table.c.data == cte.c.data - ) - ) - eq_( - conn.execute( - select(some_other_table).order_by(some_other_table.c.id) - ).fetchall(), - [(1, "d1", None), (5, "d5", 3)], + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.delete().where( + some_other_table.c.data == cte.c.data ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(1, "d1", None), (5, "d5", 3)], + ) @testing.requires.ctes_with_update_delete - def test_delete_scalar_subq_round_trip(self): + def test_delete_scalar_subq_round_trip(self, connection): some_table = self.tables.some_table some_other_table = self.tables.some_other_table - with config.db.connect() as conn: - conn.execute( - some_other_table.insert().from_select( - ["id", "data", "parent_id"], select(some_table) - ) + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) ) + ) - cte = ( - select(some_table) - .where(some_table.c.data.in_(["d2", "d3", "d4"])) - .cte("some_cte") - ) - conn.execute( - some_other_table.delete().where( - some_other_table.c.data - == select(cte.c.data) - .where(cte.c.id == some_other_table.c.id) - .scalar_subquery() - ) - ) - eq_( - conn.execute( - select(some_other_table).order_by(some_other_table.c.id) - ).fetchall(), - [(1, "d1", None), (5, "d5", 3)], + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.delete().where( + some_other_table.c.data + == select(cte.c.data) + .where(cte.c.id == some_other_table.c.id) + .scalar_subquery() ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(1, "d1", None), (5, "d5", 3)], + ) diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py index 7f697b915d..b0df1218dd 100644 --- a/lib/sqlalchemy/testing/suite/test_dialect.py +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -123,7 +123,7 @@ class IsolationLevelTest(fixtures.TestBase): eq_(conn.get_isolation_level(), existing) -class AutocommitTest(fixtures.TablesTest): +class AutocommitIsolationTest(fixtures.TablesTest): run_deletes = "each" @@ -153,7 +153,8 @@ class AutocommitTest(fixtures.TablesTest): 1 if autocommit else None, ) - conn.execute(self.tables.some_table.delete()) + with conn.begin(): + conn.execute(self.tables.some_table.delete()) def test_autocommit_on(self): conn = config.db.connect() @@ -170,7 +171,7 @@ class AutocommitTest(fixtures.TablesTest): def test_turn_autocommit_off_via_default_iso_level(self): conn = config.db.connect() - conn.execution_options(isolation_level="AUTOCOMMIT") + conn = conn.execution_options(isolation_level="AUTOCOMMIT") self._test_conn_autocommits(conn, True) conn.execution_options( diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py index 9484d41d09..0298738663 100644 --- a/lib/sqlalchemy/testing/suite/test_results.py +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -355,7 +355,7 @@ class ServerSideCursorsTest( Column("data", String(50)), ) - with engine.connect() as connection: + with engine.begin() as connection: test_table.create(connection, checkfirst=True) connection.execute(test_table.insert(), dict(data="data1")) connection.execute(test_table.insert(), dict(data="data2")) @@ -396,7 +396,7 @@ class ServerSideCursorsTest( Column("data", String(50)), ) - with engine.connect() as connection: + with engine.begin() as connection: test_table.create(connection, checkfirst=True) connection.execute( test_table.insert(), diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py index 06945ff2a7..f3f902abd2 100644 --- a/lib/sqlalchemy/testing/suite/test_rowcount.py +++ b/lib/sqlalchemy/testing/suite/test_rowcount.py @@ -58,12 +58,14 @@ class RowCountTest(fixtures.TablesTest): assert len(r) == len(self.data) - def test_update_rowcount1(self): + def test_update_rowcount1(self, connection): employees_table = self.tables.employees # WHERE matches 3, 3 rows changed department = employees_table.c.department - r = employees_table.update(department == "C").execute(department="Z") + r = connection.execute( + employees_table.update(department == "C"), {"department": "Z"} + ) assert r.rowcount == 3 def test_update_rowcount2(self, connection): diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index da01aa484b..21d2e8942d 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -340,7 +340,7 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): # passing NULL for an expression that needs to be interpreted as # a certain type, does the DBAPI have the info it needs to do this. date_table = self.tables.date_table - with config.db.connect() as conn: + with config.db.begin() as conn: result = conn.execute( date_table.insert(), {"date_data": self.data} ) @@ -702,7 +702,7 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): # testing "WHERE " renders a compatible expression boolean_table = self.tables.boolean_table - with config.db.connect() as conn: + with config.db.begin() as conn: conn.execute( boolean_table.insert(), [ @@ -817,7 +817,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def test_index_typed_access(self, datatype, value): data_table = self.tables.data_table data_element = {"key1": value} - with config.db.connect() as conn: + with config.db.begin() as conn: conn.execute( data_table.insert(), { @@ -841,7 +841,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def test_index_typed_comparison(self, datatype, value): data_table = self.tables.data_table data_element = {"key1": value} - with config.db.connect() as conn: + with config.db.begin() as conn: conn.execute( data_table.insert(), { @@ -864,7 +864,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def test_path_typed_comparison(self, datatype, value): data_table = self.tables.data_table data_element = {"key1": {"subkey1": value}} - with config.db.connect() as conn: + with config.db.begin() as conn: conn.execute( data_table.insert(), { @@ -900,7 +900,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def test_single_element_round_trip(self, element): data_table = self.tables.data_table data_element = element - with config.db.connect() as conn: + with config.db.begin() as conn: conn.execute( data_table.insert(), { @@ -928,7 +928,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): # support sqlite :memory: database... data_table.create(engine, checkfirst=True) - with engine.connect() as conn: + with engine.begin() as conn: conn.execute( data_table.insert(), {"name": "row1", "data": data_element} ) @@ -978,7 +978,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def test_round_trip_none_as_json_null(self): col = self.tables.data_table.c["data"] - with config.db.connect() as conn: + with config.db.begin() as conn: conn.execute( self.tables.data_table.insert(), {"name": "r1", "data": None} ) @@ -996,7 +996,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): def test_unicode_round_trip(self): # note we include Unicode supplementary characters as well - with config.db.connect() as conn: + with config.db.begin() as conn: conn.execute( self.tables.data_table.insert(), { diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index c52dc4a19b..c6626b9e08 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -370,7 +370,7 @@ def drop_all_tables(engine, inspector, schema=None, include_names=None): if include_names is not None: include_names = set(include_names) - with engine.connect() as conn: + with engine.begin() as conn: for tname, fkcs in reversed( inspector.get_sorted_table_and_fkc_names(schema=schema) ): diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py index b230bad6f0..cd919cc0b0 100644 --- a/lib/sqlalchemy/testing/warnings.py +++ b/lib/sqlalchemy/testing/warnings.py @@ -51,13 +51,11 @@ def setup_filters(): # Core execution # r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method", - r"The current statement is being autocommitted using implicit " - "autocommit,", r"The connection.execute\(\) method in SQLAlchemy 2.0 will accept " "parameters as a single dictionary or a single sequence of " "dictionaries only.", r"The Connection.connect\(\) method is considered legacy", - r".*DefaultGenerator.execute\(\)", + # r".*DefaultGenerator.execute\(\)", # # bound metadaa # diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py index 7188c41250..d36a0c9e1b 100644 --- a/test/aaa_profiling/test_resultset.py +++ b/test/aaa_profiling/test_resultset.py @@ -48,25 +48,28 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults): ) def setup(self): - metadata.create_all() - t.insert().execute( - [ - dict( - ("field%d" % fnum, u("value%d" % fnum)) - for fnum in range(NUM_FIELDS) - ) - for r_num in range(NUM_RECORDS) - ] - ) - t2.insert().execute( - [ - dict( - ("field%d" % fnum, u("value%d" % fnum)) - for fnum in range(NUM_FIELDS) - ) - for r_num in range(NUM_RECORDS) - ] - ) + with testing.db.begin() as conn: + metadata.create_all(conn) + conn.execute( + t.insert(), + [ + dict( + ("field%d" % fnum, u("value%d" % fnum)) + for fnum in range(NUM_FIELDS) + ) + for r_num in range(NUM_RECORDS) + ], + ) + conn.execute( + t2.insert(), + [ + dict( + ("field%d" % fnum, u("value%d" % fnum)) + for fnum in range(NUM_FIELDS) + ) + for r_num in range(NUM_RECORDS) + ], + ) # warm up type caches with testing.db.connect() as conn: diff --git a/test/conftest.py b/test/conftest.py index 63f3989ebc..0db4486a92 100755 --- a/test/conftest.py +++ b/test/conftest.py @@ -12,6 +12,8 @@ import sys import pytest +os.environ["SQLALCHEMY_WARN_20"] = "true" + collect_ignore_glob = [] # minimum version for a py3k only test is at diff --git a/test/dialect/mssql/test_engine.py b/test/dialect/mssql/test_engine.py index 4444559589..668df6ecbc 100644 --- a/test/dialect/mssql/test_engine.py +++ b/test/dialect/mssql/test_engine.py @@ -382,7 +382,7 @@ class FastExecutemanyTest(fixtures.TestBase): if executemany: assert cursor.fast_executemany - with eng.connect() as conn: + with eng.begin() as conn: conn.execute( t.insert(), [{"id": i, "data": "data_%d" % i} for i in range(100)], diff --git a/test/dialect/mssql/test_query.py b/test/dialect/mssql/test_query.py index d9dc033e16..ea0bfa4d27 100644 --- a/test/dialect/mssql/test_query.py +++ b/test/dialect/mssql/test_query.py @@ -9,7 +9,6 @@ from sqlalchemy import func from sqlalchemy import Identity from sqlalchemy import Integer from sqlalchemy import literal -from sqlalchemy import MetaData from sqlalchemy import or_ from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import select @@ -26,22 +25,15 @@ from sqlalchemy.testing.assertsql import CursorSQL from sqlalchemy.testing.assertsql import DialectSQL from sqlalchemy.util import ue -metadata = None -cattable = None -matchtable = None - -class IdentityInsertTest(fixtures.TestBase, AssertsCompiledSQL): +class IdentityInsertTest(fixtures.TablesTest, AssertsCompiledSQL): __only_on__ = "mssql" __dialect__ = mssql.MSDialect() __backend__ = True @classmethod - def setup_class(cls): - global metadata, cattable - metadata = MetaData(testing.db) - - cattable = Table( + def define_tables(cls, metadata): + Table( "cattable", metadata, Column("id", Integer), @@ -49,82 +41,82 @@ class IdentityInsertTest(fixtures.TestBase, AssertsCompiledSQL): PrimaryKeyConstraint("id", name="PK_cattable"), ) - def setup(self): - metadata.create_all() - - def teardown(self): - metadata.drop_all() - def test_compiled(self): + cattable = self.tables.cattable self.assert_compile( cattable.insert().values(id=9, description="Python"), "INSERT INTO cattable (id, description) " "VALUES (:id, :description)", ) - def test_execute(self): - with testing.db.connect() as conn: - conn.execute(cattable.insert().values(id=9, description="Python")) - - cats = conn.execute(cattable.select().order_by(cattable.c.id)) - eq_([(9, "Python")], list(cats)) + def test_execute(self, connection): + conn = connection + cattable = self.tables.cattable + conn.execute(cattable.insert().values(id=9, description="Python")) - result = conn.execute(cattable.insert().values(description="PHP")) - eq_(result.inserted_primary_key, (10,)) - lastcat = conn.execute( - cattable.select().order_by(desc(cattable.c.id)) - ) - eq_((10, "PHP"), lastcat.first()) - - def test_executemany(self): - with testing.db.connect() as conn: - conn.execute( - cattable.insert(), - [ - {"id": 89, "description": "Python"}, - {"id": 8, "description": "Ruby"}, - {"id": 3, "description": "Perl"}, - {"id": 1, "description": "Java"}, - ], - ) - cats = conn.execute(cattable.select().order_by(cattable.c.id)) - eq_( - [(1, "Java"), (3, "Perl"), (8, "Ruby"), (89, "Python")], - list(cats), - ) - conn.execute( - cattable.insert(), - [{"description": "PHP"}, {"description": "Smalltalk"}], - ) - lastcats = conn.execute( - cattable.select().order_by(desc(cattable.c.id)).limit(2) - ) - eq_([(91, "Smalltalk"), (90, "PHP")], list(lastcats)) + cats = conn.execute(cattable.select().order_by(cattable.c.id)) + eq_([(9, "Python")], list(cats)) - def test_insert_plain_param(self): - with testing.db.connect() as conn: - conn.execute(cattable.insert(), id=5) - eq_(conn.scalar(select(cattable.c.id)), 5) + result = conn.execute(cattable.insert().values(description="PHP")) + eq_(result.inserted_primary_key, (10,)) + lastcat = conn.execute(cattable.select().order_by(desc(cattable.c.id))) + eq_((10, "PHP"), lastcat.first()) - def test_insert_values_key_plain(self): - with testing.db.connect() as conn: - conn.execute(cattable.insert().values(id=5)) - eq_(conn.scalar(select(cattable.c.id)), 5) - - def test_insert_values_key_expression(self): - with testing.db.connect() as conn: - conn.execute(cattable.insert().values(id=literal(5))) - eq_(conn.scalar(select(cattable.c.id)), 5) - - def test_insert_values_col_plain(self): - with testing.db.connect() as conn: - conn.execute(cattable.insert().values({cattable.c.id: 5})) - eq_(conn.scalar(select(cattable.c.id)), 5) - - def test_insert_values_col_expression(self): - with testing.db.connect() as conn: - conn.execute(cattable.insert().values({cattable.c.id: literal(5)})) - eq_(conn.scalar(select(cattable.c.id)), 5) + def test_executemany(self, connection): + conn = connection + cattable = self.tables.cattable + conn.execute( + cattable.insert(), + [ + {"id": 89, "description": "Python"}, + {"id": 8, "description": "Ruby"}, + {"id": 3, "description": "Perl"}, + {"id": 1, "description": "Java"}, + ], + ) + cats = conn.execute(cattable.select().order_by(cattable.c.id)) + eq_( + [(1, "Java"), (3, "Perl"), (8, "Ruby"), (89, "Python")], + list(cats), + ) + conn.execute( + cattable.insert(), + [{"description": "PHP"}, {"description": "Smalltalk"}], + ) + lastcats = conn.execute( + cattable.select().order_by(desc(cattable.c.id)).limit(2) + ) + eq_([(91, "Smalltalk"), (90, "PHP")], list(lastcats)) + + def test_insert_plain_param(self, connection): + conn = connection + cattable = self.tables.cattable + conn.execute(cattable.insert(), id=5) + eq_(conn.scalar(select(cattable.c.id)), 5) + + def test_insert_values_key_plain(self, connection): + conn = connection + cattable = self.tables.cattable + conn.execute(cattable.insert().values(id=5)) + eq_(conn.scalar(select(cattable.c.id)), 5) + + def test_insert_values_key_expression(self, connection): + conn = connection + cattable = self.tables.cattable + conn.execute(cattable.insert().values(id=literal(5))) + eq_(conn.scalar(select(cattable.c.id)), 5) + + def test_insert_values_col_plain(self, connection): + conn = connection + cattable = self.tables.cattable + conn.execute(cattable.insert().values({cattable.c.id: 5})) + eq_(conn.scalar(select(cattable.c.id)), 5) + + def test_insert_values_col_expression(self, connection): + conn = connection + cattable = self.tables.cattable + conn.execute(cattable.insert().values({cattable.c.id: literal(5)})) + eq_(conn.scalar(select(cattable.c.id)), 5) class QueryUnicodeTest(fixtures.TestBase): @@ -391,37 +383,35 @@ def full_text_search_missing(): """Test if full text search is not implemented and return False if it is and True otherwise.""" - try: - connection = testing.db.connect() - try: - connection.exec_driver_sql( - "CREATE FULLTEXT CATALOG Catalog AS " "DEFAULT" - ) - return False - except Exception: - return True - finally: - connection.close() + if not testing.against("mssql"): + return True + + with testing.db.connect() as conn: + result = conn.exec_driver_sql( + "SELECT cast(SERVERPROPERTY('IsFullTextInstalled') as integer)" + ) + return result.scalar() == 0 -class MatchTest(fixtures.TestBase, AssertsCompiledSQL): +class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): __only_on__ = "mssql" __skip_if__ = (full_text_search_missing,) __backend__ = True + run_setup_tables = "once" + run_inserts = run_deletes = "once" + @classmethod - def setup_class(cls): - global metadata, cattable, matchtable - metadata = MetaData(testing.db) - cattable = Table( + def define_tables(cls, metadata): + Table( "cattable", metadata, Column("id", Integer), Column("description", String(50)), PrimaryKeyConstraint("id", name="PK_cattable"), ) - matchtable = Table( + Table( "matchtable", metadata, Column("id", Integer), @@ -429,24 +419,65 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): Column("category_id", Integer, ForeignKey("cattable.id")), PrimaryKeyConstraint("id", name="PK_matchtable"), ) - DDL( - """CREATE FULLTEXT INDEX + + event.listen( + metadata, + "before_create", + DDL("CREATE FULLTEXT CATALOG Catalog AS DEFAULT"), + ) + event.listen( + metadata, + "after_create", + DDL( + """CREATE FULLTEXT INDEX ON cattable (description) KEY INDEX PK_cattable""" - ).execute_at("after-create", matchtable) - DDL( - """CREATE FULLTEXT INDEX + ), + ) + event.listen( + metadata, + "after_create", + DDL( + """CREATE FULLTEXT INDEX ON matchtable (title) KEY INDEX PK_matchtable""" - ).execute_at("after-create", matchtable) - metadata.create_all() - cattable.insert().execute( + ), + ) + + event.listen( + metadata, + "after_drop", + DDL("DROP FULLTEXT CATALOG Catalog"), + ) + + @classmethod + def setup_bind(cls): + return testing.db.execution_options(isolation_level="AUTOCOMMIT") + + @classmethod + def setup_class(cls): + with testing.db.connect().execution_options( + isolation_level="AUTOCOMMIT" + ) as conn: + try: + conn.exec_driver_sql("DROP FULLTEXT CATALOG Catalog") + except: + pass + super(MatchTest, cls).setup_class() + + @classmethod + def insert_data(cls, connection): + cattable, matchtable = cls.tables("cattable", "matchtable") + + connection.execute( + cattable.insert(), [ {"id": 1, "description": "Python"}, {"id": 2, "description": "Ruby"}, - ] + ], ) - matchtable.insert().execute( + connection.execute( + matchtable.insert(), [ { "id": 1, @@ -461,62 +492,53 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): }, {"id": 4, "title": "Guide to Django", "category_id": 1}, {"id": 5, "title": "Python in a Nutshell", "category_id": 1}, - ] + ], ) - DDL("WAITFOR DELAY '00:00:05'").execute(bind=engines.testing_engine()) - - @classmethod - def teardown_class(cls): - metadata.drop_all() - connection = testing.db.connect() - connection.exec_driver_sql("DROP FULLTEXT CATALOG Catalog") - connection.close() + # apparently this is needed! index must run asynchronously + connection.execute(DDL("WAITFOR DELAY '00:00:05'")) def test_expression(self): + matchtable = self.tables.matchtable self.assert_compile( matchtable.c.title.match("somstr"), "CONTAINS (matchtable.title, ?)", ) - def test_simple_match(self): - results = ( + def test_simple_match(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( matchtable.select() .where(matchtable.c.title.match("python")) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([2, 5], [r.id for r in results]) - def test_simple_match_with_apostrophe(self): - results = ( - matchtable.select() - .where(matchtable.c.title.match("Matz's")) - .execute() - .fetchall() - ) + def test_simple_match_with_apostrophe(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( + matchtable.select().where(matchtable.c.title.match("Matz's")) + ).fetchall() eq_([3], [r.id for r in results]) - def test_simple_prefix_match(self): - results = ( - matchtable.select() - .where(matchtable.c.title.match('"nut*"')) - .execute() - .fetchall() - ) + def test_simple_prefix_match(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( + matchtable.select().where(matchtable.c.title.match('"nut*"')) + ).fetchall() eq_([5], [r.id for r in results]) - def test_simple_inflectional_match(self): - results = ( - matchtable.select() - .where(matchtable.c.title.match('FORMSOF(INFLECTIONAL, "dives")')) - .execute() - .fetchall() - ) + def test_simple_inflectional_match(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( + matchtable.select().where( + matchtable.c.title.match('FORMSOF(INFLECTIONAL, "dives")') + ) + ).fetchall() eq_([2], [r.id for r in results]) - def test_or_match(self): - results1 = ( + def test_or_match(self, connection): + matchtable = self.tables.matchtable + results1 = connection.execute( matchtable.select() .where( or_( @@ -525,31 +547,25 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): ) ) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([3, 5], [r.id for r in results1]) - results2 = ( + results2 = connection.execute( matchtable.select() .where(matchtable.c.title.match("nutshell OR ruby")) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([3, 5], [r.id for r in results2]) - def test_and_match(self): - results1 = ( - matchtable.select() - .where( + def test_and_match(self, connection): + matchtable = self.tables.matchtable + results1 = connection.execute( + matchtable.select().where( and_( matchtable.c.title.match("python"), matchtable.c.title.match("nutshell"), ) ) - .execute() - .fetchall() - ) + ).fetchall() eq_([5], [r.id for r in results1]) results2 = ( matchtable.select() @@ -559,8 +575,10 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): ) eq_([5], [r.id for r in results2]) - def test_match_across_joins(self): - results = ( + def test_match_across_joins(self, connection): + matchtable = self.tables.matchtable + cattable = self.tables.cattable + results = connection.execute( matchtable.select() .where( and_( @@ -572,7 +590,5 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): ) ) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([1, 3, 5], [r.id for r in results]) diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index 6009bfb6cb..86c97316ad 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -741,14 +741,9 @@ class IdentityReflectionTest(fixtures.TablesTest): @testing.requires.views def test_reflect_views(self, connection): - try: - with testing.db.connect() as conn: - conn.exec_driver_sql("CREATE VIEW view1 AS SELECT * FROM t1") - insp = inspect(testing.db) - for col in insp.get_columns("view1"): - is_true("dialect_options" not in col) - is_true("identity" in col) - eq_(col["identity"], {}) - finally: - with testing.db.connect() as conn: - conn.exec_driver_sql("DROP VIEW view1") + connection.exec_driver_sql("CREATE VIEW view1 AS SELECT * FROM t1") + insp = inspect(connection) + for col in insp.get_columns("view1"): + is_true("dialect_options" not in col) + is_true("identity" in col) + eq_(col["identity"], {}) diff --git a/test/dialect/mssql/test_types.py b/test/dialect/mssql/test_types.py index 11a2a25b3f..a4a3bedda3 100644 --- a/test/dialect/mssql/test_types.py +++ b/test/dialect/mssql/test_types.py @@ -221,7 +221,7 @@ class RowVersionTest(fixtures.TablesTest): Column("rv", cls(convert_int=convert_int)), ) - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(t.insert().values(data="foo")) last_ts_1 = conn.exec_driver_sql("SELECT @@DBTS").scalar() @@ -545,7 +545,7 @@ class TypeRoundTripTest( __backend__ = True @testing.provide_metadata - def test_decimal_notation(self): + def test_decimal_notation(self, connection): metadata = self.metadata numeric_table = Table( "numeric_table", @@ -560,7 +560,7 @@ class TypeRoundTripTest( "numericcol", Numeric(precision=38, scale=20, asdecimal=True) ), ) - metadata.create_all() + metadata.create_all(connection) test_items = [ decimal.Decimal(d) for d in ( @@ -623,21 +623,20 @@ class TypeRoundTripTest( ) ] - with testing.db.connect() as conn: - for value in test_items: - result = conn.execute( - numeric_table.insert(), dict(numericcol=value) - ) - primary_key = result.inserted_primary_key - returned = conn.scalar( - select(numeric_table.c.numericcol).where( - numeric_table.c.id == primary_key[0] - ) + for value in test_items: + result = connection.execute( + numeric_table.insert(), dict(numericcol=value) + ) + primary_key = result.inserted_primary_key + returned = connection.scalar( + select(numeric_table.c.numericcol).where( + numeric_table.c.id == primary_key[0] ) - eq_(value, returned) + ) + eq_(value, returned) @testing.provide_metadata - def test_float(self): + def test_float(self, connection): metadata = self.metadata float_table = Table( @@ -652,41 +651,47 @@ class TypeRoundTripTest( Column("floatcol", Float()), ) - metadata.create_all() - try: - test_items = [ - float(d) - for d in ( - "1500000.00000000000000000000", - "-1500000.00000000000000000000", - "1500000", - "0.0000000000000000002", - "0.2", - "-0.0000000000000000002", - "156666.458923543", - "-156666.458923543", - "1", - "-1", - "1234", - "2E-12", - "4E8", - "3E-6", - "3E-7", - "4.1", - "1E-1", - "1E-2", - "1E-3", - "1E-4", - "1E-5", - "1E-6", - "1E-7", - "1E-8", + metadata.create_all(connection) + test_items = [ + float(d) + for d in ( + "1500000.00000000000000000000", + "-1500000.00000000000000000000", + "1500000", + "0.0000000000000000002", + "0.2", + "-0.0000000000000000002", + "156666.458923543", + "-156666.458923543", + "1", + "-1", + "1234", + "2E-12", + "4E8", + "3E-6", + "3E-7", + "4.1", + "1E-1", + "1E-2", + "1E-3", + "1E-4", + "1E-5", + "1E-6", + "1E-7", + "1E-8", + ) + ] + for value in test_items: + result = connection.execute( + float_table.insert(), dict(floatcol=value) + ) + primary_key = result.inserted_primary_key + returned = connection.scalar( + select(float_table.c.floatcol).where( + float_table.c.id == primary_key[0] ) - ] - for value in test_items: - float_table.insert().execute(floatcol=value) - except Exception as e: - raise e + ) + eq_(value, returned) # todo this should suppress warnings, but it does not @emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*") @@ -770,18 +775,17 @@ class TypeRoundTripTest( d2 = datetime.datetime(2007, 10, 30, 11, 2, 32) return t, (d1, t1, d2) - def test_date_roundtrips(self, date_fixture): + def test_date_roundtrips(self, date_fixture, connection): t, (d1, t1, d2) = date_fixture - with testing.db.begin() as conn: - conn.execute( - t.insert(), adate=d1, adatetime=d2, atime1=t1, atime2=d2 - ) + connection.execute( + t.insert(), adate=d1, adatetime=d2, atime1=t1, atime2=d2 + ) - row = conn.execute(t.select()).first() - eq_( - (row.adate, row.adatetime, row.atime1, row.atime2), - (d1, d2, t1, d2.time()), - ) + row = connection.execute(t.select()).first() + eq_( + (row.adate, row.adatetime, row.atime1, row.atime2), + (d1, d2, t1, d2.time()), + ) @testing.metadata_fixture() def datetimeoffset_fixture(self, metadata): @@ -870,45 +874,45 @@ class TypeRoundTripTest( dto_param_value, expected_offset_hours, should_fail, + connection, ): t = datetimeoffset_fixture dto_param_value = dto_param_value() - with testing.db.begin() as conn: - if should_fail: - assert_raises( - sa.exc.DBAPIError, - conn.execute, - t.insert(), - adatetimeoffset=dto_param_value, - ) - return - - conn.execute( + if should_fail: + assert_raises( + sa.exc.DBAPIError, + connection.execute, t.insert(), adatetimeoffset=dto_param_value, ) + return - row = conn.execute(t.select()).first() + connection.execute( + t.insert(), + adatetimeoffset=dto_param_value, + ) - if dto_param_value is None: - is_(row.adatetimeoffset, None) - else: - eq_( - row.adatetimeoffset, - datetime.datetime( - 2007, - 10, - 30, - 11, - 2, - 32, - 123456, - util.timezone( - datetime.timedelta(hours=expected_offset_hours) - ), + row = connection.execute(t.select()).first() + + if dto_param_value is None: + is_(row.adatetimeoffset, None) + else: + eq_( + row.adatetimeoffset, + datetime.datetime( + 2007, + 10, + 30, + 11, + 2, + 32, + 123456, + util.timezone( + datetime.timedelta(hours=expected_offset_hours) ), - ) + ), + ) @emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*") @testing.provide_metadata @@ -1173,7 +1177,7 @@ class BinaryTest(fixtures.TestBase): if expected is None: expected = data - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(binary_table.insert(), data=data) eq_(conn.scalar(select(binary_table.c.data)), expected) diff --git a/test/dialect/mysql/test_dialect.py b/test/dialect/mysql/test_dialect.py index abd3a491ff..3c569bf058 100644 --- a/test/dialect/mysql/test_dialect.py +++ b/test/dialect/mysql/test_dialect.py @@ -20,7 +20,7 @@ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import mock -from ...engine import test_execute +from ...engine import test_deprecations class BackendDialectTest(fixtures.TestBase): @@ -382,56 +382,56 @@ class RemoveUTCTimestampTest(fixtures.TablesTest): Column("udata", DateTime, onupdate=func.utc_timestamp()), ) - def test_insert_executemany(self): - with testing.db.connect() as conn: - conn.execute( - self.tables.t.insert().values(data=func.utc_timestamp()), - [{"x": 5}, {"x": 6}, {"x": 7}], - ) + def test_insert_executemany(self, connection): + conn = connection + conn.execute( + self.tables.t.insert().values(data=func.utc_timestamp()), + [{"x": 5}, {"x": 6}, {"x": 7}], + ) - def test_update_executemany(self): - with testing.db.connect() as conn: - timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2) - conn.execute( - self.tables.t.insert(), - [ - {"x": 5, "data": timestamp}, - {"x": 6, "data": timestamp}, - {"x": 7, "data": timestamp}, - ], - ) + def test_update_executemany(self, connection): + conn = connection + timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2) + conn.execute( + self.tables.t.insert(), + [ + {"x": 5, "data": timestamp}, + {"x": 6, "data": timestamp}, + {"x": 7, "data": timestamp}, + ], + ) - conn.execute( - self.tables.t.update() - .values(data=func.utc_timestamp()) - .where(self.tables.t.c.x == bindparam("xval")), - [{"xval": 5}, {"xval": 6}, {"xval": 7}], - ) + conn.execute( + self.tables.t.update() + .values(data=func.utc_timestamp()) + .where(self.tables.t.c.x == bindparam("xval")), + [{"xval": 5}, {"xval": 6}, {"xval": 7}], + ) - def test_insert_executemany_w_default(self): - with testing.db.connect() as conn: - conn.execute( - self.tables.t_default.insert(), [{"x": 5}, {"x": 6}, {"x": 7}] - ) + def test_insert_executemany_w_default(self, connection): + conn = connection + conn.execute( + self.tables.t_default.insert(), [{"x": 5}, {"x": 6}, {"x": 7}] + ) - def test_update_executemany_w_default(self): - with testing.db.connect() as conn: - timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2) - conn.execute( - self.tables.t_default.insert(), - [ - {"x": 5, "idata": timestamp}, - {"x": 6, "idata": timestamp}, - {"x": 7, "idata": timestamp}, - ], - ) + def test_update_executemany_w_default(self, connection): + conn = connection + timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2) + conn.execute( + self.tables.t_default.insert(), + [ + {"x": 5, "idata": timestamp}, + {"x": 6, "idata": timestamp}, + {"x": 7, "idata": timestamp}, + ], + ) - conn.execute( - self.tables.t_default.update() - .values(idata=func.utc_timestamp()) - .where(self.tables.t_default.c.x == bindparam("xval")), - [{"xval": 5}, {"xval": 6}, {"xval": 7}], - ) + conn.execute( + self.tables.t_default.update() + .values(idata=func.utc_timestamp()) + .where(self.tables.t_default.c.x == bindparam("xval")), + [{"xval": 5}, {"xval": 6}, {"xval": 7}], + ) class SQLModeDetectionTest(fixtures.TestBase): @@ -505,7 +505,7 @@ class ExecutionTest(fixtures.TestBase): class AutocommitTextTest( - test_execute.AutocommitKeywordFixture, fixtures.TestBase + test_deprecations.AutocommitKeywordFixture, fixtures.TestBase ): __only_on__ = "mysql", "mariadb" diff --git a/test/dialect/mysql/test_on_duplicate.py b/test/dialect/mysql/test_on_duplicate.py index ed88121a55..dc86aaeb05 100644 --- a/test/dialect/mysql/test_on_duplicate.py +++ b/test/dialect/mysql/test_on_duplicate.py @@ -5,7 +5,6 @@ from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import Table -from sqlalchemy import testing from sqlalchemy.dialects.mysql import insert from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises @@ -47,155 +46,145 @@ class OnDuplicateTest(fixtures.TablesTest): {"id": 2, "bar": "baz"}, ) - def test_on_duplicate_key_update_multirow(self): + def test_on_duplicate_key_update_multirow(self, connection): foos = self.tables.foos - with testing.db.connect() as conn: - conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) - stmt = insert(foos).values( - [dict(id=1, bar="ab"), dict(id=2, bar="b")] - ) - stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar) - - result = conn.execute(stmt) - - # multirow, so its ambiguous. this is a behavioral change - # in 1.4 - eq_(result.inserted_primary_key, (None,)) - eq_( - conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), - [(1, "ab", "bz", False)], - ) + conn = connection + conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) + stmt = insert(foos).values([dict(id=1, bar="ab"), dict(id=2, bar="b")]) + stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar) + + result = conn.execute(stmt) + + # multirow, so its ambiguous. this is a behavioral change + # in 1.4 + eq_(result.inserted_primary_key, (None,)) + eq_( + conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), + [(1, "ab", "bz", False)], + ) - def test_on_duplicate_key_update_singlerow(self): + def test_on_duplicate_key_update_singlerow(self, connection): foos = self.tables.foos - with testing.db.connect() as conn: - conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) - stmt = insert(foos).values(dict(id=2, bar="b")) - stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar) - - result = conn.execute(stmt) - - # only one row in the INSERT so we do inserted_primary_key - eq_(result.inserted_primary_key, (2,)) - eq_( - conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), - [(1, "b", "bz", False)], - ) + conn = connection + conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) + stmt = insert(foos).values(dict(id=2, bar="b")) + stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar) + + result = conn.execute(stmt) + + # only one row in the INSERT so we do inserted_primary_key + eq_(result.inserted_primary_key, (2,)) + eq_( + conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), + [(1, "b", "bz", False)], + ) - def test_on_duplicate_key_update_null_multirow(self): + def test_on_duplicate_key_update_null_multirow(self, connection): foos = self.tables.foos - with testing.db.connect() as conn: - conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) - stmt = insert(foos).values( - [dict(id=1, bar="ab"), dict(id=2, bar="b")] - ) - stmt = stmt.on_duplicate_key_update(updated_once=None) - result = conn.execute(stmt) - - # ambiguous - eq_(result.inserted_primary_key, (None,)) - eq_( - conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), - [(1, "b", "bz", None)], - ) + conn = connection + conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) + stmt = insert(foos).values([dict(id=1, bar="ab"), dict(id=2, bar="b")]) + stmt = stmt.on_duplicate_key_update(updated_once=None) + result = conn.execute(stmt) + + # ambiguous + eq_(result.inserted_primary_key, (None,)) + eq_( + conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), + [(1, "b", "bz", None)], + ) - def test_on_duplicate_key_update_expression_multirow(self): + def test_on_duplicate_key_update_expression_multirow(self, connection): foos = self.tables.foos - with testing.db.connect() as conn: - conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) - stmt = insert(foos).values( - [dict(id=1, bar="ab"), dict(id=2, bar="b")] - ) - stmt = stmt.on_duplicate_key_update( - bar=func.concat(stmt.inserted.bar, "_foo") - ) - result = conn.execute(stmt) - eq_(result.inserted_primary_key, (None,)) - eq_( - conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), - [(1, "ab_foo", "bz", False)], - ) + conn = connection + conn.execute(insert(foos, dict(id=1, bar="b", baz="bz"))) + stmt = insert(foos).values([dict(id=1, bar="ab"), dict(id=2, bar="b")]) + stmt = stmt.on_duplicate_key_update( + bar=func.concat(stmt.inserted.bar, "_foo") + ) + result = conn.execute(stmt) + eq_(result.inserted_primary_key, (None,)) + eq_( + conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), + [(1, "ab_foo", "bz", False)], + ) - def test_on_duplicate_key_update_preserve_order(self): + def test_on_duplicate_key_update_preserve_order(self, connection): foos = self.tables.foos - with testing.db.connect() as conn: - conn.execute( - insert( - foos, - [ - dict(id=1, bar="b", baz="bz"), - dict(id=2, bar="b", baz="bz2"), - ], - ) - ) - - stmt = insert(foos) - update_condition = foos.c.updated_once == False - - # The following statements show importance of the columns update - # ordering as old values being referenced in UPDATE clause are - # getting replaced one by one from left to right with their new - # values. - stmt1 = stmt.on_duplicate_key_update( + conn = connection + conn.execute( + insert( + foos, [ - ( - "bar", - func.if_( - update_condition, - func.values(foos.c.bar), - foos.c.bar, - ), - ), - ( - "updated_once", - func.if_(update_condition, True, foos.c.updated_once), - ), - ] + dict(id=1, bar="b", baz="bz"), + dict(id=2, bar="b", baz="bz2"), + ], ) - stmt2 = stmt.on_duplicate_key_update( - [ - ( - "updated_once", - func.if_(update_condition, True, foos.c.updated_once), + ) + + stmt = insert(foos) + update_condition = foos.c.updated_once == False + + # The following statements show importance of the columns update + # ordering as old values being referenced in UPDATE clause are + # getting replaced one by one from left to right with their new + # values. + stmt1 = stmt.on_duplicate_key_update( + [ + ( + "bar", + func.if_( + update_condition, + func.values(foos.c.bar), + foos.c.bar, ), - ( - "bar", - func.if_( - update_condition, - func.values(foos.c.bar), - foos.c.bar, - ), + ), + ( + "updated_once", + func.if_(update_condition, True, foos.c.updated_once), + ), + ] + ) + stmt2 = stmt.on_duplicate_key_update( + [ + ( + "updated_once", + func.if_(update_condition, True, foos.c.updated_once), + ), + ( + "bar", + func.if_( + update_condition, + func.values(foos.c.bar), + foos.c.bar, ), - ] - ) - # First statement should succeed updating column bar - conn.execute(stmt1, dict(id=1, bar="ab")) - eq_( - conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), - [(1, "ab", "bz", True)], - ) - # Second statement will do noop update of column bar - conn.execute(stmt2, dict(id=2, bar="ab")) - eq_( - conn.execute(foos.select().where(foos.c.id == 2)).fetchall(), - [(2, "b", "bz2", True)], - ) + ), + ] + ) + # First statement should succeed updating column bar + conn.execute(stmt1, dict(id=1, bar="ab")) + eq_( + conn.execute(foos.select().where(foos.c.id == 1)).fetchall(), + [(1, "ab", "bz", True)], + ) + # Second statement will do noop update of column bar + conn.execute(stmt2, dict(id=2, bar="ab")) + eq_( + conn.execute(foos.select().where(foos.c.id == 2)).fetchall(), + [(2, "b", "bz2", True)], + ) - def test_last_inserted_id(self): + def test_last_inserted_id(self, connection): foos = self.tables.foos - with testing.db.connect() as conn: - stmt = insert(foos).values({"bar": "b", "baz": "bz"}) - result = conn.execute( - stmt.on_duplicate_key_update( - bar=stmt.inserted.bar, baz="newbz" - ) - ) - eq_(result.inserted_primary_key, (1,)) + conn = connection + stmt = insert(foos).values({"bar": "b", "baz": "bz"}) + result = conn.execute( + stmt.on_duplicate_key_update(bar=stmt.inserted.bar, baz="newbz") + ) + eq_(result.inserted_primary_key, (1,)) - stmt = insert(foos).values({"id": 1, "bar": "b", "baz": "bz"}) - result = conn.execute( - stmt.on_duplicate_key_update( - bar=stmt.inserted.bar, baz="newbz" - ) - ) - eq_(result.inserted_primary_key, (1,)) + stmt = insert(foos).values({"id": 1, "bar": "b", "baz": "bz"}) + result = conn.execute( + stmt.on_duplicate_key_update(bar=stmt.inserted.bar, baz="newbz") + ) + eq_(result.inserted_primary_key, (1,)) diff --git a/test/dialect/mysql/test_query.py b/test/dialect/mysql/test_query.py index f9d9caf166..f56cd98aa3 100644 --- a/test/dialect/mysql/test_query.py +++ b/test/dialect/mysql/test_query.py @@ -9,7 +9,6 @@ from sqlalchemy import Column from sqlalchemy import false from sqlalchemy import ForeignKey from sqlalchemy import Integer -from sqlalchemy import MetaData from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import String @@ -44,16 +43,13 @@ class IdiosyncrasyTest(fixtures.TestBase): ) -class MatchTest(fixtures.TestBase): +class MatchTest(fixtures.TablesTest): __only_on__ = "mysql", "mariadb" __backend__ = True @classmethod - def setup_class(cls): - global metadata, cattable, matchtable - metadata = MetaData(testing.db) - - cattable = Table( + def define_tables(cls, metadata): + Table( "cattable", metadata, Column("id", Integer, primary_key=True), @@ -61,7 +57,7 @@ class MatchTest(fixtures.TestBase): mysql_engine="MyISAM", mariadb_engine="MyISAM", ) - matchtable = Table( + Table( "matchtable", metadata, Column("id", Integer, primary_key=True), @@ -70,15 +66,20 @@ class MatchTest(fixtures.TestBase): mysql_engine="MyISAM", mariadb_engine="MyISAM", ) - metadata.create_all() - cattable.insert().execute( + @classmethod + def insert_data(cls, connection): + cattable, matchtable = cls.tables("cattable", "matchtable") + + connection.execute( + cattable.insert(), [ {"id": 1, "description": "Python"}, {"id": 2, "description": "Ruby"}, - ] + ], ) - matchtable.insert().execute( + connection.execute( + matchtable.insert(), [ { "id": 1, @@ -97,43 +98,36 @@ class MatchTest(fixtures.TestBase): "category_id": 1, }, {"id": 5, "title": "Python in a Nutshell", "category_id": 1}, - ] + ], ) - @classmethod - def teardown_class(cls): - metadata.drop_all() - - def test_simple_match(self): - results = ( + def test_simple_match(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( matchtable.select() .where(matchtable.c.title.match("python")) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([2, 5], [r.id for r in results]) - def test_not_match(self): - results = ( + def test_not_match(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( matchtable.select() .where(~matchtable.c.title.match("python")) .order_by(matchtable.c.id) - .execute() - .fetchall() ) eq_([1, 3, 4], [r.id for r in results]) - def test_simple_match_with_apostrophe(self): - results = ( - matchtable.select() - .where(matchtable.c.title.match("Matz's")) - .execute() - .fetchall() - ) + def test_simple_match_with_apostrophe(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( + matchtable.select().where(matchtable.c.title.match("Matz's")) + ).fetchall() eq_([3], [r.id for r in results]) def test_return_value(self, connection): + matchtable = self.tables.matchtable # test [ticket:3263] result = connection.execute( select( @@ -155,8 +149,9 @@ class MatchTest(fixtures.TestBase): ], ) - def test_or_match(self): - results1 = ( + def test_or_match(self, connection): + matchtable = self.tables.matchtable + results1 = connection.execute( matchtable.select() .where( or_( @@ -165,42 +160,37 @@ class MatchTest(fixtures.TestBase): ) ) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([1, 3, 5], [r.id for r in results1]) - results2 = ( + results2 = connection.execute( matchtable.select() .where(matchtable.c.title.match("nutshell ruby")) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([1, 3, 5], [r.id for r in results2]) - def test_and_match(self): - results1 = ( - matchtable.select() - .where( + def test_and_match(self, connection): + matchtable = self.tables.matchtable + results1 = connection.execute( + matchtable.select().where( and_( matchtable.c.title.match("python"), matchtable.c.title.match("nutshell"), ) ) - .execute() - .fetchall() - ) + ).fetchall() eq_([5], [r.id for r in results1]) - results2 = ( - matchtable.select() - .where(matchtable.c.title.match("+python +nutshell")) - .execute() - .fetchall() - ) + results2 = connection.execute( + matchtable.select().where( + matchtable.c.title.match("+python +nutshell") + ) + ).fetchall() eq_([5], [r.id for r in results2]) - def test_match_across_joins(self): - results = ( + def test_match_across_joins(self, connection): + matchtable = self.tables.matchtable + cattable = self.tables.cattable + results = connection.execute( matchtable.select() .where( and_( @@ -212,9 +202,7 @@ class MatchTest(fixtures.TestBase): ) ) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([1, 3, 5], [r.id for r in results]) diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py index 3871dbecca..55d88957a3 100644 --- a/test/dialect/mysql/test_reflection.py +++ b/test/dialect/mysql/test_reflection.py @@ -324,7 +324,8 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): str(reflected.c.c6.server_default.arg).upper(), ) - def test_reflection_with_table_options(self): + @testing.provide_metadata + def test_reflection_with_table_options(self, connection): comment = r"""Comment types type speedily ' " \ '' Fun!""" if testing.against("mariadb"): kwargs = dict( @@ -347,18 +348,15 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): def_table = Table( "mysql_def", - MetaData(), + self.metadata, Column("c1", Integer()), comment=comment, **kwargs ) - with testing.db.connect() as conn: - def_table.create(conn) - try: - reflected = Table("mysql_def", MetaData(), autoload_with=conn) - finally: - def_table.drop(conn) + conn = connection + def_table.create(conn) + reflected = Table("mysql_def", MetaData(), autoload_with=conn) if testing.against("mariadb"): assert def_table.kwargs["mariadb_engine"] == "MEMORY" @@ -554,31 +552,31 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): assert 1 not in list(conn.execute(tbl.select()).first()) @testing.provide_metadata - def test_view_reflection(self): + def test_view_reflection(self, connection): Table( "x", self.metadata, Column("a", Integer), Column("b", String(50)) ) - self.metadata.create_all() + self.metadata.create_all(connection) - with testing.db.connect() as conn: - conn.exec_driver_sql("CREATE VIEW v1 AS SELECT * FROM x") - conn.exec_driver_sql( - "CREATE ALGORITHM=MERGE VIEW v2 AS SELECT * FROM x" - ) - conn.exec_driver_sql( - "CREATE ALGORITHM=UNDEFINED VIEW v3 AS SELECT * FROM x" - ) - conn.exec_driver_sql( - "CREATE DEFINER=CURRENT_USER VIEW v4 AS SELECT * FROM x" - ) + conn = connection + conn.exec_driver_sql("CREATE VIEW v1 AS SELECT * FROM x") + conn.exec_driver_sql( + "CREATE ALGORITHM=MERGE VIEW v2 AS SELECT * FROM x" + ) + conn.exec_driver_sql( + "CREATE ALGORITHM=UNDEFINED VIEW v3 AS SELECT * FROM x" + ) + conn.exec_driver_sql( + "CREATE DEFINER=CURRENT_USER VIEW v4 AS SELECT * FROM x" + ) @event.listens_for(self.metadata, "before_drop") def cleanup(*arg, **kw): - with testing.db.connect() as conn: + with testing.db.begin() as conn: for v in ["v1", "v2", "v3", "v4"]: conn.exec_driver_sql("DROP VIEW %s" % v) - insp = inspect(testing.db) + insp = inspect(connection) for v in ["v1", "v2", "v3", "v4"]: eq_( [ @@ -589,38 +587,36 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): ) @testing.provide_metadata - def test_skip_not_describable(self): + def test_skip_not_describable(self, connection): @event.listens_for(self.metadata, "before_drop") def cleanup(*arg, **kw): - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.exec_driver_sql("DROP TABLE IF EXISTS test_t1") conn.exec_driver_sql("DROP TABLE IF EXISTS test_t2") conn.exec_driver_sql("DROP VIEW IF EXISTS test_v") - with testing.db.connect() as conn: - conn.exec_driver_sql("CREATE TABLE test_t1 (id INTEGER)") - conn.exec_driver_sql("CREATE TABLE test_t2 (id INTEGER)") - conn.exec_driver_sql( - "CREATE VIEW test_v AS SELECT id FROM test_t1" - ) - conn.exec_driver_sql("DROP TABLE test_t1") - - m = MetaData() - with expect_warnings( - "Skipping .* Table or view named .?test_v.? could not be " - "reflected: .* references invalid table" - ): - m.reflect(views=True, bind=conn) - eq_(m.tables["test_t2"].name, "test_t2") - - assert_raises_message( - exc.UnreflectableTableError, - "references invalid table", - Table, - "test_v", - MetaData(), - autoload_with=conn, - ) + conn = connection + conn.exec_driver_sql("CREATE TABLE test_t1 (id INTEGER)") + conn.exec_driver_sql("CREATE TABLE test_t2 (id INTEGER)") + conn.exec_driver_sql("CREATE VIEW test_v AS SELECT id FROM test_t1") + conn.exec_driver_sql("DROP TABLE test_t1") + + m = MetaData() + with expect_warnings( + "Skipping .* Table or view named .?test_v.? could not be " + "reflected: .* references invalid table" + ): + m.reflect(views=True, bind=conn) + eq_(m.tables["test_t2"].name, "test_t2") + + assert_raises_message( + exc.UnreflectableTableError, + "references invalid table", + Table, + "test_v", + MetaData(), + autoload_with=conn, + ) @testing.exclude("mysql", "<", (5, 0, 0), "no information_schema support") def test_system_views(self): @@ -663,7 +659,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): ): Table("nn_t%d" % idx, meta) # to allow DROP - with testing.db.connect() as c: + with testing.db.begin() as c: c.exec_driver_sql( """ CREATE TABLE nn_t%d ( diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index aafad8dc15..9a2174a24a 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -89,6 +89,8 @@ class DefaultSchemaNameTest(fixtures.TestBase): eng = engines.testing_engine() with eng.connect() as conn: + + trans = conn.begin() eq_( testing.db.dialect._get_default_schema_name(conn), default_schema_name, @@ -104,6 +106,7 @@ class DefaultSchemaNameTest(fixtures.TestBase): ) conn.invalidate() + trans.rollback() eq_( testing.db.dialect._get_default_schema_name(conn), @@ -317,53 +320,51 @@ class ComputedReturningTest(fixtures.TablesTest): implicit_returning=False, ) - def test_computed_insert(self): + def test_computed_insert(self, connection): test = self.tables.test - with testing.db.connect() as conn: - result = conn.execute( - test.insert().return_defaults(), {"id": 1, "foo": 5} - ) + conn = connection + result = conn.execute( + test.insert().return_defaults(), {"id": 1, "foo": 5} + ) - eq_(result.returned_defaults, (47,)) + eq_(result.returned_defaults, (47,)) - eq_(conn.scalar(select(test.c.bar)), 47) + eq_(conn.scalar(select(test.c.bar)), 47) - def test_computed_update_warning(self): + def test_computed_update_warning(self, connection): test = self.tables.test - with testing.db.connect() as conn: - conn.execute(test.insert(), {"id": 1, "foo": 5}) + conn = connection + conn.execute(test.insert(), {"id": 1, "foo": 5}) - if testing.db.dialect._supports_update_returning_computed_cols: + if testing.db.dialect._supports_update_returning_computed_cols: + result = conn.execute( + test.update().values(foo=10).return_defaults() + ) + eq_(result.returned_defaults, (52,)) + else: + with testing.expect_warnings( + "Computed columns don't work with Oracle UPDATE" + ): result = conn.execute( test.update().values(foo=10).return_defaults() ) - eq_(result.returned_defaults, (52,)) - else: - with testing.expect_warnings( - "Computed columns don't work with Oracle UPDATE" - ): - result = conn.execute( - test.update().values(foo=10).return_defaults() - ) - # returns the *old* value - eq_(result.returned_defaults, (47,)) + # returns the *old* value + eq_(result.returned_defaults, (47,)) - eq_(conn.scalar(select(test.c.bar)), 52) + eq_(conn.scalar(select(test.c.bar)), 52) - def test_computed_update_no_warning(self): + def test_computed_update_no_warning(self, connection): test = self.tables.test_no_returning - with testing.db.connect() as conn: - conn.execute(test.insert(), {"id": 1, "foo": 5}) + conn = connection + conn.execute(test.insert(), {"id": 1, "foo": 5}) - result = conn.execute( - test.update().values(foo=10).return_defaults() - ) + result = conn.execute(test.update().values(foo=10).return_defaults()) - # no returning - eq_(result.returned_defaults, None) + # no returning + eq_(result.returned_defaults, None) - eq_(conn.scalar(select(test.c.bar)), 52) + eq_(conn.scalar(select(test.c.bar)), 52) class OutParamTest(fixtures.TestBase, AssertsExecutionResults): @@ -372,7 +373,7 @@ class OutParamTest(fixtures.TestBase, AssertsExecutionResults): @classmethod def setup_class(cls): - with testing.db.connect() as c: + with testing.db.begin() as c: c.exec_driver_sql( """ create or replace procedure foo(x_in IN number, x_out OUT number, @@ -404,7 +405,7 @@ end; @classmethod def teardown_class(cls): - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(text("DROP PROCEDURE foo")) @@ -674,7 +675,7 @@ class ExecuteTest(fixtures.TestBase): seq.drop(connection) @testing.provide_metadata - def test_limit_offset_for_update(self): + def test_limit_offset_for_update(self, connection): metadata = self.metadata # oracle can't actually do the ROWNUM thing with FOR UPDATE # very well. @@ -685,19 +686,24 @@ class ExecuteTest(fixtures.TestBase): Column("id", Integer, primary_key=True), Column("data", Integer), ) - metadata.create_all() + metadata.create_all(connection) - t.insert().execute( - {"id": 1, "data": 1}, - {"id": 2, "data": 7}, - {"id": 3, "data": 12}, - {"id": 4, "data": 15}, - {"id": 5, "data": 32}, + connection.execute( + t.insert(), + [ + {"id": 1, "data": 1}, + {"id": 2, "data": 7}, + {"id": 3, "data": 12}, + {"id": 4, "data": 15}, + {"id": 5, "data": 32}, + ], ) # here, we can't use ORDER BY. eq_( - t.select().with_for_update().limit(2).execute().fetchall(), + connection.execute( + t.select().with_for_update().limit(2) + ).fetchall(), [(1, 1), (2, 7)], ) @@ -706,7 +712,8 @@ class ExecuteTest(fixtures.TestBase): assert_raises_message( exc.DatabaseError, "ORA-02014", - t.select().with_for_update().limit(2).offset(3).execute, + connection.execute, + t.select().with_for_update().limit(2).offset(3), ) diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index efa21fc1a3..2e515556f3 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -34,11 +34,6 @@ from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table -def exec_sql(engine, sql, *args, **kwargs): - with engine.connect() as conn: - return conn.exec_driver_sql(sql, *args, **kwargs) - - class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): __only_on__ = "oracle" __backend__ = True @@ -49,62 +44,64 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL): # don't really know how else to go here unless # we connect as the other user. - for stmt in ( - """ -create table %(test_schema)s.parent( - id integer primary key, - data varchar2(50) -); - -COMMENT ON TABLE %(test_schema)s.parent IS 'my table comment'; - -create table %(test_schema)s.child( - id integer primary key, - data varchar2(50), - parent_id integer references %(test_schema)s.parent(id) -); - -create table local_table( - id integer primary key, - data varchar2(50) -); - -create synonym %(test_schema)s.ptable for %(test_schema)s.parent; -create synonym %(test_schema)s.ctable for %(test_schema)s.child; - -create synonym %(test_schema)s_pt for %(test_schema)s.parent; - -create synonym %(test_schema)s.local_table for local_table; - --- can't make a ref from local schema to the --- remote schema's table without this, --- *and* cant give yourself a grant ! --- so we give it to public. ideas welcome. -grant references on %(test_schema)s.parent to public; -grant references on %(test_schema)s.child to public; -""" - % {"test_schema": testing.config.test_schema} - ).split(";"): - if stmt.strip(): - exec_sql(testing.db, stmt) + with testing.db.begin() as conn: + for stmt in ( + """ + create table %(test_schema)s.parent( + id integer primary key, + data varchar2(50) + ); + + COMMENT ON TABLE %(test_schema)s.parent IS 'my table comment'; + + create table %(test_schema)s.child( + id integer primary key, + data varchar2(50), + parent_id integer references %(test_schema)s.parent(id) + ); + + create table local_table( + id integer primary key, + data varchar2(50) + ); + + create synonym %(test_schema)s.ptable for %(test_schema)s.parent; + create synonym %(test_schema)s.ctable for %(test_schema)s.child; + + create synonym %(test_schema)s_pt for %(test_schema)s.parent; + + create synonym %(test_schema)s.local_table for local_table; + + -- can't make a ref from local schema to the + -- remote schema's table without this, + -- *and* cant give yourself a grant ! + -- so we give it to public. ideas welcome. + grant references on %(test_schema)s.parent to public; + grant references on %(test_schema)s.child to public; + """ + % {"test_schema": testing.config.test_schema} + ).split(";"): + if stmt.strip(): + conn.exec_driver_sql(stmt) @classmethod def teardown_class(cls): - for stmt in ( - """ -drop table %(test_schema)s.child; -drop table %(test_schema)s.parent; -drop table local_table; -drop synonym %(test_schema)s.ctable; -drop synonym %(test_schema)s.ptable; -drop synonym %(test_schema)s_pt; -drop synonym %(test_schema)s.local_table; - -""" - % {"test_schema": testing.config.test_schema} - ).split(";"): - if stmt.strip(): - exec_sql(testing.db, stmt) + with testing.db.begin() as conn: + for stmt in ( + """ + drop table %(test_schema)s.child; + drop table %(test_schema)s.parent; + drop table local_table; + drop synonym %(test_schema)s.ctable; + drop synonym %(test_schema)s.ptable; + drop synonym %(test_schema)s_pt; + drop synonym %(test_schema)s.local_table; + + """ + % {"test_schema": testing.config.test_schema} + ).split(";"): + if stmt.strip(): + conn.exec_driver_sql(stmt) @testing.provide_metadata def test_create_same_names_explicit_schema(self): @@ -162,7 +159,7 @@ drop synonym %(test_schema)s.local_table; ) @testing.provide_metadata - def test_create_same_names_implicit_schema(self): + def test_create_same_names_implicit_schema(self, connection): meta = self.metadata parent = Table( "parent", meta, Column("pid", Integer, primary_key=True) @@ -173,10 +170,11 @@ drop synonym %(test_schema)s.local_table; Column("cid", Integer, primary_key=True), Column("pid", Integer, ForeignKey("parent.pid")), ) - meta.create_all() - parent.insert().execute({"pid": 1}) - child.insert().execute({"cid": 1, "pid": 1}) - eq_(child.select().execute().fetchall(), [(1, 1)]) + meta.create_all(connection) + + connection.execute(parent.insert(), {"pid": 1}) + connection.execute(child.insert(), {"cid": 1, "pid": 1}) + eq_(connection.execute(child.select()).fetchall(), [(1, 1)]) def test_reflect_alt_owner_explicit(self): meta = MetaData() @@ -238,9 +236,8 @@ drop synonym %(test_schema)s.local_table; {"text": "my local comment"}, ) - def test_reflect_local_to_remote(self): - exec_sql( - testing.db, + def test_reflect_local_to_remote(self, connection): + connection.exec_driver_sql( "CREATE TABLE localtable (id INTEGER " "PRIMARY KEY, parent_id INTEGER REFERENCES " "%(test_schema)s.parent(id))" @@ -258,7 +255,7 @@ drop synonym %(test_schema)s.local_table; % {"test_schema": testing.config.test_schema}, ) finally: - exec_sql(testing.db, "DROP TABLE localtable") + connection.exec_driver_sql("DROP TABLE localtable") def test_reflect_alt_owner_implicit(self): meta = MetaData() @@ -286,9 +283,8 @@ drop synonym %(test_schema)s.local_table; select(parent, child).select_from(parent.join(child)) ).fetchall() - def test_reflect_alt_owner_synonyms(self): - exec_sql( - testing.db, + def test_reflect_alt_owner_synonyms(self, connection): + connection.exec_driver_sql( "CREATE TABLE localtable (id INTEGER " "PRIMARY KEY, parent_id INTEGER REFERENCES " "%s.ptable(id))" % testing.config.test_schema, @@ -298,7 +294,7 @@ drop synonym %(test_schema)s.local_table; lcl = Table( "localtable", meta, - autoload_with=testing.db, + autoload_with=connection, oracle_resolve_synonyms=True, ) parent = meta.tables["%s.ptable" % testing.config.test_schema] @@ -309,12 +305,11 @@ drop synonym %(test_schema)s.local_table; "localtable.parent_id" % {"test_schema": testing.config.test_schema}, ) - with testing.db.connect() as conn: - conn.execute( - select(parent, lcl).select_from(parent.join(lcl)) - ).fetchall() + connection.execute( + select(parent, lcl).select_from(parent.join(lcl)) + ).fetchall() finally: - exec_sql(testing.db, "DROP TABLE localtable") + connection.exec_driver_sql("DROP TABLE localtable") def test_reflect_remote_synonyms(self): meta = MetaData() @@ -389,19 +384,20 @@ class SystemTableTablenamesTest(fixtures.TestBase): __backend__ = True def setup(self): - exec_sql(testing.db, "create table my_table (id integer)") - exec_sql( - testing.db, - "create global temporary table my_temp_table (id integer)", - ) - exec_sql( - testing.db, "create table foo_table (id integer) tablespace SYSTEM" - ) + with testing.db.begin() as conn: + conn.exec_driver_sql("create table my_table (id integer)") + conn.exec_driver_sql( + "create global temporary table my_temp_table (id integer)", + ) + conn.exec_driver_sql( + "create table foo_table (id integer) tablespace SYSTEM" + ) def teardown(self): - exec_sql(testing.db, "drop table my_temp_table") - exec_sql(testing.db, "drop table my_table") - exec_sql(testing.db, "drop table foo_table") + with testing.db.begin() as conn: + conn.exec_driver_sql("drop table my_temp_table") + conn.exec_driver_sql("drop table my_table") + conn.exec_driver_sql("drop table foo_table") def test_table_names_no_system(self): insp = inspect(testing.db) @@ -430,24 +426,25 @@ class DontReflectIOTTest(fixtures.TestBase): __backend__ = True def setup(self): - exec_sql( - testing.db, - """ - CREATE TABLE admin_docindex( - token char(20), - doc_id NUMBER, - token_frequency NUMBER, - token_offsets VARCHAR2(2000), - CONSTRAINT pk_admin_docindex PRIMARY KEY (token, doc_id)) - ORGANIZATION INDEX - TABLESPACE users - PCTTHRESHOLD 20 - OVERFLOW TABLESPACE users - """, - ) + with testing.db.begin() as conn: + conn.exec_driver_sql( + """ + CREATE TABLE admin_docindex( + token char(20), + doc_id NUMBER, + token_frequency NUMBER, + token_offsets VARCHAR2(2000), + CONSTRAINT pk_admin_docindex PRIMARY KEY (token, doc_id)) + ORGANIZATION INDEX + TABLESPACE users + PCTTHRESHOLD 20 + OVERFLOW TABLESPACE users + """, + ) def teardown(self): - exec_sql(testing.db, "drop table admin_docindex") + with testing.db.begin() as conn: + conn.exec_driver_sql("drop table admin_docindex") def test_reflect_all(self): m = MetaData(testing.db) @@ -456,30 +453,24 @@ class DontReflectIOTTest(fixtures.TestBase): def all_tables_compression_missing(): - try: - exec_sql(testing.db, "SELECT compression FROM all_tables") + with testing.db.connect() as conn: if ( "Enterprise Edition" - not in exec_sql(testing.db, "select * from v$version").scalar() + not in conn.exec_driver_sql("select * from v$version").scalar() # this works in Oracle Database 18c Express Edition Release ) and testing.db.dialect.server_version_info < (18,): return True return False - except Exception: - return True def all_tables_compress_for_missing(): - try: - exec_sql(testing.db, "SELECT compress_for FROM all_tables") + with testing.db.connect() as conn: if ( "Enterprise Edition" - not in exec_sql(testing.db, "select * from v$version").scalar() + not in conn.exec_driver_sql("select * from v$version").scalar() ): return True return False - except Exception: - return True class TableReflectionTest(fixtures.TestBase): @@ -748,7 +739,7 @@ class DBLinkReflectionTest(fixtures.TestBase): # note that the synonym here is still not totally functional # when accessing via a different username as we do with the # multiprocess test suite, so testing here is minimal - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.exec_driver_sql( "create table test_table " "(id integer primary key, data varchar2(50))" @@ -760,7 +751,7 @@ class DBLinkReflectionTest(fixtures.TestBase): @classmethod def teardown_class(cls): - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.exec_driver_sql("drop synonym test_table_syn") conn.exec_driver_sql("drop table test_table") diff --git a/test/dialect/oracle/test_types.py b/test/dialect/oracle/test_types.py index 8fbf374ee5..db3825d137 100644 --- a/test/dialect/oracle/test_types.py +++ b/test/dialect/oracle/test_types.py @@ -228,16 +228,16 @@ class TypesTest(fixtures.TestBase): @testing.requires.returning @testing.provide_metadata - def test_int_not_float(self): + def test_int_not_float(self, connection): m = self.metadata t1 = Table("t1", m, Column("foo", Integer)) - t1.create() - r = t1.insert().values(foo=5).returning(t1.c.foo).execute() + t1.create(connection) + r = connection.execute(t1.insert().values(foo=5).returning(t1.c.foo)) x = r.scalar() assert x == 5 assert isinstance(x, int) - x = t1.select().scalar() + x = connection.scalar(t1.select()) assert x == 5 assert isinstance(x, int) @@ -281,7 +281,7 @@ class TypesTest(fixtures.TestBase): eq_(conn.execute(s3).fetchall(), [(5, rowid)]) @testing.provide_metadata - def test_interval(self): + def test_interval(self, connection): metadata = self.metadata interval_table = Table( "intervaltable", @@ -291,11 +291,12 @@ class TypesTest(fixtures.TestBase): ), Column("day_interval", oracle.INTERVAL(day_precision=3)), ) - metadata.create_all() - interval_table.insert().execute( - day_interval=datetime.timedelta(days=35, seconds=5743) + metadata.create_all(connection) + connection.execute( + interval_table.insert(), + dict(day_interval=datetime.timedelta(days=35, seconds=5743)), ) - row = interval_table.select().execute().first() + row = connection.execute(interval_table.select()).first() eq_(row["day_interval"], datetime.timedelta(days=35, seconds=5743)) @testing.provide_metadata @@ -364,16 +365,19 @@ class TypesTest(fixtures.TestBase): Column("intcol", Integer), Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=False)), ) - t1.create() - t1.insert().execute( + t1.create(connection) + connection.execute( + t1.insert(), [ dict(intcol=1, numericcol=float("inf")), dict(intcol=2, numericcol=float("-inf")), - ] + ], ) eq_( - select(t1.c.numericcol).order_by(t1.c.intcol).execute().fetchall(), + connection.execute( + select(t1.c.numericcol).order_by(t1.c.intcol) + ).fetchall(), [(float("inf"),), (float("-inf"),)], ) @@ -393,16 +397,19 @@ class TypesTest(fixtures.TestBase): Column("intcol", Integer), Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=True)), ) - t1.create() - t1.insert().execute( + t1.create(connection) + connection.execute( + t1.insert(), [ dict(intcol=1, numericcol=decimal.Decimal("Infinity")), dict(intcol=2, numericcol=decimal.Decimal("-Infinity")), - ] + ], ) eq_( - select(t1.c.numericcol).order_by(t1.c.intcol).execute().fetchall(), + connection.execute( + select(t1.c.numericcol).order_by(t1.c.intcol) + ).fetchall(), [(decimal.Decimal("Infinity"),), (decimal.Decimal("-Infinity"),)], ) @@ -422,20 +429,21 @@ class TypesTest(fixtures.TestBase): Column("intcol", Integer), Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=False)), ) - t1.create() - t1.insert().execute( + t1.create(connection) + connection.execute( + t1.insert(), [ dict(intcol=1, numericcol=float("nan")), dict(intcol=2, numericcol=float("-nan")), - ] + ], ) eq_( [ tuple(str(col) for col in row) - for row in select(t1.c.numericcol) - .order_by(t1.c.intcol) - .execute() + for row in connection.execute( + select(t1.c.numericcol).order_by(t1.c.intcol) + ) ], [("nan",), ("nan",)], ) @@ -786,7 +794,7 @@ class TypesTest(fixtures.TestBase): eq_(connection.execute(raw_table.select()).first(), (1, b("ABCDEF"))) @testing.provide_metadata - def test_reflect_nvarchar(self): + def test_reflect_nvarchar(self, connection): metadata = self.metadata Table( "tnv", @@ -794,31 +802,30 @@ class TypesTest(fixtures.TestBase): Column("nv_data", sqltypes.NVARCHAR(255)), Column("c_data", sqltypes.NCHAR(20)), ) - metadata.create_all() + metadata.create_all(connection) m2 = MetaData() - t2 = Table("tnv", m2, autoload_with=testing.db) + t2 = Table("tnv", m2, autoload_with=connection) assert isinstance(t2.c.nv_data.type, sqltypes.NVARCHAR) assert isinstance(t2.c.c_data.type, sqltypes.NCHAR) if testing.against("oracle+cx_oracle"): assert isinstance( - t2.c.nv_data.type.dialect_impl(testing.db.dialect), + t2.c.nv_data.type.dialect_impl(connection.dialect), cx_oracle._OracleUnicodeStringNCHAR, ) assert isinstance( - t2.c.c_data.type.dialect_impl(testing.db.dialect), + t2.c.c_data.type.dialect_impl(connection.dialect), cx_oracle._OracleNChar, ) data = u("m’a réveillé.") - with testing.db.connect() as conn: - conn.execute(t2.insert(), dict(nv_data=data, c_data=data)) - nv_data, c_data = conn.execute(t2.select()).first() - eq_(nv_data, data) - eq_(c_data, data + (" " * 7)) # char is space padded - assert isinstance(nv_data, util.text_type) - assert isinstance(c_data, util.text_type) + connection.execute(t2.insert(), dict(nv_data=data, c_data=data)) + nv_data, c_data = connection.execute(t2.select()).first() + eq_(nv_data, data) + eq_(c_data, data + (" " * 7)) # char is space padded + assert isinstance(nv_data, util.text_type) + assert isinstance(c_data, util.text_type) @testing.provide_metadata def test_reflect_unicode_no_nvarchar(self): @@ -1183,7 +1190,7 @@ class SetInputSizesTest(fixtures.TestBase): else: engine = testing.db - with engine.connect() as conn: + with engine.begin() as conn: connection_fairy = conn.connection for tab in [t1, t2, t3]: with mock.patch.object( diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 5cea604d68..3bd8e9da0b 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -36,6 +36,7 @@ from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_VALUES from sqlalchemy.engine import cursor as _cursor from sqlalchemy.engine import engine_from_config from sqlalchemy.engine import url +from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -51,7 +52,7 @@ from sqlalchemy.testing.assertions import eq_regex from sqlalchemy.testing.assertions import ne_ from sqlalchemy.util import u from sqlalchemy.util import ue -from ...engine import test_execute +from ...engine import test_deprecations if True: from sqlalchemy.dialects.postgresql.psycopg2 import ( @@ -195,6 +196,20 @@ class ExecuteManyMode(object): options = None + @config.fixture() + def connection(self): + eng = engines.testing_engine(options=self.options) + + conn = eng.connect() + trans = conn.begin() + try: + yield conn + finally: + if trans.is_active: + trans.rollback() + conn.close() + eng.dispose() + @classmethod def define_tables(cls, metadata): Table( @@ -213,20 +228,12 @@ class ExecuteManyMode(object): Column(ue("\u6e2c\u8a66"), Integer), ) - def setup(self): - super(ExecuteManyMode, self).setup() - self.engine = engines.testing_engine(options=self.options) - - def teardown(self): - self.engine.dispose() - super(ExecuteManyMode, self).teardown() - - def test_insert(self): + def test_insert(self, connection): from psycopg2 import extras - values_page_size = self.engine.dialect.executemany_values_page_size - batch_page_size = self.engine.dialect.executemany_batch_page_size - if self.engine.dialect.executemany_mode & EXECUTEMANY_VALUES: + values_page_size = connection.dialect.executemany_values_page_size + batch_page_size = connection.dialect.executemany_batch_page_size + if connection.dialect.executemany_mode & EXECUTEMANY_VALUES: meth = extras.execute_values stmt = "INSERT INTO data (x, y) VALUES %s" expected_kwargs = { @@ -234,7 +241,7 @@ class ExecuteManyMode(object): "page_size": values_page_size, "fetch": False, } - elif self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH: + elif connection.dialect.executemany_mode & EXECUTEMANY_BATCH: meth = extras.execute_batch stmt = "INSERT INTO data (x, y) VALUES (%(x)s, %(y)s)" expected_kwargs = {"page_size": batch_page_size} @@ -244,24 +251,23 @@ class ExecuteManyMode(object): with mock.patch.object( extras, meth.__name__, side_effect=meth ) as mock_exec: - with self.engine.connect() as conn: - conn.execute( - self.tables.data.insert(), - [ - {"x": "x1", "y": "y1"}, - {"x": "x2", "y": "y2"}, - {"x": "x3", "y": "y3"}, - ], - ) + connection.execute( + self.tables.data.insert(), + [ + {"x": "x1", "y": "y1"}, + {"x": "x2", "y": "y2"}, + {"x": "x3", "y": "y3"}, + ], + ) - eq_( - conn.execute(select(self.tables.data)).fetchall(), - [ - (1, "x1", "y1", 5), - (2, "x2", "y2", 5), - (3, "x3", "y3", 5), - ], - ) + eq_( + connection.execute(select(self.tables.data)).fetchall(), + [ + (1, "x1", "y1", 5), + (2, "x2", "y2", 5), + (3, "x3", "y3", 5), + ], + ) eq_( mock_exec.mock_calls, [ @@ -278,14 +284,13 @@ class ExecuteManyMode(object): ], ) - def test_insert_no_page_size(self): + def test_insert_no_page_size(self, connection): from psycopg2 import extras - values_page_size = self.engine.dialect.executemany_values_page_size - batch_page_size = self.engine.dialect.executemany_batch_page_size + values_page_size = connection.dialect.executemany_values_page_size + batch_page_size = connection.dialect.executemany_batch_page_size - eng = self.engine - if eng.dialect.executemany_mode & EXECUTEMANY_VALUES: + if connection.dialect.executemany_mode & EXECUTEMANY_VALUES: meth = extras.execute_values stmt = "INSERT INTO data (x, y) VALUES %s" expected_kwargs = { @@ -293,7 +298,7 @@ class ExecuteManyMode(object): "page_size": values_page_size, "fetch": False, } - elif eng.dialect.executemany_mode & EXECUTEMANY_BATCH: + elif connection.dialect.executemany_mode & EXECUTEMANY_BATCH: meth = extras.execute_batch stmt = "INSERT INTO data (x, y) VALUES (%(x)s, %(y)s)" expected_kwargs = {"page_size": batch_page_size} @@ -303,15 +308,14 @@ class ExecuteManyMode(object): with mock.patch.object( extras, meth.__name__, side_effect=meth ) as mock_exec: - with eng.connect() as conn: - conn.execute( - self.tables.data.insert(), - [ - {"x": "x1", "y": "y1"}, - {"x": "x2", "y": "y2"}, - {"x": "x3", "y": "y3"}, - ], - ) + connection.execute( + self.tables.data.insert(), + [ + {"x": "x1", "y": "y1"}, + {"x": "x2", "y": "y2"}, + {"x": "x3", "y": "y3"}, + ], + ) eq_( mock_exec.mock_calls, @@ -356,7 +360,7 @@ class ExecuteManyMode(object): with mock.patch.object( extras, meth.__name__, side_effect=meth ) as mock_exec: - with eng.connect() as conn: + with eng.begin() as conn: conn.execute( self.tables.data.insert(), [ @@ -398,11 +402,10 @@ class ExecuteManyMode(object): eq_(connection.execute(table.select()).all(), [(1, 1), (2, 2), (3, 3)]) - def test_update_fallback(self): + def test_update_fallback(self, connection): from psycopg2 import extras - batch_page_size = self.engine.dialect.executemany_batch_page_size - eng = self.engine + batch_page_size = connection.dialect.executemany_batch_page_size meth = extras.execute_batch stmt = "UPDATE data SET y=%(yval)s WHERE data.x = %(xval)s" expected_kwargs = {"page_size": batch_page_size} @@ -410,18 +413,17 @@ class ExecuteManyMode(object): with mock.patch.object( extras, meth.__name__, side_effect=meth ) as mock_exec: - with eng.connect() as conn: - conn.execute( - self.tables.data.update() - .where(self.tables.data.c.x == bindparam("xval")) - .values(y=bindparam("yval")), - [ - {"xval": "x1", "yval": "y5"}, - {"xval": "x3", "yval": "y6"}, - ], - ) + connection.execute( + self.tables.data.update() + .where(self.tables.data.c.x == bindparam("xval")) + .values(y=bindparam("yval")), + [ + {"xval": "x1", "yval": "y5"}, + {"xval": "x3", "yval": "y6"}, + ], + ) - if eng.dialect.executemany_mode & EXECUTEMANY_BATCH: + if connection.dialect.executemany_mode & EXECUTEMANY_BATCH: eq_( mock_exec.mock_calls, [ @@ -439,36 +441,34 @@ class ExecuteManyMode(object): else: eq_(mock_exec.mock_calls, []) - def test_not_sane_rowcount(self): - self.engine.connect().close() - if self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH: - assert not self.engine.dialect.supports_sane_multi_rowcount + def test_not_sane_rowcount(self, connection): + if connection.dialect.executemany_mode & EXECUTEMANY_BATCH: + assert not connection.dialect.supports_sane_multi_rowcount else: - assert self.engine.dialect.supports_sane_multi_rowcount + assert connection.dialect.supports_sane_multi_rowcount - def test_update(self): - with self.engine.connect() as conn: - conn.execute( - self.tables.data.insert(), - [ - {"x": "x1", "y": "y1"}, - {"x": "x2", "y": "y2"}, - {"x": "x3", "y": "y3"}, - ], - ) + def test_update(self, connection): + connection.execute( + self.tables.data.insert(), + [ + {"x": "x1", "y": "y1"}, + {"x": "x2", "y": "y2"}, + {"x": "x3", "y": "y3"}, + ], + ) - conn.execute( - self.tables.data.update() - .where(self.tables.data.c.x == bindparam("xval")) - .values(y=bindparam("yval")), - [{"xval": "x1", "yval": "y5"}, {"xval": "x3", "yval": "y6"}], - ) - eq_( - conn.execute( - select(self.tables.data).order_by(self.tables.data.c.id) - ).fetchall(), - [(1, "x1", "y5", 5), (2, "x2", "y2", 5), (3, "x3", "y6", 5)], - ) + connection.execute( + self.tables.data.update() + .where(self.tables.data.c.x == bindparam("xval")) + .values(y=bindparam("yval")), + [{"xval": "x1", "yval": "y5"}, {"xval": "x3", "yval": "y6"}], + ) + eq_( + connection.execute( + select(self.tables.data).order_by(self.tables.data.c.id) + ).fetchall(), + [(1, "x1", "y5", 5), (2, "x2", "y2", 5), (3, "x3", "y6", 5)], + ) class ExecutemanyBatchModeTest(ExecuteManyMode, fixtures.TablesTest): @@ -578,7 +578,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): [(pk,) for pk in range(1 + first_pk, total_rows + first_pk)], ) - def test_insert_w_newlines(self): + def test_insert_w_newlines(self, connection): from psycopg2 import extras t = self.tables.data @@ -606,15 +606,14 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): extras, "execute_values", side_effect=meth ) as mock_exec: - with self.engine.connect() as conn: - conn.execute( - ins, - [ - {"id": 1, "y": "y1", "z": 1}, - {"id": 2, "y": "y2", "z": 2}, - {"id": 3, "y": "y3", "z": 3}, - ], - ) + connection.execute( + ins, + [ + {"id": 1, "y": "y1", "z": 1}, + {"id": 2, "y": "y2", "z": 2}, + {"id": 3, "y": "y3", "z": 3}, + ], + ) eq_( mock_exec.mock_calls, @@ -629,12 +628,12 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): ), template="(%(id)s, (SELECT 5 \nFROM data), %(y)s, %(z)s)", fetch=False, - page_size=conn.dialect.executemany_values_page_size, + page_size=connection.dialect.executemany_values_page_size, ) ], ) - def test_insert_modified_by_event(self): + def test_insert_modified_by_event(self, connection): from psycopg2 import extras t = self.tables.data @@ -664,33 +663,33 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest): extras, "execute_batch", side_effect=meth ) as mock_batch: - with self.engine.connect() as conn: - - # create an event hook that will change the statement to - # something else, meaning the dialect has to detect that - # insert_single_values_expr is no longer useful - @event.listens_for(conn, "before_cursor_execute", retval=True) - def before_cursor_execute( - conn, cursor, statement, parameters, context, executemany - ): - statement = ( - "INSERT INTO data (id, y, z) VALUES " - "(%(id)s, %(y)s, %(z)s)" - ) - return statement, parameters - - conn.execute( - ins, - [ - {"id": 1, "y": "y1", "z": 1}, - {"id": 2, "y": "y2", "z": 2}, - {"id": 3, "y": "y3", "z": 3}, - ], + # create an event hook that will change the statement to + # something else, meaning the dialect has to detect that + # insert_single_values_expr is no longer useful + @event.listens_for( + connection, "before_cursor_execute", retval=True + ) + def before_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + statement = ( + "INSERT INTO data (id, y, z) VALUES " + "(%(id)s, %(y)s, %(z)s)" ) + return statement, parameters + + connection.execute( + ins, + [ + {"id": 1, "y": "y1", "z": 1}, + {"id": 2, "y": "y2", "z": 2}, + {"id": 3, "y": "y3", "z": 3}, + ], + ) eq_(mock_values.mock_calls, []) - if self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH: + if connection.dialect.executemany_mode & EXECUTEMANY_BATCH: eq_( mock_batch.mock_calls, [ @@ -727,10 +726,10 @@ class ExecutemanyFlagOptionsTest(fixtures.TablesTest): ("values_only", EXECUTEMANY_VALUES), ("values_plus_batch", EXECUTEMANY_VALUES_PLUS_BATCH), ]: - self.engine = engines.testing_engine( + connection = engines.testing_engine( options={"executemany_mode": opt} ) - is_(self.engine.dialect.executemany_mode, expected) + is_(connection.dialect.executemany_mode, expected) def test_executemany_wrong_flag_options(self): for opt in [1, True, "batch_insert"]: @@ -1082,7 +1081,7 @@ $$ LANGUAGE plpgsql; t.create(connection, checkfirst=True) @testing.provide_metadata - def test_schema_roundtrips(self): + def test_schema_roundtrips(self, connection): meta = self.metadata users = Table( "users", @@ -1091,33 +1090,37 @@ $$ LANGUAGE plpgsql; Column("name", String(50)), schema="test_schema", ) - users.create() - users.insert().execute(id=1, name="name1") - users.insert().execute(id=2, name="name2") - users.insert().execute(id=3, name="name3") - users.insert().execute(id=4, name="name4") + users.create(connection) + connection.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=2, name="name2")) + connection.execute(users.insert(), dict(id=3, name="name3")) + connection.execute(users.insert(), dict(id=4, name="name4")) eq_( - users.select().where(users.c.name == "name2").execute().fetchall(), + connection.execute( + users.select().where(users.c.name == "name2") + ).fetchall(), [(2, "name2")], ) eq_( - users.select(use_labels=True) - .where(users.c.name == "name2") - .execute() - .fetchall(), + connection.execute( + users.select().apply_labels().where(users.c.name == "name2") + ).fetchall(), [(2, "name2")], ) - users.delete().where(users.c.id == 3).execute() + connection.execute(users.delete().where(users.c.id == 3)) eq_( - users.select().where(users.c.name == "name3").execute().fetchall(), + connection.execute( + users.select().where(users.c.name == "name3") + ).fetchall(), [], ) - users.update().where(users.c.name == "name4").execute(name="newname") + connection.execute( + users.update().where(users.c.name == "name4"), dict(name="newname") + ) eq_( - users.select(use_labels=True) - .where(users.c.id == 4) - .execute() - .fetchall(), + connection.execute( + users.select().apply_labels().where(users.c.id == 4) + ).fetchall(), [(4, "newname")], ) @@ -1233,7 +1236,7 @@ $$ LANGUAGE plpgsql; ne_(conn.connection.status, STATUS_IN_TRANSACTION) -class AutocommitTextTest(test_execute.AutocommitTextTest): +class AutocommitTextTest(test_deprecations.AutocommitTextTest): __only_on__ = "postgresql" def test_grant(self): diff --git a/test/dialect/postgresql/test_on_conflict.py b/test/dialect/postgresql/test_on_conflict.py index 7604878426..4e96cc6a21 100644 --- a/test/dialect/postgresql/test_on_conflict.py +++ b/test/dialect/postgresql/test_on_conflict.py @@ -99,28 +99,29 @@ class OnConflictTest(fixtures.TablesTest): ValueError, insert(self.tables.users).on_conflict_do_update ) - def test_on_conflict_do_nothing(self): + def test_on_conflict_do_nothing(self, connection): users = self.tables.users - with testing.db.connect() as conn: - result = conn.execute( - insert(users).on_conflict_do_nothing(), - dict(id=1, name="name1"), - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) - - result = conn.execute( - insert(users).on_conflict_do_nothing(), - dict(id=1, name="name2"), - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + result = connection.execute( + insert(users).on_conflict_do_nothing(), + dict(id=1, name="name1"), + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + result = connection.execute( + insert(users).on_conflict_do_nothing(), + dict(id=1, name="name2"), + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) def test_on_conflict_do_nothing_connectionless(self, connection): users = self.tables.users_xtra @@ -147,95 +148,99 @@ class OnConflictTest(fixtures.TablesTest): ) @testing.provide_metadata - def test_on_conflict_do_nothing_target(self): + def test_on_conflict_do_nothing_target(self, connection): users = self.tables.users - with testing.db.connect() as conn: - result = conn.execute( - insert(users).on_conflict_do_nothing( - index_elements=users.primary_key.columns - ), - dict(id=1, name="name1"), - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) - - result = conn.execute( - insert(users).on_conflict_do_nothing( - index_elements=users.primary_key.columns - ), - dict(id=1, name="name2"), - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) - - def test_on_conflict_do_update_one(self): + result = connection.execute( + insert(users).on_conflict_do_nothing( + index_elements=users.primary_key.columns + ), + dict(id=1, name="name1"), + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + result = connection.execute( + insert(users).on_conflict_do_nothing( + index_elements=users.primary_key.columns + ), + dict(id=1, name="name2"), + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) + + def test_on_conflict_do_update_one(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], set_=dict(name=i.excluded.name) - ) - result = conn.execute(i, dict(id=1, name="name1")) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], set_=dict(name=i.excluded.name) + ) + result = connection.execute(i, dict(id=1, name="name1")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) - def test_on_conflict_do_update_schema(self): + def test_on_conflict_do_update_schema(self, connection): users = self.tables.get("%s.users_schema" % config.test_schema) - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], set_=dict(name=i.excluded.name) - ) - result = conn.execute(i, dict(id=1, name="name1")) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], set_=dict(name=i.excluded.name) + ) + result = connection.execute(i, dict(id=1, name="name1")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) - def test_on_conflict_do_update_column_as_key_set(self): + def test_on_conflict_do_update_column_as_key_set(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], - set_={users.c.name: i.excluded.name}, - ) - result = conn.execute(i, dict(id=1, name="name1")) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: i.excluded.name}, + ) + result = connection.execute(i, dict(id=1, name="name1")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) - def test_on_conflict_do_update_clauseelem_as_key_set(self): + def test_on_conflict_do_update_clauseelem_as_key_set(self, connection): users = self.tables.users class MyElem(object): @@ -245,162 +250,165 @@ class OnConflictTest(fixtures.TablesTest): def __clause_element__(self): return self.expr - with testing.db.connect() as conn: - conn.execute( - users.insert(), - {"id": 1, "name": "name1"}, - ) + connection.execute( + users.insert(), + {"id": 1, "name": "name1"}, + ) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], - set_={MyElem(users.c.name): i.excluded.name}, - ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name1"}) - result = conn.execute(i) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], + set_={MyElem(users.c.name): i.excluded.name}, + ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name1"}) + result = connection.execute(i) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) - def test_on_conflict_do_update_column_as_key_set_schema(self): + def test_on_conflict_do_update_column_as_key_set_schema(self, connection): users = self.tables.get("%s.users_schema" % config.test_schema) - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], - set_={users.c.name: i.excluded.name}, - ) - result = conn.execute(i, dict(id=1, name="name1")) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], + set_={users.c.name: i.excluded.name}, + ) + result = connection.execute(i, dict(id=1, name="name1")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name1")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name1")], + ) - def test_on_conflict_do_update_two(self): + def test_on_conflict_do_update_two(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.id], - set_=dict(id=i.excluded.id, name=i.excluded.name), - ) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.id], + set_=dict(id=i.excluded.id, name=i.excluded.name), + ) - result = conn.execute(i, dict(id=1, name="name2")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + result = connection.execute(i, dict(id=1, name="name2")) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name2")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name2")], + ) - def test_on_conflict_do_update_three(self): + def test_on_conflict_do_update_three(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=users.primary_key.columns, - set_=dict(name=i.excluded.name), - ) - result = conn.execute(i, dict(id=1, name="name3")) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_=dict(name=i.excluded.name), + ) + result = connection.execute(i, dict(id=1, name="name3")) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name3")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name3")], + ) - def test_on_conflict_do_update_four(self): + def test_on_conflict_do_update_four(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=users.primary_key.columns, - set_=dict(id=i.excluded.id, name=i.excluded.name), - ).values(id=1, name="name4") + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_=dict(id=i.excluded.id, name=i.excluded.name), + ).values(id=1, name="name4") - result = conn.execute(i) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + result = connection.execute(i) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name4")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name4")], + ) - def test_on_conflict_do_update_five(self): + def test_on_conflict_do_update_five(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=1, name="name1")) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=users.primary_key.columns, - set_=dict(id=10, name="I'm a name"), - ).values(id=1, name="name4") + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_=dict(id=10, name="I'm a name"), + ).values(id=1, name="name4") - result = conn.execute(i) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) + result = connection.execute(i) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) - eq_( - conn.execute( - users.select().where(users.c.id == 10) - ).fetchall(), - [(10, "I'm a name")], - ) + eq_( + connection.execute( + users.select().where(users.c.id == 10) + ).fetchall(), + [(10, "I'm a name")], + ) - def test_on_conflict_do_update_multivalues(self): + def test_on_conflict_do_update_multivalues(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(id=1, name="name1")) - conn.execute(users.insert(), dict(id=2, name="name2")) - - i = insert(users) - i = i.on_conflict_do_update( - index_elements=users.primary_key.columns, - set_=dict(name="updated"), - where=(i.excluded.name != "name12"), - ).values( - [ - dict(id=1, name="name11"), - dict(id=2, name="name12"), - dict(id=3, name="name13"), - dict(id=4, name="name14"), - ] - ) - - result = conn.execute(i) - eq_(result.inserted_primary_key, (None,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute(users.select().order_by(users.c.id)).fetchall(), - [(1, "updated"), (2, "name2"), (3, "name13"), (4, "name14")], - ) + connection.execute(users.insert(), dict(id=1, name="name1")) + connection.execute(users.insert(), dict(id=2, name="name2")) + + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_=dict(name="updated"), + where=(i.excluded.name != "name12"), + ).values( + [ + dict(id=1, name="name11"), + dict(id=2, name="name12"), + dict(id=3, name="name13"), + dict(id=4, name="name14"), + ] + ) + + result = connection.execute(i) + eq_(result.inserted_primary_key, (None,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute(users.select().order_by(users.c.id)).fetchall(), + [(1, "updated"), (2, "name2"), (3, "name13"), (4, "name14")], + ) def _exotic_targets_fixture(self, conn): users = self.tables.users_xtra @@ -429,260 +437,250 @@ class OnConflictTest(fixtures.TablesTest): [(1, "name1", "name1@gmail.com", "not")], ) - def test_on_conflict_do_update_exotic_targets_two(self): + def test_on_conflict_do_update_exotic_targets_two(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - # try primary key constraint: cause an upsert on unique id column - i = insert(users) - i = i.on_conflict_do_update( - index_elements=users.primary_key.columns, - set_=dict( - name=i.excluded.name, login_email=i.excluded.login_email - ), - ) - result = conn.execute( - i, - dict( - id=1, - name="name2", - login_email="name1@gmail.com", - lets_index_this="not", - ), - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute(users.select().where(users.c.id == 1)).fetchall(), - [(1, "name2", "name1@gmail.com", "not")], - ) - - def test_on_conflict_do_update_exotic_targets_three(self): + self._exotic_targets_fixture(connection) + # try primary key constraint: cause an upsert on unique id column + i = insert(users) + i = i.on_conflict_do_update( + index_elements=users.primary_key.columns, + set_=dict( + name=i.excluded.name, login_email=i.excluded.login_email + ), + ) + result = connection.execute( + i, + dict( + id=1, + name="name2", + login_email="name1@gmail.com", + lets_index_this="not", + ), + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute( + users.select().where(users.c.id == 1) + ).fetchall(), + [(1, "name2", "name1@gmail.com", "not")], + ) + + def test_on_conflict_do_update_exotic_targets_three(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - # try unique constraint: cause an upsert on target - # login_email, not id - i = insert(users) - i = i.on_conflict_do_update( - constraint=self.unique_constraint, - set_=dict( - id=i.excluded.id, - name=i.excluded.name, - login_email=i.excluded.login_email, - ), - ) - # note: lets_index_this value totally ignored in SET clause. - result = conn.execute( - i, - dict( - id=42, - name="nameunique", - login_email="name2@gmail.com", - lets_index_this="unique", - ), - ) - eq_(result.inserted_primary_key, (42,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute( - users.select().where( - users.c.login_email == "name2@gmail.com" - ) - ).fetchall(), - [(42, "nameunique", "name2@gmail.com", "not")], - ) - - def test_on_conflict_do_update_exotic_targets_four(self): + self._exotic_targets_fixture(connection) + # try unique constraint: cause an upsert on target + # login_email, not id + i = insert(users) + i = i.on_conflict_do_update( + constraint=self.unique_constraint, + set_=dict( + id=i.excluded.id, + name=i.excluded.name, + login_email=i.excluded.login_email, + ), + ) + # note: lets_index_this value totally ignored in SET clause. + result = connection.execute( + i, + dict( + id=42, + name="nameunique", + login_email="name2@gmail.com", + lets_index_this="unique", + ), + ) + eq_(result.inserted_primary_key, (42,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute( + users.select().where(users.c.login_email == "name2@gmail.com") + ).fetchall(), + [(42, "nameunique", "name2@gmail.com", "not")], + ) + + def test_on_conflict_do_update_exotic_targets_four(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - # try unique constraint by name: cause an - # upsert on target login_email, not id - i = insert(users) - i = i.on_conflict_do_update( - constraint=self.unique_constraint.name, - set_=dict( - id=i.excluded.id, - name=i.excluded.name, - login_email=i.excluded.login_email, - ), - ) - # note: lets_index_this value totally ignored in SET clause. - - result = conn.execute( - i, - dict( - id=43, - name="nameunique2", - login_email="name2@gmail.com", - lets_index_this="unique", - ), - ) - eq_(result.inserted_primary_key, (43,)) - eq_(result.returned_defaults, None) - - eq_( - conn.execute( - users.select().where( - users.c.login_email == "name2@gmail.com" - ) - ).fetchall(), - [(43, "nameunique2", "name2@gmail.com", "not")], - ) - - def test_on_conflict_do_update_exotic_targets_four_no_pk(self): + self._exotic_targets_fixture(connection) + # try unique constraint by name: cause an + # upsert on target login_email, not id + i = insert(users) + i = i.on_conflict_do_update( + constraint=self.unique_constraint.name, + set_=dict( + id=i.excluded.id, + name=i.excluded.name, + login_email=i.excluded.login_email, + ), + ) + # note: lets_index_this value totally ignored in SET clause. + + result = connection.execute( + i, + dict( + id=43, + name="nameunique2", + login_email="name2@gmail.com", + lets_index_this="unique", + ), + ) + eq_(result.inserted_primary_key, (43,)) + eq_(result.returned_defaults, None) + + eq_( + connection.execute( + users.select().where(users.c.login_email == "name2@gmail.com") + ).fetchall(), + [(43, "nameunique2", "name2@gmail.com", "not")], + ) + + def test_on_conflict_do_update_exotic_targets_four_no_pk(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - # try unique constraint by name: cause an - # upsert on target login_email, not id - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.login_email], - set_=dict( - id=i.excluded.id, - name=i.excluded.name, - login_email=i.excluded.login_email, - ), - ) - - result = conn.execute( - i, dict(name="name3", login_email="name1@gmail.com") - ) - eq_(result.inserted_primary_key, (1,)) - eq_(result.returned_defaults, (1,)) - - eq_( - conn.execute(users.select().order_by(users.c.id)).fetchall(), - [ - (1, "name3", "name1@gmail.com", "not"), - (2, "name2", "name2@gmail.com", "not"), - ], - ) - - def test_on_conflict_do_update_exotic_targets_five(self): + self._exotic_targets_fixture(connection) + # try unique constraint by name: cause an + # upsert on target login_email, not id + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.login_email], + set_=dict( + id=i.excluded.id, + name=i.excluded.name, + login_email=i.excluded.login_email, + ), + ) + + result = connection.execute( + i, dict(name="name3", login_email="name1@gmail.com") + ) + eq_(result.inserted_primary_key, (1,)) + eq_(result.returned_defaults, (1,)) + + eq_( + connection.execute(users.select().order_by(users.c.id)).fetchall(), + [ + (1, "name3", "name1@gmail.com", "not"), + (2, "name2", "name2@gmail.com", "not"), + ], + ) + + def test_on_conflict_do_update_exotic_targets_five(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - # try bogus index - i = insert(users) - i = i.on_conflict_do_update( - index_elements=self.bogus_index.columns, - index_where=self.bogus_index.dialect_options["postgresql"][ - "where" - ], - set_=dict( - name=i.excluded.name, login_email=i.excluded.login_email - ), - ) - - assert_raises( - exc.ProgrammingError, - conn.execute, - i, - dict( - id=1, - name="namebogus", - login_email="bogus@gmail.com", - lets_index_this="bogus", - ), - ) - - def test_on_conflict_do_update_exotic_targets_six(self): + self._exotic_targets_fixture(connection) + # try bogus index + i = insert(users) + i = i.on_conflict_do_update( + index_elements=self.bogus_index.columns, + index_where=self.bogus_index.dialect_options["postgresql"][ + "where" + ], + set_=dict( + name=i.excluded.name, login_email=i.excluded.login_email + ), + ) + + assert_raises( + exc.ProgrammingError, + connection.execute, + i, + dict( + id=1, + name="namebogus", + login_email="bogus@gmail.com", + lets_index_this="bogus", + ), + ) + + def test_on_conflict_do_update_exotic_targets_six(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - conn.execute( - insert(users), + connection.execute( + insert(users), + dict( + id=1, + name="name1", + login_email="mail1@gmail.com", + lets_index_this="unique_name", + ), + ) + + i = insert(users) + i = i.on_conflict_do_update( + index_elements=self.unique_partial_index.columns, + index_where=self.unique_partial_index.dialect_options[ + "postgresql" + ]["where"], + set_=dict( + name=i.excluded.name, login_email=i.excluded.login_email + ), + ) + + connection.execute( + i, + [ dict( - id=1, name="name1", - login_email="mail1@gmail.com", + login_email="mail2@gmail.com", lets_index_this="unique_name", - ), - ) - - i = insert(users) - i = i.on_conflict_do_update( - index_elements=self.unique_partial_index.columns, - index_where=self.unique_partial_index.dialect_options[ - "postgresql" - ]["where"], - set_=dict( - name=i.excluded.name, login_email=i.excluded.login_email - ), - ) - - conn.execute( - i, - [ - dict( - name="name1", - login_email="mail2@gmail.com", - lets_index_this="unique_name", - ) - ], - ) - - eq_( - conn.execute(users.select()).fetchall(), - [(1, "name1", "mail2@gmail.com", "unique_name")], - ) - - def test_on_conflict_do_update_no_row_actually_affected(self): + ) + ], + ) + + eq_( + connection.execute(users.select()).fetchall(), + [(1, "name1", "mail2@gmail.com", "unique_name")], + ) + + def test_on_conflict_do_update_no_row_actually_affected(self, connection): users = self.tables.users_xtra - with testing.db.connect() as conn: - self._exotic_targets_fixture(conn) - i = insert(users) - i = i.on_conflict_do_update( - index_elements=[users.c.login_email], - set_=dict(name="new_name"), - where=(i.excluded.name == "other_name"), - ) - result = conn.execute( - i, dict(name="name2", login_email="name1@gmail.com") - ) - - eq_(result.returned_defaults, None) - eq_(result.inserted_primary_key, None) - - eq_( - conn.execute(users.select()).fetchall(), - [ - (1, "name1", "name1@gmail.com", "not"), - (2, "name2", "name2@gmail.com", "not"), - ], - ) - - def test_on_conflict_do_update_special_types_in_set(self): + self._exotic_targets_fixture(connection) + i = insert(users) + i = i.on_conflict_do_update( + index_elements=[users.c.login_email], + set_=dict(name="new_name"), + where=(i.excluded.name == "other_name"), + ) + result = connection.execute( + i, dict(name="name2", login_email="name1@gmail.com") + ) + + eq_(result.returned_defaults, None) + eq_(result.inserted_primary_key, None) + + eq_( + connection.execute(users.select()).fetchall(), + [ + (1, "name1", "name1@gmail.com", "not"), + (2, "name2", "name2@gmail.com", "not"), + ], + ) + + def test_on_conflict_do_update_special_types_in_set(self, connection): bind_targets = self.tables.bind_targets - with testing.db.connect() as conn: - i = insert(bind_targets) - conn.execute(i, {"id": 1, "data": "initial data"}) - - eq_( - conn.scalar(sql.select(bind_targets.c.data)), - "initial data processed", - ) - - i = insert(bind_targets) - i = i.on_conflict_do_update( - index_elements=[bind_targets.c.id], - set_=dict(data="new updated data"), - ) - conn.execute(i, {"id": 1, "data": "new inserted data"}) - - eq_( - conn.scalar(sql.select(bind_targets.c.data)), - "new updated data processed", - ) + i = insert(bind_targets) + connection.execute(i, {"id": 1, "data": "initial data"}) + + eq_( + connection.scalar(sql.select(bind_targets.c.data)), + "initial data processed", + ) + + i = insert(bind_targets) + i = i.on_conflict_do_update( + index_elements=[bind_targets.c.id], + set_=dict(data="new updated data"), + ) + connection.execute(i, {"id": 1, "data": "new inserted data"}) + + eq_( + connection.scalar(sql.select(bind_targets.c.data)), + "new updated data processed", + ) diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index c959acf359..94af168eee 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -35,30 +35,32 @@ from sqlalchemy.testing.assertsql import CursorSQL from sqlalchemy.testing.assertsql import DialectSQL -matchtable = cattable = None - - class InsertTest(fixtures.TestBase, AssertsExecutionResults): __only_on__ = "postgresql" __backend__ = True - @classmethod - def setup_class(cls): - cls.metadata = MetaData(testing.db) + def setup(self): + self.metadata = MetaData() def teardown(self): - self.metadata.drop_all() - self.metadata.clear() + with testing.db.begin() as conn: + self.metadata.drop_all(conn) + + @testing.combinations((False,), (True,)) + def test_foreignkey_missing_insert(self, implicit_returning): + engine = engines.testing_engine( + options={"implicit_returning": implicit_returning} + ) - def test_foreignkey_missing_insert(self): Table("t1", self.metadata, Column("id", Integer, primary_key=True)) t2 = Table( "t2", self.metadata, Column("id", Integer, ForeignKey("t1.id"), primary_key=True), ) - self.metadata.create_all() + + self.metadata.create_all(engine) # want to ensure that "null value in column "id" violates not- # null constraint" is raised (IntegrityError on psycoopg2, but @@ -67,19 +69,13 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): # the latter corresponds to autoincrement behavior, which is not # the case here due to the foreign key. - for eng in [ - engines.testing_engine(options={"implicit_returning": False}), - engines.testing_engine(options={"implicit_returning": True}), - ]: - with expect_warnings( - ".*has no Python-side or server-side default.*" - ): - with eng.connect() as conn: - assert_raises( - (exc.IntegrityError, exc.ProgrammingError), - conn.execute, - t2.insert(), - ) + with expect_warnings(".*has no Python-side or server-side default.*"): + with engine.begin() as conn: + assert_raises( + (exc.IntegrityError, exc.ProgrammingError), + conn.execute, + t2.insert(), + ) def test_sequence_insert(self): table = Table( @@ -88,7 +84,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): Column("id", Integer, Sequence("my_seq"), primary_key=True), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_with_sequence(table, "my_seq") @testing.requires.returning @@ -99,7 +95,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): Column("id", Integer, Sequence("my_seq"), primary_key=True), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_with_sequence_returning(table, "my_seq") def test_opt_sequence_insert(self): @@ -114,7 +110,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_autoincrement(table) @testing.requires.returning @@ -130,7 +126,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_autoincrement_returning(table) def test_autoincrement_insert(self): @@ -140,7 +136,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): Column("id", Integer, primary_key=True), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_autoincrement(table) @testing.requires.returning @@ -151,7 +147,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): Column("id", Integer, primary_key=True), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_autoincrement_returning(table) def test_noautoincrement_insert(self): @@ -161,7 +157,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): Column("id", Integer, primary_key=True, autoincrement=False), Column("data", String(30)), ) - self.metadata.create_all() + self.metadata.create_all(testing.db) self._assert_data_noautoincrement(table) def _assert_data_autoincrement(self, table): @@ -169,7 +165,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: # execute with explicit id r = conn.execute(table.insert(), {"id": 30, "data": "d1"}) @@ -226,7 +222,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -250,7 +246,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table = Table(table.name, m2, autoload_with=engine) with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) r = conn.execute(table.insert(), {"data": "d2"}) eq_(r.inserted_primary_key, (5,)) @@ -288,7 +284,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): "INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}] ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -308,7 +304,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): engine = engines.testing_engine(options={"implicit_returning": True}) with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: # execute with explicit id @@ -367,7 +363,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -390,7 +386,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table = Table(table.name, m2, autoload_with=engine) with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) r = conn.execute(table.insert(), {"data": "d2"}) eq_(r.inserted_primary_key, (5,)) @@ -430,7 +426,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -450,7 +446,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): engine = engines.testing_engine(options={"implicit_returning": False}) with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) conn.execute(table.insert(), {"data": "d2"}) conn.execute( @@ -491,7 +487,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): [{"data": "d8"}], ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -513,7 +509,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): engine = engines.testing_engine(options={"implicit_returning": True}) with self.sql_execution_asserter(engine) as asserter: - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) conn.execute(table.insert(), {"data": "d2"}) conn.execute( @@ -555,7 +551,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ), ) - with engine.connect() as conn: + with engine.begin() as conn: eq_( conn.execute(table.select()).fetchall(), [ @@ -578,9 +574,12 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): # turning off the cache because we are checking for compile-time # warnings - with engine.connect().execution_options(compiled_cache=None) as conn: + engine = engine.execution_options(compiled_cache=None) + + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -590,6 +589,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table.insert(), {"data": "d2"}, ) + + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -599,6 +600,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table.insert(), [{"data": "d2"}, {"data": "d3"}], ) + + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -608,6 +611,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table.insert(), {"data": "d2"}, ) + + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -618,6 +623,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): [{"data": "d2"}, {"data": "d3"}], ) + with engine.begin() as conn: conn.execute( table.insert(), [{"id": 31, "data": "d2"}, {"id": 32, "data": "d3"}], @@ -634,9 +640,10 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): m2 = MetaData() table = Table(table.name, m2, autoload_with=engine) - with engine.connect() as conn: + with engine.begin() as conn: conn.execute(table.insert(), {"id": 30, "data": "d1"}) + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -646,6 +653,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table.insert(), {"data": "d2"}, ) + + with engine.begin() as conn: with expect_warnings( ".*has no Python-side or server-side default.*" ): @@ -655,6 +664,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): table.insert(), [{"data": "d2"}, {"data": "d3"}], ) + + with engine.begin() as conn: conn.execute( table.insert(), [{"id": 31, "data": "d2"}, {"id": 32, "data": "d3"}], @@ -666,36 +677,40 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults): ) -class MatchTest(fixtures.TestBase, AssertsCompiledSQL): +class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): __only_on__ = "postgresql >= 8.3" __backend__ = True @classmethod - def setup_class(cls): - global metadata, cattable, matchtable - metadata = MetaData(testing.db) - cattable = Table( + def define_tables(cls, metadata): + Table( "cattable", metadata, Column("id", Integer, primary_key=True), Column("description", String(50)), ) - matchtable = Table( + Table( "matchtable", metadata, Column("id", Integer, primary_key=True), Column("title", String(200)), Column("category_id", Integer, ForeignKey("cattable.id")), ) - metadata.create_all() - cattable.insert().execute( + + @classmethod + def insert_data(cls, connection): + cattable, matchtable = cls.tables("cattable", "matchtable") + + connection.execute( + cattable.insert(), [ {"id": 1, "description": "Python"}, {"id": 2, "description": "Ruby"}, - ] + ], ) - matchtable.insert().execute( + connection.execute( + matchtable.insert(), [ { "id": 1, @@ -714,15 +729,12 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): "category_id": 1, }, {"id": 5, "title": "Python in a Nutshell", "category_id": 1}, - ] + ], ) - @classmethod - def teardown_class(cls): - metadata.drop_all() - @testing.requires.pyformat_paramstyle def test_expression_pyformat(self): + matchtable = self.tables.matchtable self.assert_compile( matchtable.c.title.match("somstr"), "matchtable.title @@ to_tsquery(%(title_1)s" ")", @@ -730,51 +742,47 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): @testing.requires.format_paramstyle def test_expression_positional(self): + matchtable = self.tables.matchtable self.assert_compile( matchtable.c.title.match("somstr"), "matchtable.title @@ to_tsquery(%s)", ) - def test_simple_match(self): - results = ( + def test_simple_match(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( matchtable.select() .where(matchtable.c.title.match("python")) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([2, 5], [r.id for r in results]) - def test_not_match(self): - results = ( + def test_not_match(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( matchtable.select() .where(~matchtable.c.title.match("python")) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([1, 3, 4], [r.id for r in results]) - def test_simple_match_with_apostrophe(self): - results = ( - matchtable.select() - .where(matchtable.c.title.match("Matz's")) - .execute() - .fetchall() - ) + def test_simple_match_with_apostrophe(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( + matchtable.select().where(matchtable.c.title.match("Matz's")) + ).fetchall() eq_([3], [r.id for r in results]) - def test_simple_derivative_match(self): - results = ( - matchtable.select() - .where(matchtable.c.title.match("nutshells")) - .execute() - .fetchall() - ) + def test_simple_derivative_match(self, connection): + matchtable = self.tables.matchtable + results = connection.execute( + matchtable.select().where(matchtable.c.title.match("nutshells")) + ).fetchall() eq_([5], [r.id for r in results]) - def test_or_match(self): - results1 = ( + def test_or_match(self, connection): + matchtable = self.tables.matchtable + results1 = connection.execute( matchtable.select() .where( or_( @@ -783,42 +791,36 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): ) ) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([3, 5], [r.id for r in results1]) - results2 = ( + results2 = connection.execute( matchtable.select() .where(matchtable.c.title.match("nutshells | rubies")) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([3, 5], [r.id for r in results2]) - def test_and_match(self): - results1 = ( - matchtable.select() - .where( + def test_and_match(self, connection): + matchtable = self.tables.matchtable + results1 = connection.execute( + matchtable.select().where( and_( matchtable.c.title.match("python"), matchtable.c.title.match("nutshells"), ) ) - .execute() - .fetchall() - ) + ).fetchall() eq_([5], [r.id for r in results1]) - results2 = ( - matchtable.select() - .where(matchtable.c.title.match("python & nutshells")) - .execute() - .fetchall() - ) + results2 = connection.execute( + matchtable.select().where( + matchtable.c.title.match("python & nutshells") + ) + ).fetchall() eq_([5], [r.id for r in results2]) - def test_match_across_joins(self): - results = ( + def test_match_across_joins(self, connection): + cattable, matchtable = self.tables("cattable", "matchtable") + results = connection.execute( matchtable.select() .where( and_( @@ -830,9 +832,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL): ) ) .order_by(matchtable.c.id) - .execute() - .fetchall() - ) + ).fetchall() eq_([1, 3, 5], [r.id for r in results]) diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 4de4d88e31..824f6cd36d 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -291,63 +291,64 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): @classmethod def setup_class(cls): - con = testing.db.connect() - for ddl in [ - 'CREATE SCHEMA "SomeSchema"', - "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42", - "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0", - "CREATE TYPE testtype AS ENUM ('test')", - "CREATE DOMAIN enumdomain AS testtype", - "CREATE DOMAIN arraydomain AS INTEGER[]", - 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0', - ]: - try: - con.exec_driver_sql(ddl) - except exc.DBAPIError as e: - if "already exists" not in str(e): - raise e - con.exec_driver_sql( - "CREATE TABLE testtable (question integer, answer " "testdomain)" - ) - con.exec_driver_sql( - "CREATE TABLE test_schema.testtable(question " - "integer, answer test_schema.testdomain, anything " - "integer)" - ) - con.exec_driver_sql( - "CREATE TABLE crosschema (question integer, answer " - "test_schema.testdomain)" - ) + with testing.db.begin() as con: + for ddl in [ + 'CREATE SCHEMA "SomeSchema"', + "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42", + "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0", + "CREATE TYPE testtype AS ENUM ('test')", + "CREATE DOMAIN enumdomain AS testtype", + "CREATE DOMAIN arraydomain AS INTEGER[]", + 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0', + ]: + try: + con.exec_driver_sql(ddl) + except exc.DBAPIError as e: + if "already exists" not in str(e): + raise e + con.exec_driver_sql( + "CREATE TABLE testtable (question integer, answer " + "testdomain)" + ) + con.exec_driver_sql( + "CREATE TABLE test_schema.testtable(question " + "integer, answer test_schema.testdomain, anything " + "integer)" + ) + con.exec_driver_sql( + "CREATE TABLE crosschema (question integer, answer " + "test_schema.testdomain)" + ) - con.exec_driver_sql( - "CREATE TABLE enum_test (id integer, data enumdomain)" - ) + con.exec_driver_sql( + "CREATE TABLE enum_test (id integer, data enumdomain)" + ) - con.exec_driver_sql( - "CREATE TABLE array_test (id integer, data arraydomain)" - ) + con.exec_driver_sql( + "CREATE TABLE array_test (id integer, data arraydomain)" + ) - con.exec_driver_sql( - "CREATE TABLE quote_test " - '(id integer, data "SomeSchema"."Quoted.Domain")' - ) + con.exec_driver_sql( + "CREATE TABLE quote_test " + '(id integer, data "SomeSchema"."Quoted.Domain")' + ) @classmethod def teardown_class(cls): - con = testing.db.connect() - con.exec_driver_sql("DROP TABLE testtable") - con.exec_driver_sql("DROP TABLE test_schema.testtable") - con.exec_driver_sql("DROP TABLE crosschema") - con.exec_driver_sql("DROP TABLE quote_test") - con.exec_driver_sql("DROP DOMAIN testdomain") - con.exec_driver_sql("DROP DOMAIN test_schema.testdomain") - con.exec_driver_sql("DROP TABLE enum_test") - con.exec_driver_sql("DROP DOMAIN enumdomain") - con.exec_driver_sql("DROP TYPE testtype") - con.exec_driver_sql("DROP TABLE array_test") - con.exec_driver_sql("DROP DOMAIN arraydomain") - con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"') - con.exec_driver_sql('DROP SCHEMA "SomeSchema"') + with testing.db.begin() as con: + con.exec_driver_sql("DROP TABLE testtable") + con.exec_driver_sql("DROP TABLE test_schema.testtable") + con.exec_driver_sql("DROP TABLE crosschema") + con.exec_driver_sql("DROP TABLE quote_test") + con.exec_driver_sql("DROP DOMAIN testdomain") + con.exec_driver_sql("DROP DOMAIN test_schema.testdomain") + con.exec_driver_sql("DROP TABLE enum_test") + con.exec_driver_sql("DROP DOMAIN enumdomain") + con.exec_driver_sql("DROP TYPE testtype") + con.exec_driver_sql("DROP TABLE array_test") + con.exec_driver_sql("DROP DOMAIN arraydomain") + con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"') + con.exec_driver_sql('DROP SCHEMA "SomeSchema"') def test_table_is_reflected(self): metadata = MetaData() @@ -486,7 +487,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("ref", Integer, ForeignKey("subject.id$")), ) - meta1.create_all() + meta1.create_all(testing.db) meta2 = MetaData() subject = Table("subject", meta2, autoload_with=testing.db) referer = Table("referer", meta2, autoload_with=testing.db) @@ -523,9 +524,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): with testing.db.begin() as conn: r = conn.execute(t2.insert()) eq_(r.inserted_primary_key, (1,)) - testing.db.connect().execution_options( - autocommit=True - ).exec_driver_sql("alter table t_id_seq rename to foobar_id_seq") + + with testing.db.begin() as conn: + conn.exec_driver_sql( + "alter table t_id_seq rename to foobar_id_seq" + ) m3 = MetaData() t3 = Table("t", m3, autoload_with=testing.db, implicit_returning=False) eq_( @@ -545,10 +548,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all() - testing.db.connect().execution_options( - autocommit=True - ).exec_driver_sql("alter table t alter column id type varchar(50)") + metadata.create_all(testing.db) + + with testing.db.begin() as conn: + conn.exec_driver_sql( + "alter table t alter column id type varchar(50)" + ) m2 = MetaData() t2 = Table("t", m2, autoload_with=testing.db) eq_(t2.c.id.autoincrement, False) @@ -558,10 +563,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): def test_renamed_pk_reflection(self): metadata = self.metadata Table("t", metadata, Column("id", Integer, primary_key=True)) - metadata.create_all() - testing.db.connect().execution_options( - autocommit=True - ).exec_driver_sql("alter table t rename id to t_id") + metadata.create_all(testing.db) + with testing.db.begin() as conn: + conn.exec_driver_sql("alter table t rename id to t_id") m2 = MetaData() t2 = Table("t", m2, autoload_with=testing.db) eq_([c.name for c in t2.primary_key], ["t_id"]) @@ -936,13 +940,13 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("name", String(20), index=True), Column("aname", String(20)), ) - metadata.create_all() - with testing.db.connect() as c: - c.exec_driver_sql("create index idx1 on party ((id || name))") - c.exec_driver_sql( + metadata.create_all(testing.db) + with testing.db.begin() as conn: + conn.exec_driver_sql("create index idx1 on party ((id || name))") + conn.exec_driver_sql( "create unique index idx2 on party (id) where name = 'test'" ) - c.exec_driver_sql( + conn.exec_driver_sql( """ create index idx3 on party using btree (lower(name::text), lower(aname::text)) @@ -1029,7 +1033,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("aname", String(20)), ) - with testing.db.connect() as conn: + with testing.db.begin() as conn: t1.create(conn) @@ -1109,18 +1113,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all() - conn = testing.db.connect().execution_options(autocommit=True) - conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)") - conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y") + metadata.create_all(testing.db) + with testing.db.begin() as conn: + conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)") + conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y") - ind = testing.db.dialect.get_indexes(conn, "t", None) - expected = [{"name": "idx1", "unique": False, "column_names": ["y"]}] - if testing.requires.index_reflects_included_columns.enabled: - expected[0]["include_columns"] = [] + ind = testing.db.dialect.get_indexes(conn, "t", None) + expected = [ + {"name": "idx1", "unique": False, "column_names": ["y"]} + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] - eq_(ind, expected) - conn.close() + eq_(ind, expected) @testing.fails_if("postgresql < 8.2", "reloptions not supported") @testing.provide_metadata @@ -1135,9 +1140,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", Integer), ) - metadata.create_all() + metadata.create_all(testing.db) - with testing.db.connect().execution_options(autocommit=True) as conn: + with testing.db.begin() as conn: conn.exec_driver_sql( "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)" ) @@ -1177,8 +1182,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("id", Integer, primary_key=True), Column("x", ARRAY(Integer)), ) - metadata.create_all() - with testing.db.connect().execution_options(autocommit=True) as conn: + metadata.create_all(testing.db) + with testing.db.begin() as conn: conn.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)") ind = testing.db.dialect.get_indexes(conn, "t", None) @@ -1215,7 +1220,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): Column("name", String(20)), ) metadata.create_all() - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.exec_driver_sql("CREATE INDEX idx1 ON t (x) INCLUDE (name)") # prior to #5205, this would return: @@ -1312,8 +1317,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): eq_(fk, fk_ref[fk["name"]]) @testing.provide_metadata - def test_inspect_enums_schema(self): - conn = testing.db.connect() + def test_inspect_enums_schema(self, connection): enum_type = postgresql.ENUM( "sad", "ok", @@ -1322,8 +1326,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): schema="test_schema", metadata=self.metadata, ) - enum_type.create(conn) - inspector = inspect(conn) + enum_type.create(connection) + inspector = inspect(connection) eq_( inspector.get_enums("test_schema"), [ diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index e7174f234a..ae7a65a3af 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -206,7 +206,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ), schema=symbol_name, ) - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn = conn.execution_options( schema_translate_map={symbol_name: testing.config.test_schema} ) diff --git a/test/dialect/test_mxodbc.py b/test/dialect/test_mxodbc.py index de8b22b67c..cd8768d73b 100644 --- a/test/dialect/test_mxodbc.py +++ b/test/dialect/test_mxodbc.py @@ -30,34 +30,37 @@ class MxODBCTest(fixtures.TestBase): ) conn = engine.connect() - # crud: uses execute - conn.execute(t1.insert().values(c1="foo")) - conn.execute(t1.delete().where(t1.c.c1 == "foo")) - conn.execute(t1.update().where(t1.c.c1 == "foo").values(c1="bar")) - - # select: uses executedirect - conn.execute(t1.select()) - - # manual flagging - conn.execution_options(native_odbc_execute=True).execute(t1.select()) - conn.execution_options(native_odbc_execute=False).execute( - t1.insert().values(c1="foo") - ) + with conn.begin(): + # crud: uses execute + conn.execute(t1.insert().values(c1="foo")) + conn.execute(t1.delete().where(t1.c.c1 == "foo")) + conn.execute(t1.update().where(t1.c.c1 == "foo").values(c1="bar")) - eq_( - # fmt: off - [ - c[2] - for c in dbapi.connect.return_value.cursor. - return_value.execute.mock_calls - ], - # fmt: on - [ - {"direct": True}, - {"direct": True}, - {"direct": True}, - {"direct": True}, - {"direct": False}, - {"direct": True}, - ] - ) + # select: uses executedirect + conn.execute(t1.select()) + + # manual flagging + conn.execution_options(native_odbc_execute=True).execute( + t1.select() + ) + conn.execution_options(native_odbc_execute=False).execute( + t1.insert().values(c1="foo") + ) + + eq_( + # fmt: off + [ + c[2] + for c in dbapi.connect.return_value.cursor. + return_value.execute.mock_calls + ], + # fmt: on + [ + {"direct": True}, + {"direct": True}, + {"direct": True}, + {"direct": True}, + {"direct": False}, + {"direct": True}, + ] + ) diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index f8b50f8883..12200f832f 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -63,8 +63,9 @@ from sqlalchemy.util import ue def exec_sql(engine, sql, *args, **kwargs): - conn = engine.connect(close_with_result=True) - return conn.exec_driver_sql(sql, *args, **kwargs) + # TODO: convert all tests to not use this + with engine.begin() as conn: + conn.exec_driver_sql(sql, *args, **kwargs) class TestTypes(fixtures.TestBase, AssertsExecutionResults): @@ -189,11 +190,13 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): connection.execute( t.insert().values(d=datetime.datetime(2010, 10, 15, 12, 37, 0)) ) - exec_sql( - testing.db, "insert into t (d) values ('2004-05-21T00:00:00')" + connection.exec_driver_sql( + "insert into t (d) values ('2004-05-21T00:00:00')" ) eq_( - exec_sql(testing.db, "select * from t order by d").fetchall(), + connection.exec_driver_sql( + "select * from t order by d" + ).fetchall(), [("2004-05-21T00:00:00",), ("2010-10-15T12:37:00",)], ) eq_( @@ -216,9 +219,13 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): connection.execute( t.insert().values(d=datetime.datetime(2010, 10, 15, 12, 37, 0)) ) - exec_sql(testing.db, "insert into t (d) values ('20040521000000')") + connection.exec_driver_sql( + "insert into t (d) values ('20040521000000')" + ) eq_( - exec_sql(testing.db, "select * from t order by d").fetchall(), + connection.exec_driver_sql( + "select * from t order by d" + ).fetchall(), [("20040521000000",), ("20101015123700",)], ) eq_( @@ -238,9 +245,11 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): t = Table("t", self.metadata, Column("d", sqlite_date)) self.metadata.create_all(connection) connection.execute(t.insert().values(d=datetime.date(2010, 10, 15))) - exec_sql(testing.db, "insert into t (d) values ('20040521')") + connection.exec_driver_sql("insert into t (d) values ('20040521')") eq_( - exec_sql(testing.db, "select * from t order by d").fetchall(), + connection.exec_driver_sql( + "select * from t order by d" + ).fetchall(), [("20040521",), ("20101015",)], ) eq_( @@ -256,11 +265,15 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults): regexp=r"(\d+)\|(\d+)\|(\d+)", ) t = Table("t", self.metadata, Column("d", sqlite_date)) - self.metadata.create_all(testing.db) + self.metadata.create_all(connection) connection.execute(t.insert().values(d=datetime.date(2010, 10, 15))) - exec_sql(testing.db, "insert into t (d) values ('2004|05|21')") + + connection.exec_driver_sql("insert into t (d) values ('2004|05|21')") + eq_( - exec_sql(testing.db, "select * from t order by d").fetchall(), + connection.exec_driver_sql( + "select * from t order by d" + ).fetchall(), [("2004|05|21",), ("2010|10|15",)], ) eq_( @@ -313,7 +326,7 @@ class JSONTest(fixtures.TestBase): value = {"json": {"foo": "bar"}, "recs": ["one", "two"]} - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(sqlite_json.insert(), foo=value) eq_(conn.scalar(select(sqlite_json.c.foo)), value) @@ -328,7 +341,7 @@ class JSONTest(fixtures.TestBase): value = {"json": {"foo": "bar"}} - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(sqlite_json.insert(), foo=value) eq_(conn.scalar(select(sqlite_json.c.foo["json"])), value["json"]) @@ -551,7 +564,7 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL): Column("x", Boolean, server_default=sql.false()), ) t.create(testing.db) - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(t.insert()) conn.execute(t.insert().values(x=True)) eq_( @@ -568,7 +581,7 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL): Column("x", DateTime(), server_default=func.now()), ) t.create(testing.db) - with testing.db.connect() as conn: + with testing.db.begin() as conn: now = conn.scalar(func.now()) today = datetime.datetime.today() conn.execute(t.insert()) @@ -587,7 +600,7 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL): Column("x", Integer(), server_default=func.abs(-5) + 17), ) t.create(testing.db) - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(t.insert()) conn.execute(t.insert().values(x=35)) eq_( @@ -622,7 +635,8 @@ class DialectTest( ) ) - def test_extra_reserved_words(self): + @testing.provide_metadata + def test_extra_reserved_words(self, connection): """Tests reserved words in identifiers. 'true', 'false', and 'column' are undocumented reserved words @@ -630,22 +644,19 @@ class DialectTest( here to ensure they remain in place if the dialect's reserved_words set is updated in the future.""" - meta = MetaData(testing.db) t = Table( "reserved", - meta, + self.metadata, Column("safe", Integer), Column("true", Integer), Column("false", Integer), Column("column", Integer), Column("exists", Integer), ) - try: - meta.create_all() - t.insert().execute(safe=1) - list(t.select().execute()) - finally: - meta.drop_all() + self.metadata.create_all(connection) + connection.execute(t.insert(), dict(safe=1)) + result = connection.execute(t.select()) + eq_(list(result), [(1, None, None, None, None)]) @testing.provide_metadata def test_quoted_identifiers_functional_one(self): @@ -827,7 +838,8 @@ class AttachedDBTest(fixtures.TestBase): schema="test_schema", ) - meta.create_all(self.conn) + with self.conn.begin(): + meta.create_all(self.conn) return ct def setup(self): @@ -835,7 +847,8 @@ class AttachedDBTest(fixtures.TestBase): self.metadata = MetaData() def teardown(self): - self.metadata.drop_all(self.conn) + with self.conn.begin(): + self.metadata.drop_all(self.conn) self.conn.close() def test_no_tables(self): @@ -928,18 +941,20 @@ class AttachedDBTest(fixtures.TestBase): def test_crud(self): ct = self._fixture() - self.conn.execute(ct.insert(), {"id": 1, "name": "foo"}) - eq_(self.conn.execute(ct.select()).fetchall(), [(1, "foo")]) + with self.conn.begin(): + self.conn.execute(ct.insert(), {"id": 1, "name": "foo"}) + eq_(self.conn.execute(ct.select()).fetchall(), [(1, "foo")]) - self.conn.execute(ct.update(), {"id": 2, "name": "bar"}) - eq_(self.conn.execute(ct.select()).fetchall(), [(2, "bar")]) - self.conn.execute(ct.delete()) - eq_(self.conn.execute(ct.select()).fetchall(), []) + self.conn.execute(ct.update(), {"id": 2, "name": "bar"}) + eq_(self.conn.execute(ct.select()).fetchall(), [(2, "bar")]) + self.conn.execute(ct.delete()) + eq_(self.conn.execute(ct.select()).fetchall(), []) def test_col_targeting(self): ct = self._fixture() - self.conn.execute(ct.insert(), {"id": 1, "name": "foo"}) + with self.conn.begin(): + self.conn.execute(ct.insert(), {"id": 1, "name": "foo"}) row = self.conn.execute(ct.select()).first() eq_(row._mapping["id"], 1) eq_(row._mapping["name"], "foo") @@ -947,7 +962,8 @@ class AttachedDBTest(fixtures.TestBase): def test_col_targeting_union(self): ct = self._fixture() - self.conn.execute(ct.insert(), {"id": 1, "name": "foo"}) + with self.conn.begin(): + self.conn.execute(ct.insert(), {"id": 1, "name": "foo"}) row = self.conn.execute(ct.select().union(ct.select())).first() eq_(row._mapping["id"], 1) eq_(row._mapping["name"], "foo") @@ -2236,7 +2252,7 @@ class ConstraintReflectionTest(fixtures.TestBase): ) def test_foreign_key_options_unnamed_inline(self): - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.exec_driver_sql( "create table foo (id integer, " "foreign key (id) references bar (id) on update cascade)" @@ -2571,33 +2587,33 @@ class TypeReflectionTest(fixtures.TestBase): def _test_round_trip(self, fixture, warnings=False): from sqlalchemy import inspect - conn = testing.db.connect() for from_, to_ in self._fixture_as_string(fixture): - inspector = inspect(conn) - conn.exec_driver_sql("CREATE TABLE foo (data %s)" % from_) - try: - if warnings: + with testing.db.begin() as conn: + inspector = inspect(conn) + conn.exec_driver_sql("CREATE TABLE foo (data %s)" % from_) + try: + if warnings: - def go(): - return inspector.get_columns("foo")[0] + def go(): + return inspector.get_columns("foo")[0] - col_info = testing.assert_warnings( - go, ["Could not instantiate"], regex=True - ) - else: - col_info = inspector.get_columns("foo")[0] - expected_type = type(to_) - is_(type(col_info["type"]), expected_type) - - # test args - for attr in ("scale", "precision", "length"): - if getattr(to_, attr, None) is not None: - eq_( - getattr(col_info["type"], attr), - getattr(to_, attr, None), + col_info = testing.assert_warnings( + go, ["Could not instantiate"], regex=True ) - finally: - conn.exec_driver_sql("DROP TABLE foo") + else: + col_info = inspector.get_columns("foo")[0] + expected_type = type(to_) + is_(type(col_info["type"]), expected_type) + + # test args + for attr in ("scale", "precision", "length"): + if getattr(to_, attr, None) is not None: + eq_( + getattr(col_info["type"], attr), + getattr(to_, attr, None), + ) + finally: + conn.exec_driver_sql("DROP TABLE foo") def test_lookup_direct_lookup(self): self._test_lookup_direct(self._fixed_lookup_fixture()) diff --git a/test/engine/test_ddlevents.py b/test/engine/test_ddlevents.py index f2429175f9..5cbb478546 100644 --- a/test/engine/test_ddlevents.py +++ b/test/engine/test_ddlevents.py @@ -489,6 +489,7 @@ class DDLExecutionTest(fixtures.TestBase): def test_ddl_execute(self): engine = create_engine("sqlite:///") cx = engine.connect() + cx.begin() table = self.users ddl = DDL("SELECT 1") diff --git a/test/engine/test_deprecations.py b/test/engine/test_deprecations.py index 5e32cc3e96..47e59b55da 100644 --- a/test/engine/test_deprecations.py +++ b/test/engine/test_deprecations.py @@ -93,6 +93,9 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): for meta in (MetaData, ThreadLocalMetaData): for bind in (testing.db, testing.db.connect()): + if isinstance(bind, engine.Connection): + bind.begin() + if meta is ThreadLocalMetaData: with testing.expect_deprecated( "ThreadLocalMetaData is deprecated" @@ -151,6 +154,8 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): def test_bind_create_drop_constructor_bound(self): for bind in (testing.db, testing.db.connect()): + if isinstance(bind, engine.Connection): + bind.begin() try: for args in (([bind], {}), ([], {"bind": bind})): metadata = MetaData(*args[0], **args[1]) @@ -177,15 +182,25 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): test_needs_acid=True, ) conn = testing.db.connect() - metadata.create_all(bind=conn) + with conn.begin(): + metadata.create_all(bind=conn) try: trans = conn.begin() metadata.bind = conn t = table.insert() assert t.bind is conn - table.insert().execute(foo=5) - table.insert().execute(foo=6) - table.insert().execute(foo=7) + with testing.expect_deprecated_20( + r"The Executable.execute\(\) method is considered legacy" + ): + table.insert().execute(foo=5) + with testing.expect_deprecated_20( + r"The Executable.execute\(\) method is considered legacy" + ): + table.insert().execute(foo=6) + with testing.expect_deprecated_20( + r"The Executable.execute\(\) method is considered legacy" + ): + table.insert().execute(foo=7) trans.rollback() metadata.bind = None assert ( @@ -195,7 +210,8 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): == 0 ) finally: - metadata.drop_all(bind=conn) + with conn.begin(): + metadata.drop_all(bind=conn) def test_bind_clauseelement(self): metadata = MetaData() @@ -215,14 +231,21 @@ class ConnectionlessDeprecationTest(fixtures.TestBase): ): e = elem(bind=bind) assert e.bind is bind - e.execute().close() + with testing.expect_deprecated_20( + r"The Executable.execute\(\) method is " + "considered legacy" + ): + e.execute().close() finally: if isinstance(bind, engine.Connection): bind.close() e = elem() assert e.bind is None - assert_raises(exc.UnboundExecutionError, e.execute) + with testing.expect_deprecated_20( + r"The Executable.execute\(\) method is considered legacy" + ): + assert_raises(exc.UnboundExecutionError, e.execute) finally: if isinstance(bind, engine.Connection): bind.close() @@ -365,6 +388,11 @@ class TransactionTest(fixtures.TablesTest): ) Table("inserttable", metadata, Column("data", String(20))) + @testing.fixture + def local_connection(self): + with testing.db.connect() as conn: + yield conn + def test_transaction_container(self): users = self.tables.users @@ -429,6 +457,110 @@ class TransactionTest(fixtures.TablesTest): "insert into inserttable (data) values ('thedata')" ) + def test_branch_autorollback(self, local_connection): + connection = local_connection + users = self.tables.users + branched = connection.connect() + with testing.expect_deprecated_20( + "The current statement is being autocommitted using " + "implicit autocommit" + ): + branched.execute( + users.insert(), dict(user_id=1, user_name="user1") + ) + assert_raises( + exc.DBAPIError, + branched.execute, + users.insert(), + dict(user_id=1, user_name="user1"), + ) + # can continue w/o issue + with testing.expect_deprecated_20( + "The current statement is being autocommitted using " + "implicit autocommit" + ): + branched.execute( + users.insert(), dict(user_id=2, user_name="user2") + ) + + def test_branch_orig_rollback(self, local_connection): + connection = local_connection + users = self.tables.users + branched = connection.connect() + with testing.expect_deprecated_20( + "The current statement is being autocommitted using " + "implicit autocommit" + ): + branched.execute( + users.insert(), dict(user_id=1, user_name="user1") + ) + nested = branched.begin() + assert branched.in_transaction() + branched.execute(users.insert(), dict(user_id=2, user_name="user2")) + nested.rollback() + eq_( + connection.exec_driver_sql("select count(*) from users").scalar(), + 1, + ) + + @testing.requires.independent_connections + def test_branch_autocommit(self, local_connection): + users = self.tables.users + with testing.db.connect() as connection: + branched = connection.connect() + with testing.expect_deprecated_20( + "The current statement is being autocommitted using " + "implicit autocommit" + ): + branched.execute( + users.insert(), dict(user_id=1, user_name="user1") + ) + + eq_( + local_connection.execute( + text("select count(*) from users") + ).scalar(), + 1, + ) + + @testing.requires.savepoints + def test_branch_savepoint_rollback(self, local_connection): + connection = local_connection + users = self.tables.users + trans = connection.begin() + branched = connection.connect() + assert branched.in_transaction() + branched.execute(users.insert(), user_id=1, user_name="user1") + nested = branched.begin_nested() + branched.execute(users.insert(), user_id=2, user_name="user2") + nested.rollback() + assert connection.in_transaction() + trans.commit() + eq_( + connection.exec_driver_sql("select count(*) from users").scalar(), + 1, + ) + + @testing.requires.two_phase_transactions + def test_branch_twophase_rollback(self, local_connection): + connection = local_connection + users = self.tables.users + branched = connection.connect() + assert not branched.in_transaction() + with testing.expect_deprecated_20( + r"The current statement is being autocommitted using " + "implicit autocommit" + ): + branched.execute(users.insert(), user_id=1, user_name="user1") + nested = branched.begin_twophase() + branched.execute(users.insert(), user_id=2, user_name="user2") + nested.rollback() + assert not connection.in_transaction() + eq_( + connection.exec_driver_sql("select count(*) from users").scalar(), + 1, + ) + class HandleInvalidatedOnConnectTest(fixtures.TestBase): __requires__ = ("sqlite",) @@ -699,20 +831,20 @@ class DeprecatedReflectionTest(fixtures.TablesTest): def test_create_drop_explicit(self): metadata = MetaData() table = Table("test_table", metadata, Column("foo", Integer)) - for bind in (testing.db, testing.db.connect()): - for args in [([], {"bind": bind}), ([bind], {})]: - metadata.create_all(*args[0], **args[1]) - with testing.expect_deprecated( - r"The Table.exists\(\) method is deprecated" - ): - assert table.exists(*args[0], **args[1]) - metadata.drop_all(*args[0], **args[1]) - table.create(*args[0], **args[1]) - table.drop(*args[0], **args[1]) - with testing.expect_deprecated( - r"The Table.exists\(\) method is deprecated" - ): - assert not table.exists(*args[0], **args[1]) + bind = testing.db + for args in [([], {"bind": bind}), ([bind], {})]: + metadata.create_all(*args[0], **args[1]) + with testing.expect_deprecated( + r"The Table.exists\(\) method is deprecated" + ): + assert table.exists(*args[0], **args[1]) + metadata.drop_all(*args[0], **args[1]) + table.create(*args[0], **args[1]) + table.drop(*args[0], **args[1]) + with testing.expect_deprecated( + r"The Table.exists\(\) method is deprecated" + ): + assert not table.exists(*args[0], **args[1]) def test_create_drop_err_table(self): metadata = MetaData() @@ -1195,3 +1327,208 @@ class DDLExecutionTest(fixtures.TestBase): with testing.expect_deprecated_20(ddl_msg): r = fn(**kw) eq_(list(r), [(1,)]) + + +class AutocommitKeywordFixture(object): + def _test_keyword(self, keyword, expected=True): + dbapi = Mock( + connect=Mock( + return_value=Mock( + cursor=Mock(return_value=Mock(description=())) + ) + ) + ) + engine = engines.testing_engine( + options={"_initialize": False, "pool_reset_on_return": None} + ) + engine.dialect.dbapi = dbapi + + with engine.connect() as conn: + if expected: + with testing.expect_deprecated_20( + "The current statement is being autocommitted " + "using implicit autocommit" + ): + conn.exec_driver_sql( + "%s something table something" % keyword + ) + else: + conn.exec_driver_sql("%s something table something" % keyword) + + if expected: + eq_( + [n for (n, k, s) in dbapi.connect().mock_calls], + ["cursor", "commit"], + ) + else: + eq_( + [n for (n, k, s) in dbapi.connect().mock_calls], ["cursor"] + ) + + +class AutocommitTextTest(AutocommitKeywordFixture, fixtures.TestBase): + __backend__ = True + + def test_update(self): + self._test_keyword("UPDATE") + + def test_insert(self): + self._test_keyword("INSERT") + + def test_delete(self): + self._test_keyword("DELETE") + + def test_alter(self): + self._test_keyword("ALTER TABLE") + + def test_create(self): + self._test_keyword("CREATE TABLE foobar") + + def test_drop(self): + self._test_keyword("DROP TABLE foobar") + + def test_select(self): + self._test_keyword("SELECT foo FROM table", False) + + +class ExplicitAutoCommitTest(fixtures.TestBase): + + """test the 'autocommit' flag on select() and text() objects. + + Requires PostgreSQL so that we may define a custom function which + modifies the database.""" + + __only_on__ = "postgresql" + + @classmethod + def setup_class(cls): + global metadata, foo + metadata = MetaData(testing.db) + foo = Table( + "foo", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(100)), + ) + with testing.db.begin() as conn: + metadata.create_all(conn) + conn.exec_driver_sql( + "create function insert_foo(varchar) " + "returns integer as 'insert into foo(data) " + "values ($1);select 1;' language sql" + ) + + def teardown(self): + with testing.db.begin() as conn: + conn.execute(foo.delete()) + + @classmethod + def teardown_class(cls): + with testing.db.begin() as conn: + conn.exec_driver_sql("drop function insert_foo(varchar)") + metadata.drop_all(conn) + + def test_control(self): + + # test that not using autocommit does not commit + + conn1 = testing.db.connect() + conn2 = testing.db.connect() + conn1.execute(select(func.insert_foo("data1"))) + assert conn2.execute(select(foo.c.data)).fetchall() == [] + conn1.execute(text("select insert_foo('moredata')")) + assert conn2.execute(select(foo.c.data)).fetchall() == [] + trans = conn1.begin() + trans.commit() + assert conn2.execute(select(foo.c.data)).fetchall() == [ + ("data1",), + ("moredata",), + ] + conn1.close() + conn2.close() + + def test_explicit_compiled(self): + conn1 = testing.db.connect() + conn2 = testing.db.connect() + + with testing.expect_deprecated_20( + "The current statement is being autocommitted using " + "implicit autocommit" + ): + conn1.execute( + select(func.insert_foo("data1")).execution_options( + autocommit=True + ) + ) + assert conn2.execute(select(foo.c.data)).fetchall() == [("data1",)] + conn1.close() + conn2.close() + + def test_explicit_connection(self): + conn1 = testing.db.connect() + conn2 = testing.db.connect() + with testing.expect_deprecated_20( + "The current statement is being autocommitted using " + "implicit autocommit" + ): + conn1.execution_options(autocommit=True).execute( + select(func.insert_foo("data1")) + ) + eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)]) + + # connection supersedes statement + + conn1.execution_options(autocommit=False).execute( + select(func.insert_foo("data2")).execution_options(autocommit=True) + ) + eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)]) + + # ditto + + with testing.expect_deprecated_20( + "The current statement is being autocommitted using " + "implicit autocommit" + ): + conn1.execution_options(autocommit=True).execute( + select(func.insert_foo("data3")).execution_options( + autocommit=False + ) + ) + eq_( + conn2.execute(select(foo.c.data)).fetchall(), + [("data1",), ("data2",), ("data3",)], + ) + conn1.close() + conn2.close() + + def test_explicit_text(self): + conn1 = testing.db.connect() + conn2 = testing.db.connect() + with testing.expect_deprecated_20( + "The current statement is being autocommitted using " + "implicit autocommit" + ): + conn1.execute( + text("select insert_foo('moredata')").execution_options( + autocommit=True + ) + ) + assert conn2.execute(select(foo.c.data)).fetchall() == [("moredata",)] + conn1.close() + conn2.close() + + def test_implicit_text(self): + conn1 = testing.db.connect() + conn2 = testing.db.connect() + with testing.expect_deprecated_20( + "The current statement is being autocommitted using " + "implicit autocommit" + ): + conn1.execute( + text("insert into foo (data) values ('implicitdata')") + ) + assert conn2.execute(select(foo.c.data)).fetchall() == [ + ("implicitdata",) + ] + conn1.close() + conn2.close() diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index efec9376c1..55a114409b 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -543,13 +543,15 @@ class ExecuteTest(fixtures.TablesTest): @testing.only_on("sqlite") def test_execute_compiled_favors_compiled_paramstyle(self): + users = self.tables.users + with patch.object(testing.db.dialect, "do_execute") as do_exec: stmt = users.update().values(user_id=1, user_name="foo") d1 = default.DefaultDialect(paramstyle="format") d2 = default.DefaultDialect(paramstyle="pyformat") - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(stmt.compile(dialect=d1)) conn.execute(stmt.compile(dialect=d2)) @@ -805,9 +807,8 @@ class ConvenienceExecuteTest(fixtures.TablesTest): def test_connection_as_ctx(self): fn = self._trans_fn() - ctx = testing.db.connect() - testing.run_as_contextmanager(ctx, fn, 5, value=8) - # autocommit is on + with testing.db.begin() as conn: + fn(conn, 5, value=8) self._assert_fn(5, value=8) @testing.fails_on("mysql+oursql", "oursql bug ? getting wrong rowcount") @@ -822,14 +823,12 @@ class ConvenienceExecuteTest(fixtures.TablesTest): self._assert_no_data() -class CompiledCacheTest(fixtures.TestBase): +class CompiledCacheTest(fixtures.TablesTest): __backend__ = True @classmethod - def setup_class(cls): - global users, metadata - metadata = MetaData(testing.db) - users = Table( + def define_tables(cls, metadata): + Table( "users", metadata, Column( @@ -838,19 +837,11 @@ class CompiledCacheTest(fixtures.TestBase): Column("user_name", VARCHAR(20)), Column("extra_data", VARCHAR(20)), ) - metadata.create_all() - @engines.close_first - def teardown(self): - with testing.db.connect() as conn: - conn.execute(users.delete()) - - @classmethod - def teardown_class(cls): - metadata.drop_all() + def test_cache(self, connection): + users = self.tables.users - def test_cache(self): - conn = testing.db.connect() + conn = connection cache = {} cached_conn = conn.execution_options(compiled_cache=cache) @@ -870,7 +861,7 @@ class CompiledCacheTest(fixtures.TestBase): "uses blob value that is problematic for some DBAPIs", ) @testing.provide_metadata - def test_cache_noleak_on_statement_values(self): + def test_cache_noleak_on_statement_values(self, connection): # This is a non regression test for an object reference leak caused # by the compiled_cache. @@ -883,11 +874,10 @@ class CompiledCacheTest(fixtures.TestBase): ), Column("photo_blob", LargeBinary()), ) - metadata.create_all() + metadata.create_all(connection) - conn = testing.db.connect() cache = {} - cached_conn = conn.execution_options(compiled_cache=cache) + cached_conn = connection.execution_options(compiled_cache=cache) class PhotoBlob(bytearray): pass @@ -902,7 +892,10 @@ class CompiledCacheTest(fixtures.TestBase): cached_conn.execute(ins, {"photo_blob": blob}) eq_(compile_mock.call_count, 1) eq_(len(cache), 1) - eq_(conn.exec_driver_sql("select count(*) from photo").scalar(), 1) + eq_( + connection.exec_driver_sql("select count(*) from photo").scalar(), + 1, + ) del blob @@ -912,14 +905,15 @@ class CompiledCacheTest(fixtures.TestBase): # the statement values (only the keys). eq_(ref_blob(), None) - def test_keys_independent_of_ordering(self): - conn = testing.db.connect() - conn.execute( + def test_keys_independent_of_ordering(self, connection): + users = self.tables.users + + connection.execute( users.insert(), {"user_id": 1, "user_name": "u1", "extra_data": "e1"}, ) cache = {} - cached_conn = conn.execution_options(compiled_cache=cache) + cached_conn = connection.execution_options(compiled_cache=cache) upd = users.update().where(users.c.user_id == bindparam("b_user_id")) @@ -974,30 +968,32 @@ class CompiledCacheTest(fixtures.TestBase): stmt = select(t1.c.q) cache = {} - with config.db.connect().execution_options( - compiled_cache=cache - ) as conn: + with config.db.begin() as conn: + conn = conn.execution_options(compiled_cache=cache) conn.execute(ins, {"q": 1}) eq_(conn.scalar(stmt), 1) - with config.db.connect().execution_options( - compiled_cache=cache, - schema_translate_map={None: config.test_schema}, - ) as conn: + with config.db.begin() as conn: + conn = conn.execution_options( + compiled_cache=cache, + schema_translate_map={None: config.test_schema}, + ) conn.execute(ins, {"q": 2}) eq_(conn.scalar(stmt), 2) - with config.db.connect().execution_options( - compiled_cache=cache, - schema_translate_map={None: None}, - ) as conn: + with config.db.begin() as conn: + conn = conn.execution_options( + compiled_cache=cache, + schema_translate_map={None: None}, + ) # should use default schema again even though statement # was compiled with test_schema in the map eq_(conn.scalar(stmt), 1) - with config.db.connect().execution_options( - compiled_cache=cache - ) as conn: + with config.db.begin() as conn: + conn = conn.execution_options( + compiled_cache=cache, + ) eq_(conn.scalar(stmt), 1) @@ -1050,7 +1046,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): t3 = Table("t3", metadata, Column("x", Integer), schema="bar") with self.sql_execution_asserter(config.db) as asserter: - with config.db.connect().execution_options( + with config.db.begin() as conn, conn.execution_options( schema_translate_map=map_ ) as conn: @@ -1091,9 +1087,8 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): Table("t2", metadata, Column("x", Integer), schema="foo") Table("t3", metadata, Column("x", Integer), schema="bar") - with config.db.connect().execution_options( - schema_translate_map=map_ - ) as conn: + with config.db.begin() as conn: + conn = conn.execution_options(schema_translate_map=map_) metadata.create_all(conn) insp = inspect(config.db) @@ -1101,9 +1096,8 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): is_true(insp.has_table("t2", schema=config.test_schema)) is_true(insp.has_table("t3", schema=None)) - with config.db.connect().execution_options( - schema_translate_map=map_ - ) as conn: + with config.db.begin() as conn: + conn = conn.execution_options(schema_translate_map=map_) metadata.drop_all(conn) insp = inspect(config.db) @@ -1127,7 +1121,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): t3 = Table("t3", metadata, Column("x", Integer), schema="bar") with self.sql_execution_asserter(config.db) as asserter: - with config.db.connect() as conn: + with config.db.begin() as conn: execution_options = {"schema_translate_map": map_} conn._execute_20( @@ -1222,7 +1216,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults): t3 = Table("t3", metadata, Column("x", Integer), schema="bar") with self.sql_execution_asserter(config.db) as asserter: - with config.db.connect().execution_options( + with config.db.begin() as conn, conn.execution_options( schema_translate_map=map_ ) as conn: @@ -1790,6 +1784,7 @@ class EngineEventsTest(fixtures.TestBase): else: ctx = conn = engine.connect() + trans = conn.begin() try: m.create_all(conn, checkfirst=False) try: @@ -1801,8 +1796,7 @@ class EngineEventsTest(fixtures.TestBase): ) finally: m.drop_all(conn) - if engine._is_future: - conn.commit() + trans.commit() finally: if ctx: ctx.close() @@ -3046,7 +3040,7 @@ class DialectEventTest(fixtures.TestBase): m1.do_execute_no_params.side_effect ) = mock_the_cursor - with e.connect() as conn: + with e.begin() as conn: yield conn, m1 def _assert(self, retval, m1, m2, mock_calls): @@ -3244,59 +3238,6 @@ class DialectEventTest(fixtures.TestBase): eq_(conn.info["boom"], "one") -class AutocommitKeywordFixture(object): - def _test_keyword(self, keyword, expected=True): - dbapi = Mock( - connect=Mock( - return_value=Mock( - cursor=Mock(return_value=Mock(description=())) - ) - ) - ) - engine = engines.testing_engine( - options={"_initialize": False, "pool_reset_on_return": None} - ) - engine.dialect.dbapi = dbapi - - with engine.connect() as conn: - conn.exec_driver_sql("%s something table something" % keyword) - - if expected: - eq_( - [n for (n, k, s) in dbapi.connect().mock_calls], - ["cursor", "commit"], - ) - else: - eq_( - [n for (n, k, s) in dbapi.connect().mock_calls], ["cursor"] - ) - - -class AutocommitTextTest(AutocommitKeywordFixture, fixtures.TestBase): - __backend__ = True - - def test_update(self): - self._test_keyword("UPDATE") - - def test_insert(self): - self._test_keyword("INSERT") - - def test_delete(self): - self._test_keyword("DELETE") - - def test_alter(self): - self._test_keyword("ALTER TABLE") - - def test_create(self): - self._test_keyword("CREATE TABLE foobar") - - def test_drop(self): - self._test_keyword("DROP TABLE foobar") - - def test_select(self): - self._test_keyword("SELECT foo FROM table", False) - - class FutureExecuteTest(fixtures.FutureEngineMixin, fixtures.TablesTest): __backend__ = True @@ -3463,7 +3404,7 @@ class SetInputSizesTest(fixtures.TablesTest): def test_set_input_sizes_no_event(self, input_sizes_fixture): engine, canary = input_sizes_fixture - with engine.connect() as conn: + with engine.begin() as conn: conn.execute( self.tables.users.insert(), [ @@ -3596,7 +3537,7 @@ class SetInputSizesTest(fixtures.TablesTest): 0, ) - with engine.connect() as conn: + with engine.begin() as conn: conn.execute( self.tables.users.insert(), [ diff --git a/test/engine/test_logging.py b/test/engine/test_logging.py index aa272c0cf5..29b8132aa3 100644 --- a/test/engine/test_logging.py +++ b/test/engine/test_logging.py @@ -22,7 +22,7 @@ from sqlalchemy.testing.util import lazy_gc def exec_sql(engine, sql, *args, **kwargs): - with engine.connect() as conn: + with engine.begin() as conn: return conn.exec_driver_sql(sql, *args, **kwargs) @@ -56,7 +56,7 @@ class LogParamsTest(fixtures.TestBase): [{"data": str(i)} for i in range(100)], ) eq_( - self.buf.buffer[1].message, + self.buf.buffer[2].message, "[raw sql] [{'data': '0'}, {'data': '1'}, {'data': '2'}, " "{'data': '3'}, " "{'data': '4'}, {'data': '5'}, {'data': '6'}, {'data': '7'}" @@ -86,7 +86,7 @@ class LogParamsTest(fixtures.TestBase): [{"data": str(i)} for i in range(100)], ) eq_( - self.buf.buffer[1].message, + self.buf.buffer[2].message, "[raw sql] [SQL parameters hidden due to hide_parameters=True]", ) @@ -97,7 +97,7 @@ class LogParamsTest(fixtures.TestBase): [(str(i),) for i in range(100)], ) eq_( - self.buf.buffer[1].message, + self.buf.buffer[2].message, "[raw sql] [('0',), ('1',), ('2',), ('3',), ('4',), ('5',), " "('6',), ('7',) ... displaying 10 of 100 total " "bound parameter sets ... ('98',), ('99',)]", @@ -227,7 +227,7 @@ class LogParamsTest(fixtures.TestBase): exec_sql(self.eng, "INSERT INTO foo (data) values (?)", (largeparam,)) eq_( - self.buf.buffer[1].message, + self.buf.buffer[2].message, "[raw sql] ('%s ... (4702 characters truncated) ... %s',)" % (largeparam[0:149], largeparam[-149:]), ) @@ -242,7 +242,7 @@ class LogParamsTest(fixtures.TestBase): exec_sql(self.eng, "SELECT ?, ?, ?", (lp1, lp2, lp3)) eq_( - self.buf.buffer[1].message, + self.buf.buffer[2].message, "[raw sql] ('%s', '%s', '%s ... (372 characters truncated) " "... %s')" % (lp1, lp2, lp3[0:149], lp3[-149:]), ) @@ -261,7 +261,7 @@ class LogParamsTest(fixtures.TestBase): ) eq_( - self.buf.buffer[1].message, + self.buf.buffer[2].message, "[raw sql] [('%s ... (4702 characters truncated) ... %s',), " "('%s',), " "('%s ... (372 characters truncated) ... %s',)]" @@ -347,20 +347,20 @@ class LogParamsTest(fixtures.TestBase): row = result.first() eq_( - self.buf.buffer[1].message, + self.buf.buffer[2].message, "[raw sql] ('%s ... (4702 characters truncated) ... %s',)" % (largeparam[0:149], largeparam[-149:]), ) if util.py3k: eq_( - self.buf.buffer[3].message, + self.buf.buffer[5].message, "Row ('%s ... (4702 characters truncated) ... %s',)" % (largeparam[0:149], largeparam[-149:]), ) else: eq_( - self.buf.buffer[3].message, + self.buf.buffer[5].message, "Row (u'%s ... (4703 characters truncated) ... %s',)" % (largeparam[0:148], largeparam[-149:]), ) @@ -495,7 +495,8 @@ class LoggingNameTest(fixtures.TestBase): __requires__ = ("ad_hoc_engines",) def _assert_names_in_execute(self, eng, eng_name, pool_name): - eng.execute(select(1)) + with eng.connect() as conn: + conn.execute(select(1)) assert self.buf.buffer for name in [b.name for b in self.buf.buffer]: assert name in ( @@ -505,7 +506,8 @@ class LoggingNameTest(fixtures.TestBase): ) def _assert_no_name_in_execute(self, eng): - eng.execute(select(1)) + with eng.connect() as conn: + conn.execute(select(1)) assert self.buf.buffer for name in [b.name for b in self.buf.buffer]: assert name in ( @@ -548,7 +550,8 @@ class LoggingNameTest(fixtures.TestBase): def test_named_logger_names_after_dispose(self): eng = self._named_engine() - eng.execute(select(1)) + with eng.connect() as conn: + conn.execute(select(1)) eng.dispose() eq_(eng.logging_name, "myenginename") eq_(eng.pool.logging_name, "mypoolname") @@ -568,7 +571,8 @@ class LoggingNameTest(fixtures.TestBase): def test_named_logger_execute_after_dispose(self): eng = self._named_engine() - eng.execute(select(1)) + with eng.connect() as conn: + conn.execute(select(1)) eng.dispose() self._assert_names_in_execute(eng, "myenginename", "mypoolname") @@ -599,7 +603,8 @@ class EchoTest(fixtures.TestBase): # do an initial execute to clear out 'first connect' # messages - e.execute(select(10)).close() + with e.connect() as conn: + conn.execute(select(10)).close() self.buf.flush() return e @@ -637,16 +642,25 @@ class EchoTest(fixtures.TestBase): e2 = self._testing_engine() e1.echo = True - e1.execute(select(1)).close() - e2.execute(select(2)).close() + + with e1.connect() as conn: + conn.execute(select(1)).close() + + with e2.connect() as conn: + conn.execute(select(2)).close() e1.echo = False - e1.execute(select(3)).close() - e2.execute(select(4)).close() + + with e1.connect() as conn: + conn.execute(select(3)).close() + with e2.connect() as conn: + conn.execute(select(4)).close() e2.echo = True - e1.execute(select(5)).close() - e2.execute(select(6)).close() + with e1.connect() as conn: + conn.execute(select(5)).close() + with e2.connect() as conn: + conn.execute(select(6)).close() assert self.buf.buffer[0].getMessage().startswith("SELECT 1") assert self.buf.buffer[2].getMessage().startswith("SELECT 6") diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 0dc35f99e8..ebdaa79a08 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1340,20 +1340,24 @@ class InvalidateDuringResultTest(fixtures.TestBase): def setup(self): self.engine = engines.reconnecting_engine() - self.meta = MetaData(self.engine) + self.meta = MetaData() table = Table( "sometable", self.meta, Column("id", Integer, primary_key=True), Column("name", String(50)), ) - self.meta.create_all() - table.insert().execute( - [{"id": i, "name": "row %d" % i} for i in range(1, 100)] - ) + + with self.engine.begin() as conn: + self.meta.create_all(conn) + conn.execute( + table.insert(), + [{"id": i, "name": "row %d" % i} for i in range(1, 100)], + ) def teardown(self): - self.meta.drop_all() + with self.engine.begin() as conn: + self.meta.drop_all(conn) self.engine.dispose() @testing.crashes( diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index b19836c842..48b6c40d77 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -2016,7 +2016,7 @@ def createIndexes(con, schema=None): @testing.requires.views def _create_views(con, schema=None): - with testing.db.connect() as conn: + with testing.db.begin() as conn: for table_name in ("users", "email_addresses"): fullname = table_name if schema: @@ -2031,7 +2031,7 @@ def _create_views(con, schema=None): @testing.requires.views def _drop_views(con, schema=None): - with testing.db.connect() as conn: + with testing.db.begin() as conn: for table_name in ("email_addresses", "users"): fullname = table_name if schema: @@ -2047,7 +2047,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL): @testing.requires.denormalized_names def setup(self): - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.exec_driver_sql( """ CREATE TABLE weird_casing( @@ -2060,7 +2060,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL): @testing.requires.denormalized_names def teardown(self): - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.exec_driver_sql("drop table weird_casing") @testing.requires.denormalized_names diff --git a/test/engine/test_transaction.py b/test/engine/test_transaction.py index d0774e8464..4db5a745ad 100644 --- a/test/engine/test_transaction.py +++ b/test/engine/test_transaction.py @@ -5,20 +5,16 @@ from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import func from sqlalchemy import INT -from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import pool as _pool from sqlalchemy import select -from sqlalchemy import String from sqlalchemy import testing -from sqlalchemy import text from sqlalchemy import util from sqlalchemy import VARCHAR from sqlalchemy.engine import base from sqlalchemy.engine import characteristics from sqlalchemy.engine import default from sqlalchemy.engine import url -from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings @@ -29,31 +25,19 @@ from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table -users, metadata = None, None - -class TransactionTest(fixtures.TestBase): +class TransactionTest(fixtures.TablesTest): __backend__ = True @classmethod - def setup_class(cls): - global users, metadata - metadata = MetaData() - users = Table( - "query_users", + def define_tables(cls, metadata): + Table( + "users", metadata, Column("user_id", INT, primary_key=True), Column("user_name", VARCHAR(20)), test_needs_acid=True, ) - users.create(testing.db) - - def teardown(self): - testing.db.execute(users.delete()).close() - - @classmethod - def teardown_class(cls): - users.drop(testing.db) @testing.fixture def local_connection(self): @@ -61,6 +45,7 @@ class TransactionTest(fixtures.TestBase): yield conn def test_commits(self, local_connection): + users = self.tables.users connection = local_connection transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name="user1") @@ -72,7 +57,7 @@ class TransactionTest(fixtures.TestBase): transaction.commit() transaction = connection.begin() - result = connection.exec_driver_sql("select * from query_users") + result = connection.exec_driver_sql("select * from users") assert len(result.fetchall()) == 3 transaction.commit() connection.close() @@ -80,17 +65,19 @@ class TransactionTest(fixtures.TestBase): def test_rollback(self, local_connection): """test a basic rollback""" + users = self.tables.users connection = local_connection transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name="user1") connection.execute(users.insert(), user_id=2, user_name="user2") connection.execute(users.insert(), user_id=3, user_name="user3") transaction.rollback() - result = connection.exec_driver_sql("select * from query_users") + result = connection.exec_driver_sql("select * from users") assert len(result.fetchall()) == 0 def test_raise(self, local_connection): connection = local_connection + users = self.tables.users transaction = connection.begin() try: @@ -103,11 +90,12 @@ class TransactionTest(fixtures.TestBase): print("Exception: ", e) transaction.rollback() - result = connection.exec_driver_sql("select * from query_users") + result = connection.exec_driver_sql("select * from users") assert len(result.fetchall()) == 0 def test_nested_rollback(self, local_connection): connection = local_connection + users = self.tables.users try: transaction = connection.begin() try: @@ -146,6 +134,7 @@ class TransactionTest(fixtures.TestBase): def test_branch_nested_rollback(self, local_connection): connection = local_connection + users = self.tables.users connection.begin() branched = connection.connect() assert branched.in_transaction() @@ -179,6 +168,7 @@ class TransactionTest(fixtures.TestBase): @testing.requires.savepoints def test_savepoint_cancelled_by_toplevel_marker(self, local_connection): conn = local_connection + users = self.tables.users trans = conn.begin() conn.execute(users.insert(), {"user_id": 1, "user_name": "name"}) @@ -245,85 +235,6 @@ class TransactionTest(fixtures.TestBase): nested.commit, ) - def test_branch_autorollback(self, local_connection): - connection = local_connection - branched = connection.connect() - branched.execute(users.insert(), dict(user_id=1, user_name="user1")) - assert_raises( - exc.DBAPIError, - branched.execute, - users.insert(), - dict(user_id=1, user_name="user1"), - ) - # can continue w/o issue - branched.execute(users.insert(), dict(user_id=2, user_name="user2")) - - def test_branch_orig_rollback(self, local_connection): - connection = local_connection - branched = connection.connect() - branched.execute(users.insert(), dict(user_id=1, user_name="user1")) - nested = branched.begin() - assert branched.in_transaction() - branched.execute(users.insert(), dict(user_id=2, user_name="user2")) - nested.rollback() - eq_( - connection.exec_driver_sql( - "select count(*) from query_users" - ).scalar(), - 1, - ) - - @testing.requires.independent_connections - def test_branch_autocommit(self, local_connection): - with testing.db.connect() as connection: - branched = connection.connect() - branched.execute( - users.insert(), dict(user_id=1, user_name="user1") - ) - - eq_( - local_connection.execute( - text("select count(*) from query_users") - ).scalar(), - 1, - ) - - @testing.requires.savepoints - def test_branch_savepoint_rollback(self, local_connection): - connection = local_connection - trans = connection.begin() - branched = connection.connect() - assert branched.in_transaction() - branched.execute(users.insert(), user_id=1, user_name="user1") - nested = branched.begin_nested() - branched.execute(users.insert(), user_id=2, user_name="user2") - nested.rollback() - assert connection.in_transaction() - trans.commit() - eq_( - connection.exec_driver_sql( - "select count(*) from query_users" - ).scalar(), - 1, - ) - - @testing.requires.two_phase_transactions - def test_branch_twophase_rollback(self, local_connection): - connection = local_connection - branched = connection.connect() - assert not branched.in_transaction() - branched.execute(users.insert(), user_id=1, user_name="user1") - nested = branched.begin_twophase() - branched.execute(users.insert(), user_id=2, user_name="user2") - nested.rollback() - assert not connection.in_transaction() - eq_( - connection.exec_driver_sql( - "select count(*) from query_users" - ).scalar(), - 1, - ) - def test_deactivated_warning_ctxmanager(self, local_connection): with expect_warnings( "transaction already deassociated from connection" @@ -472,20 +383,20 @@ class TransactionTest(fixtures.TestBase): def test_retains_through_options(self, local_connection): connection = local_connection + users = self.tables.users transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name="user1") conn2 = connection.execution_options(dummy=True) conn2.execute(users.insert(), user_id=2, user_name="user2") transaction.rollback() eq_( - connection.exec_driver_sql( - "select count(*) from query_users" - ).scalar(), + connection.exec_driver_sql("select count(*) from users").scalar(), 0, ) def test_nesting(self, local_connection): connection = local_connection + users = self.tables.users transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name="user1") connection.execute(users.insert(), user_id=2, user_name="user2") @@ -497,15 +408,16 @@ class TransactionTest(fixtures.TestBase): transaction.rollback() self.assert_( connection.exec_driver_sql( - "select count(*) from " "query_users" + "select count(*) from " "users" ).scalar() == 0 ) - result = connection.exec_driver_sql("select * from query_users") + result = connection.exec_driver_sql("select * from users") assert len(result.fetchall()) == 0 def test_with_interface(self, local_connection): connection = local_connection + users = self.tables.users trans = connection.begin() connection.execute(users.insert(), user_id=1, user_name="user1") connection.execute(users.insert(), user_id=2, user_name="user2") @@ -517,7 +429,7 @@ class TransactionTest(fixtures.TestBase): assert not trans.is_active self.assert_( connection.exec_driver_sql( - "select count(*) from " "query_users" + "select count(*) from " "users" ).scalar() == 0 ) @@ -528,13 +440,14 @@ class TransactionTest(fixtures.TestBase): assert not trans.is_active self.assert_( connection.exec_driver_sql( - "select count(*) from " "query_users" + "select count(*) from " "users" ).scalar() == 1 ) def test_close(self, local_connection): connection = local_connection + users = self.tables.users transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name="user1") connection.execute(users.insert(), user_id=2, user_name="user2") @@ -549,15 +462,16 @@ class TransactionTest(fixtures.TestBase): assert not connection.in_transaction() self.assert_( connection.exec_driver_sql( - "select count(*) from " "query_users" + "select count(*) from " "users" ).scalar() == 5 ) - result = connection.exec_driver_sql("select * from query_users") + result = connection.exec_driver_sql("select * from users") assert len(result.fetchall()) == 5 def test_close2(self, local_connection): connection = local_connection + users = self.tables.users transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name="user1") connection.execute(users.insert(), user_id=2, user_name="user2") @@ -572,16 +486,17 @@ class TransactionTest(fixtures.TestBase): assert not connection.in_transaction() self.assert_( connection.exec_driver_sql( - "select count(*) from " "query_users" + "select count(*) from " "users" ).scalar() == 0 ) - result = connection.exec_driver_sql("select * from query_users") + result = connection.exec_driver_sql("select * from users") assert len(result.fetchall()) == 0 @testing.requires.savepoints def test_nested_subtransaction_rollback(self, local_connection): connection = local_connection + users = self.tables.users transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name="user1") trans2 = connection.begin_nested() @@ -599,6 +514,7 @@ class TransactionTest(fixtures.TestBase): @testing.requires.savepoints def test_nested_subtransaction_commit(self, local_connection): connection = local_connection + users = self.tables.users transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name="user1") trans2 = connection.begin_nested() @@ -616,6 +532,7 @@ class TransactionTest(fixtures.TestBase): @testing.requires.savepoints def test_rollback_to_subtransaction(self, local_connection): connection = local_connection + users = self.tables.users transaction = connection.begin() connection.execute(users.insert(), user_id=1, user_name="user1") trans2 = connection.begin_nested() @@ -646,6 +563,7 @@ class TransactionTest(fixtures.TestBase): @testing.requires.two_phase_transactions def test_two_phase_transaction(self, local_connection): connection = local_connection + users = self.tables.users transaction = connection.begin_twophase() connection.execute(users.insert(), user_id=1, user_name="user1") transaction.prepare() @@ -680,6 +598,7 @@ class TransactionTest(fixtures.TestBase): @testing.requires.savepoints def test_mixed_two_phase_transaction(self, local_connection): connection = local_connection + users = self.tables.users transaction = connection.begin_twophase() connection.execute(users.insert(), user_id=1, user_name="user1") transaction2 = connection.begin() @@ -704,6 +623,7 @@ class TransactionTest(fixtures.TestBase): @testing.requires.two_phase_transactions @testing.requires.two_phase_recovery def test_two_phase_recover(self): + users = self.tables.users # 2020, still can't get this to work w/ modern MySQL or MariaDB. # the XA RECOVER comes back as bytes, OK, convert to string, @@ -722,11 +642,14 @@ class TransactionTest(fixtures.TestBase): with testing.db.connect() as connection2: eq_( - connection2.execution_options(autocommit=True) - .execute(select(users.c.user_id).order_by(users.c.user_id)) - .fetchall(), + connection2.execute( + select(users.c.user_id).order_by(users.c.user_id) + ).fetchall(), [], ) + + # recover_twophase needs to be run in a new transaction + with testing.db.connect() as connection2: recoverables = connection2.recover_twophase() assert transaction.xid in recoverables connection2.commit_prepared(transaction.xid, recover=True) @@ -740,6 +663,7 @@ class TransactionTest(fixtures.TestBase): @testing.requires.two_phase_transactions def test_multiple_two_phase(self, local_connection): conn = local_connection + users = self.tables.users xa = conn.begin_twophase() conn.execute(users.insert(), user_id=1, user_name="user1") xa.prepare() @@ -767,6 +691,7 @@ class TransactionTest(fixtures.TestBase): # so that picky backends like MySQL correctly clear out # their state when a connection is closed without handling # the transaction explicitly. + users = self.tables.users eng = testing_engine() @@ -1005,7 +930,8 @@ class AutoRollbackTest(fixtures.TestBase): Column("user_name", VARCHAR(20)), test_needs_acid=True, ) - users.create(conn1) + with conn1.begin(): + users.create(conn1) conn1.exec_driver_sql("select * from deadlock_users") conn1.close() @@ -1014,125 +940,8 @@ class AutoRollbackTest(fixtures.TestBase): # pool but still has a lock on "deadlock_users". comment out the # rollback in pool/ConnectionFairy._close() to see ! - users.drop(conn2) - conn2.close() - - -class ExplicitAutoCommitTest(fixtures.TestBase): - - """test the 'autocommit' flag on select() and text() objects. - - Requires PostgreSQL so that we may define a custom function which - modifies the database.""" - - __only_on__ = "postgresql" - - @classmethod - def setup_class(cls): - global metadata, foo - metadata = MetaData(testing.db) - foo = Table( - "foo", - metadata, - Column("id", Integer, primary_key=True), - Column("data", String(100)), - ) - with testing.db.connect() as conn: - metadata.create_all(conn) - conn.exec_driver_sql( - "create function insert_foo(varchar) " - "returns integer as 'insert into foo(data) " - "values ($1);select 1;' language sql" - ) - - def teardown(self): - with testing.db.connect() as conn: - conn.execute(foo.delete()) - - @classmethod - def teardown_class(cls): - with testing.db.connect() as conn: - conn.exec_driver_sql("drop function insert_foo(varchar)") - metadata.drop_all(conn) - - def test_control(self): - - # test that not using autocommit does not commit - - conn1 = testing.db.connect() - conn2 = testing.db.connect() - conn1.execute(select(func.insert_foo("data1"))) - assert conn2.execute(select(foo.c.data)).fetchall() == [] - conn1.execute(text("select insert_foo('moredata')")) - assert conn2.execute(select(foo.c.data)).fetchall() == [] - trans = conn1.begin() - trans.commit() - assert conn2.execute(select(foo.c.data)).fetchall() == [ - ("data1",), - ("moredata",), - ] - conn1.close() - conn2.close() - - def test_explicit_compiled(self): - conn1 = testing.db.connect() - conn2 = testing.db.connect() - conn1.execute( - select(func.insert_foo("data1")).execution_options(autocommit=True) - ) - assert conn2.execute(select(foo.c.data)).fetchall() == [("data1",)] - conn1.close() - conn2.close() - - def test_explicit_connection(self): - conn1 = testing.db.connect() - conn2 = testing.db.connect() - conn1.execution_options(autocommit=True).execute( - select(func.insert_foo("data1")) - ) - eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)]) - - # connection supersedes statement - - conn1.execution_options(autocommit=False).execute( - select(func.insert_foo("data2")).execution_options(autocommit=True) - ) - eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)]) - - # ditto - - conn1.execution_options(autocommit=True).execute( - select(func.insert_foo("data3")).execution_options( - autocommit=False - ) - ) - eq_( - conn2.execute(select(foo.c.data)).fetchall(), - [("data1",), ("data2",), ("data3",)], - ) - conn1.close() - conn2.close() - - def test_explicit_text(self): - conn1 = testing.db.connect() - conn2 = testing.db.connect() - conn1.execute( - text("select insert_foo('moredata')").execution_options( - autocommit=True - ) - ) - assert conn2.execute(select(foo.c.data)).fetchall() == [("moredata",)] - conn1.close() - conn2.close() - - def test_implicit_text(self): - conn1 = testing.db.connect() - conn2 = testing.db.connect() - conn1.execute(text("insert into foo (data) values ('implicitdata')")) - assert conn2.execute(select(foo.c.data)).fetchall() == [ - ("implicitdata",) - ] - conn1.close() + with conn2.begin(): + users.drop(conn2) conn2.close() diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index 3cb29c67dc..df27c8d270 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -1329,10 +1329,13 @@ class KVChild(object): self.value = value -class ReconstitutionTest(fixtures.TestBase): - def setup(self): - metadata = MetaData(testing.db) - parents = Table( +class ReconstitutionTest(fixtures.MappedTest): + run_setup_mappers = "each" + run_setup_classes = "each" + + @classmethod + def define_tables(cls, metadata): + Table( "parents", metadata, Column( @@ -1340,7 +1343,7 @@ class ReconstitutionTest(fixtures.TestBase): ), Column("name", String(30)), ) - children = Table( + Table( "children", metadata, Column( @@ -1349,22 +1352,23 @@ class ReconstitutionTest(fixtures.TestBase): Column("parent_id", Integer, ForeignKey("parents.id")), Column("name", String(30)), ) - metadata.create_all() - parents.insert().execute(name="p1") - self.metadata = metadata - self.parents = parents - self.children = children - Parent.kids = association_proxy("children", "name") - def teardown(self): - self.metadata.drop_all() - clear_mappers() + @classmethod + def insert_data(cls, connection): + parents = cls.tables.parents + connection.execute(parents.insert(), dict(name="p1")) + + @classmethod + def setup_classes(cls): + Parent.kids = association_proxy("children", "name") def test_weak_identity_map(self): mapper( - Parent, self.parents, properties=dict(children=relationship(Child)) + Parent, + self.tables.parents, + properties=dict(children=relationship(Child)), ) - mapper(Child, self.children) + mapper(Child, self.tables.children) session = create_session() def add_child(parent_name, child_name): @@ -1380,9 +1384,11 @@ class ReconstitutionTest(fixtures.TestBase): def test_copy(self): mapper( - Parent, self.parents, properties=dict(children=relationship(Child)) + Parent, + self.tables.parents, + properties=dict(children=relationship(Child)), ) - mapper(Child, self.children) + mapper(Child, self.tables.children) p = Parent("p1") p.kids.extend(["c1", "c2"]) p_copy = copy.copy(p) @@ -1392,9 +1398,11 @@ class ReconstitutionTest(fixtures.TestBase): def test_pickle_list(self): mapper( - Parent, self.parents, properties=dict(children=relationship(Child)) + Parent, + self.tables.parents, + properties=dict(children=relationship(Child)), ) - mapper(Child, self.children) + mapper(Child, self.tables.children) p = Parent("p1") p.kids.extend(["c1", "c2"]) r1 = pickle.loads(pickle.dumps(p)) @@ -1407,12 +1415,12 @@ class ReconstitutionTest(fixtures.TestBase): def test_pickle_set(self): mapper( Parent, - self.parents, + self.tables.parents, properties=dict( children=relationship(Child, collection_class=set) ), ) - mapper(Child, self.children) + mapper(Child, self.tables.children) p = Parent("p1") p.kids.update(["c1", "c2"]) r1 = pickle.loads(pickle.dumps(p)) @@ -1425,7 +1433,7 @@ class ReconstitutionTest(fixtures.TestBase): def test_pickle_dict(self): mapper( Parent, - self.parents, + self.tables.parents, properties=dict( children=relationship( KVChild, @@ -1435,7 +1443,7 @@ class ReconstitutionTest(fixtures.TestBase): ) ), ) - mapper(KVChild, self.children) + mapper(KVChild, self.tables.children) p = Parent("p1") p.kids.update({"c1": "v1", "c2": "v2"}) assert p.kids == {"c1": "c1", "c2": "c2"} diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index a8c17d7aca..e46c65ff02 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -53,10 +53,10 @@ class ShardTest(object): def id_generator(ctx): # in reality, might want to use a separate transaction for this. - c = db1.connect() - nextid = c.execute(ids.select().with_for_update()).scalar() - c.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1})) - return nextid + with db1.begin() as c: + nextid = c.execute(ids.select().with_for_update()).scalar() + c.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1})) + return nextid weather_locations = Table( "weather_locations", @@ -80,7 +80,8 @@ class ShardTest(object): for db in (db1, db2, db3, db4): meta.create_all(db) - db1.execute(ids.insert(), nextid=1) + with db1.begin() as conn: + conn.execute(ids.insert(), dict(nextid=1)) self.setup_session() self.setup_mappers() @@ -762,7 +763,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase): ) e2 = testing_engine() - with e2.connect() as conn: + with e2.begin() as conn: for i in [2, 4]: conn.exec_driver_sql( "CREATE SCHEMA IF NOT EXISTS shard%s" % (i,) @@ -784,7 +785,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase): for i in [1, 3]: os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT)) - with self.postgresql_engine.connect() as conn: + with self.postgresql_engine.begin() as conn: self.metadata.drop_all(conn) for i in [2, 4]: conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,)) diff --git a/test/orm/inheritance/test_selects.py b/test/orm/inheritance/test_selects.py index c9a78db081..dab1841943 100644 --- a/test/orm/inheritance/test_selects.py +++ b/test/orm/inheritance/test_selects.py @@ -2,7 +2,6 @@ from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String -from sqlalchemy import testing from sqlalchemy.orm import mapper from sqlalchemy.orm import Session from sqlalchemy.testing import eq_ @@ -24,13 +23,13 @@ class InheritingSelectablesTest(fixtures.MappedTest): cls.tables.bar = foo.select(foo.c.b == "bar").alias("bar") cls.tables.baz = foo.select(foo.c.b == "baz").alias("baz") - def test_load(self): + def test_load(self, connection): foo, bar, baz = self.tables.foo, self.tables.bar, self.tables.baz # TODO: add persistence test also - testing.db.execute(foo.insert(), a="not bar", b="baz") - testing.db.execute(foo.insert(), a="also not bar", b="baz") - testing.db.execute(foo.insert(), a="i am bar", b="bar") - testing.db.execute(foo.insert(), a="also bar", b="bar") + connection.execute(foo.insert(), dict(a="not bar", b="baz")) + connection.execute(foo.insert(), dict(a="also not bar", b="baz")) + connection.execute(foo.insert(), dict(a="i am bar", b="bar")) + connection.execute(foo.insert(), dict(a="also bar", b="bar")) class Foo(fixtures.ComparableEntity): pass @@ -69,8 +68,8 @@ class InheritingSelectablesTest(fixtures.MappedTest): polymorphic_identity="bar", ) - s = Session() - assert [Bar(), Bar()] == s.query(Bar).all() + s = Session(connection) + eq_(s.query(Bar).all(), [Bar(), Bar()]) class JoinFromSelectPersistenceTest(fixtures.MappedTest): diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py index 3a99598570..64f85b3351 100644 --- a/test/orm/test_bind.py +++ b/test/orm/test_bind.py @@ -151,7 +151,7 @@ class BindIntegrationTest(_fixtures.FixtureTest): mapper(User, users) - session = create_session() + session = Session() session.execute(users.insert(), dict(name="Johnny")) @@ -447,7 +447,9 @@ class BindIntegrationTest(_fixtures.FixtureTest): sess.commit() assert not c.in_transaction() assert c.exec_driver_sql("select count(1) from users").scalar() == 1 - c.exec_driver_sql("delete from users") + + with c.begin(): + c.exec_driver_sql("delete from users") assert c.exec_driver_sql("select count(1) from users").scalar() == 0 c = testing.db.connect() diff --git a/test/orm/test_compile.py b/test/orm/test_compile.py index dcf07eec8a..c6a1226d4b 100644 --- a/test/orm/test_compile.py +++ b/test/orm/test_compile.py @@ -190,8 +190,9 @@ class CompileTest(fixtures.ORMTest): sa_exc.ArgumentError, "Error creating backref", configure_mappers ) - def test_misc_one(self): - metadata = MetaData(testing.db) + @testing.provide_metadata + def test_misc_one(self, connection): + metadata = self.metadata node_table = Table( "node", metadata, @@ -212,33 +213,30 @@ class CompileTest(fixtures.ORMTest): Column("host_id", Integer, primary_key=True), Column("hostname", String(64), nullable=False, unique=True), ) - metadata.create_all() - try: - node_table.insert().execute(node_id=1, node_index=5) - - class Node(object): - pass - - class NodeName(object): - pass - - class Host(object): - pass - - mapper(Node, node_table) - mapper(Host, host_table) - mapper( - NodeName, - node_name_table, - properties={ - "node": relationship(Node, backref=backref("names")), - "host": relationship(Host), - }, - ) - sess = create_session() - assert sess.query(Node).get(1).names == [] - finally: - metadata.drop_all() + metadata.create_all(connection) + connection.execute(node_table.insert(), dict(node_id=1, node_index=5)) + + class Node(object): + pass + + class NodeName(object): + pass + + class Host(object): + pass + + mapper(Node, node_table) + mapper(Host, host_table) + mapper( + NodeName, + node_name_table, + properties={ + "node": relationship(Node, backref=backref("names")), + "host": relationship(Host), + }, + ) + sess = create_session(connection) + assert sess.query(Node).get(1).names == [] def test_conflicting_backref_two(self): meta = MetaData() diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index 57225d6406..7bc82b2a3a 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -4808,6 +4808,8 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL): class SubqueryTest(fixtures.MappedTest): + run_deletes = "each" + @classmethod def define_tables(cls, metadata): Table( @@ -4830,7 +4832,12 @@ class SubqueryTest(fixtures.MappedTest): Column("score2", sa.Float), ) - def test_label_anonymizing(self): + @testing.combinations( + (True, "score"), + (True, None), + (False, None), + ) + def test_label_anonymizing(self, labeled, labelname): """Eager loading works with subqueries with labels, Even if an explicit labelname which conflicts with a label on the @@ -4859,75 +4866,65 @@ class SubqueryTest(fixtures.MappedTest): def prop_score(self): return self.score1 * self.score2 - for labeled, labelname in [ - (True, "score"), - (True, None), - (False, None), - ]: - sa.orm.clear_mappers() - - tag_score = tags_table.c.score1 * tags_table.c.score2 - user_score = sa.select( - sa.func.sum(tags_table.c.score1 * tags_table.c.score2) - ).where( - tags_table.c.user_id == users_table.c.id, - ) + tag_score = tags_table.c.score1 * tags_table.c.score2 + user_score = sa.select( + sa.func.sum(tags_table.c.score1 * tags_table.c.score2) + ).where( + tags_table.c.user_id == users_table.c.id, + ) - if labeled: - tag_score = tag_score.label(labelname) - user_score = user_score.label(labelname) - else: - user_score = user_score.scalar_subquery() + if labeled: + tag_score = tag_score.label(labelname) + user_score = user_score.label(labelname) + else: + user_score = user_score.scalar_subquery() - mapper( - Tag, - tags_table, - properties={"query_score": sa.orm.column_property(tag_score)}, - ) + mapper( + Tag, + tags_table, + properties={"query_score": sa.orm.column_property(tag_score)}, + ) - mapper( - User, - users_table, - properties={ - "tags": relationship(Tag, backref="user", lazy="joined"), - "query_score": sa.orm.column_property(user_score), - }, - ) + mapper( + User, + users_table, + properties={ + "tags": relationship(Tag, backref="user", lazy="joined"), + "query_score": sa.orm.column_property(user_score), + }, + ) - session = create_session() - session.add( - User( - name="joe", - tags=[ - Tag(score1=5.0, score2=3.0), - Tag(score1=55.0, score2=1.0), - ], - ) + session = create_session() + session.add( + User( + name="joe", + tags=[ + Tag(score1=5.0, score2=3.0), + Tag(score1=55.0, score2=1.0), + ], ) - session.add( - User( - name="bar", - tags=[ - Tag(score1=5.0, score2=4.0), - Tag(score1=50.0, score2=1.0), - Tag(score1=15.0, score2=2.0), - ], - ) + ) + session.add( + User( + name="bar", + tags=[ + Tag(score1=5.0, score2=4.0), + Tag(score1=50.0, score2=1.0), + Tag(score1=15.0, score2=2.0), + ], ) - session.flush() - session.expunge_all() - - for user in session.query(User).all(): - eq_(user.query_score, user.prop_score) + ) + session.flush() + session.expunge_all() - def go(): - u = session.query(User).filter_by(name="joe").one() - eq_(u.query_score, u.prop_score) + for user in session.query(User).all(): + eq_(user.query_score, user.prop_score) - self.assert_sql_count(testing.db, go, 1) + def go(): + u = session.query(User).filter_by(name="joe").one() + eq_(u.query_score, u.prop_score) - for t in (tags_table, users_table): - t.delete().execute() + self.assert_sql_count(testing.db, go, 1) class CorrelatedSubqueryTest(fixtures.MappedTest): diff --git a/test/orm/test_expire.py b/test/orm/test_expire.py index 7ccf2c1aee..5abaa03db5 100644 --- a/test/orm/test_expire.py +++ b/test/orm/test_expire.py @@ -9,7 +9,6 @@ from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.orm import attributes -from sqlalchemy.orm import create_session from sqlalchemy.orm import defer from sqlalchemy.orm import deferred from sqlalchemy.orm import exc as orm_exc @@ -26,6 +25,7 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import create_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect @@ -66,7 +66,7 @@ class ExpireTest(_fixtures.FixtureTest): u.name = "foo" sess.flush() # change the value in the DB - users.update(users.c.id == 7, values=dict(name="jack")).execute() + sess.execute(users.update(users.c.id == 7, values=dict(name="jack"))) sess.expire(u) # object isn't refreshed yet, using dict to bypass trigger assert u.__dict__.get("name") != "jack" @@ -471,7 +471,7 @@ class ExpireTest(_fixtures.FixtureTest): o = sess.query(Order).get(3) sess.expire(o) - orders.update().execute(description="order 3 modified") + sess.execute(orders.update(), dict(description="order 3 modified")) assert o.isopen == 1 assert ( attributes.instance_state(o).dict["description"] @@ -788,7 +788,7 @@ class ExpireTest(_fixtures.FixtureTest): sess.expire(u) assert "name" not in u.__dict__ - users.update(users.c.id == 7).execute(name="jack2") + sess.execute(users.update(users.c.id == 7), dict(name="jack2")) assert u.name == "jack2" assert u.uname == "jack2" assert "name" in u.__dict__ @@ -812,7 +812,10 @@ class ExpireTest(_fixtures.FixtureTest): assert "description" not in o.__dict__ assert attributes.instance_state(o).dict["isopen"] == 1 - orders.update(orders.c.id == 3).execute(description="order 3 modified") + sess.execute( + orders.update(orders.c.id == 3), + dict(description="order 3 modified"), + ) def go(): assert o.description == "order 3 modified" @@ -1660,12 +1663,9 @@ class LifecycleTest(fixtures.MappedTest): def test_cols_missing_in_load(self): Data = self.classes.Data - sess = create_session() - - d1 = Data(data="d1") - sess.add(d1) - sess.flush() - sess.close() + with Session(testing.db) as sess, sess.begin(): + d1 = Data(data="d1") + sess.add(d1) sess = create_session() d1 = sess.query(Data).from_statement(select(Data.id)).first() @@ -1679,21 +1679,18 @@ class LifecycleTest(fixtures.MappedTest): def test_deferred_cols_missing_in_load_state_reset(self): Data = self.classes.DataDefer - sess = create_session() + with Session(testing.db) as sess, sess.begin(): + d1 = Data(data="d1") + sess.add(d1) - d1 = Data(data="d1") - sess.add(d1) - sess.flush() - sess.close() - - sess = create_session() - d1 = ( - sess.query(Data) - .from_statement(select(Data.id)) - .options(undefer(Data.data)) - .first() - ) - d1.data = "d2" + with Session(testing.db) as sess: + d1 = ( + sess.query(Data) + .from_statement(select(Data.id)) + .options(undefer(Data.data)) + .first() + ) + d1.data = "d2" # the deferred loader has to clear out any state # on the col, including that 'd2' here diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index c1cc85261f..e1c0ec77b8 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -1302,18 +1302,22 @@ class O2MWOSideFixedTest(fixtures.MappedTest): def _fixture(self, include_other): city, person = self.tables.city, self.tables.person - if include_other: - city.insert().execute({"id": 1, "deleted": False}) - - person.insert().execute( - {"id": 1, "city_id": 1}, {"id": 2, "city_id": 1} - ) + with testing.db.begin() as conn: + if include_other: + conn.execute(city.insert(), {"id": 1, "deleted": False}) + + conn.execute( + person.insert(), + {"id": 1, "city_id": 1}, + {"id": 2, "city_id": 1}, + ) - city.insert().execute({"id": 2, "deleted": True}) + conn.execute(city.insert(), {"id": 2, "deleted": True}) - person.insert().execute( - {"id": 3, "city_id": 2}, {"id": 4, "city_id": 2} - ) + conn.execute( + person.insert(), + [{"id": 3, "city_id": 2}, {"id": 4, "city_id": 2}], + ) def test_lazyload_assert_expected_sql(self): self._fixture(True) diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index fc6caa75d4..edbb4b0cd0 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -129,7 +129,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): ) assert_raises(sa.exc.ArgumentError, sa.orm.configure_mappers) - def test_update_attr_keys(self): + def test_update_attr_keys(self, connection): """test that update()/insert() use the correct key when given InstrumentedAttributes.""" @@ -137,21 +137,21 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): self.mapper(User, users, properties={"foobar": users.c.name}) - users.insert().values({User.foobar: "name1"}).execute() + connection.execute(users.insert().values({User.foobar: "name1"})) eq_( - sa.select(User.foobar) - .where(User.foobar == "name1") - .execute() - .fetchall(), + connection.execute( + sa.select(User.foobar).where(User.foobar == "name1") + ).fetchall(), [("name1",)], ) - users.update().values({User.foobar: User.foobar + "foo"}).execute() + connection.execute( + users.update().values({User.foobar: User.foobar + "foo"}) + ) eq_( - sa.select(User.foobar) - .where(User.foobar == "name1foo") - .execute() - .fetchall(), + connection.execute( + sa.select(User.foobar).where(User.foobar == "name1foo") + ).fetchall(), [("name1foo",)], ) diff --git a/test/orm/test_naturalpks.py b/test/orm/test_naturalpks.py index 87ec0d79d3..d814b0cab8 100644 --- a/test/orm/test_naturalpks.py +++ b/test/orm/test_naturalpks.py @@ -12,7 +12,6 @@ from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import TypeDecorator -from sqlalchemy.orm import create_session from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship from sqlalchemy.orm import Session @@ -23,6 +22,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import ne_ +from sqlalchemy.testing.fixtures import create_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from test.orm import _fixtures @@ -141,7 +141,9 @@ class NaturalPKTest(fixtures.MappedTest): sess.flush() assert sess.query(User).get("jack") is u1 - users.update(values={User.username: "jack"}).execute(username="ed") + sess.execute( + users.update(values={User.username: "jack"}), dict(username="ed") + ) # expire/refresh works off of primary key. the PK is gone # in this case so there's no way to look it up. criterion- @@ -1089,7 +1091,7 @@ class NonPKCascadeTest(fixtures.MappedTest): a1 = u1.addresses[0] eq_( - sa.select(addresses.c.username).execute().fetchall(), + sess.execute(sa.select(addresses.c.username)).fetchall(), [("jack",), ("jack",)], ) @@ -1099,7 +1101,7 @@ class NonPKCascadeTest(fixtures.MappedTest): sess.flush() assert u1.addresses[0].username == "ed" eq_( - sa.select(addresses.c.username).execute().fetchall(), + sess.execute(sa.select(addresses.c.username)).fetchall(), [("ed",), ("ed",)], ) @@ -1141,7 +1143,7 @@ class NonPKCascadeTest(fixtures.MappedTest): eq_(a1.username, None) eq_( - sa.select(addresses.c.username).execute().fetchall(), + sess.execute(sa.select(addresses.c.username)).fetchall(), [(None,), (None,)], ) @@ -1454,7 +1456,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): eq_(a1.username, "ed") eq_(a2.username, "ed") eq_( - sa.select(addresses.c.username).execute().fetchall(), + sess.execute(sa.select(addresses.c.username)).fetchall(), [("ed",), ("ed",)], ) @@ -1465,7 +1467,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL): eq_(a1.username, "jack") eq_(a2.username, "jack") eq_( - sa.select(addresses.c.username).execute().fetchall(), + sess.execute(sa.select(addresses.c.username)).fetchall(), [("jack",), ("jack",)], ) diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 8cca45b270..9e528dc0d4 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -806,7 +806,7 @@ class GetTest(QueryTest): @testing.provide_metadata @testing.requires.unicode_connections - def test_unicode(self): + def test_unicode(self, connection): """test that Query.get properly sets up the type for the bind parameter. using unicode would normally fail on postgresql, mysql and oracle unless it is converted to an encoded string""" @@ -818,19 +818,20 @@ class GetTest(QueryTest): Column("id", Unicode(40), primary_key=True), Column("data", Unicode(40)), ) - metadata.create_all() + metadata.create_all(connection) ustring = util.b("petit voix m\xe2\x80\x99a").decode("utf-8") - table.insert().execute(id=ustring, data=ustring) + connection.execute(table.insert(), dict(id=ustring, data=ustring)) class LocalFoo(self.classes.Base): pass mapper(LocalFoo, table) - eq_( - create_session().query(LocalFoo).get(ustring), - LocalFoo(id=ustring, data=ustring), - ) + with Session(connection) as sess: + eq_( + sess.get(LocalFoo, ustring), + LocalFoo(id=ustring, data=ustring), + ) def test_populate_existing(self): User, Address = self.classes.User, self.classes.Address diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 1650082346..d2838e5bf8 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -12,7 +12,6 @@ from sqlalchemy import testing from sqlalchemy.orm import attributes from sqlalchemy.orm import backref from sqlalchemy.orm import close_all_sessions -from sqlalchemy.orm import create_session from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import joinedload from sqlalchemy.orm import make_transient @@ -35,6 +34,7 @@ from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing import pickleable +from sqlalchemy.testing.fixtures import create_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect @@ -48,33 +48,33 @@ class ExecutionTest(_fixtures.FixtureTest): __backend__ = True @testing.requires.sequences - def test_sequence_execute(self): + def test_sequence_execute(self, connection): seq = Sequence("some_sequence") - seq.create(testing.db) + seq.create(connection) try: - sess = create_session(bind=testing.db) - eq_(sess.execute(seq), testing.db.dialect.default_sequence_base) + sess = Session(connection) + eq_(sess.execute(seq), connection.dialect.default_sequence_base) finally: - seq.drop(testing.db) + seq.drop(connection) - def test_textual_execute(self): + def test_textual_execute(self, connection): """test that Session.execute() converts to text()""" users = self.tables.users - sess = create_session(bind=self.metadata.bind) - users.insert().execute(id=7, name="jack") + with Session(bind=connection) as sess: + sess.execute(users.insert(), dict(id=7, name="jack")) - # use :bindparam style - eq_( - sess.execute( - "select * from users where id=:id", {"id": 7} - ).fetchall(), - [(7, "jack")], - ) + # use :bindparam style + eq_( + sess.execute( + "select * from users where id=:id", {"id": 7} + ).fetchall(), + [(7, "jack")], + ) - # use :bindparam style - eq_(sess.scalar("select id from users where id=:id", {"id": 7}), 7) + # use :bindparam style + eq_(sess.scalar("select id from users where id=:id", {"id": 7}), 7) def test_parameter_execute(self): users = self.tables.users @@ -104,7 +104,7 @@ class TransScopingTest(_fixtures.FixtureTest): c.exec_driver_sql("select * from users") mapper(User, users) - s = create_session(bind=c) + s = Session(bind=c) s.add(User(name="first")) s.flush() c.exec_driver_sql("select * from users") @@ -118,7 +118,7 @@ class TransScopingTest(_fixtures.FixtureTest): c.exec_driver_sql("select * from users") mapper(User, users) - s = create_session(bind=c) + s = Session(bind=c) s.add(User(name="first")) s.flush() c.exec_driver_sql("select * from users") @@ -189,7 +189,7 @@ class TransScopingTest(_fixtures.FixtureTest): conn1 = testing.db.connect() conn2 = testing.db.connect() - sess = create_session(autocommit=False, bind=conn1) + sess = Session(autocommit=False, bind=conn1) u = User(name="x") sess.add(u) sess.flush() @@ -415,7 +415,7 @@ class SessionStateTest(_fixtures.FixtureTest): conn1 = bind.connect() conn2 = bind.connect() - sess = create_session(bind=conn1, autocommit=False, autoflush=True) + sess = Session(bind=conn1, autocommit=False, autoflush=True) u = User() u.name = "ed" sess.add(u) @@ -600,7 +600,7 @@ class SessionStateTest(_fixtures.FixtureTest): mapper(User, users) conn1 = testing.db.connect() - sess = create_session(bind=conn1, autocommit=False, autoflush=True) + sess = Session(bind=conn1, autocommit=False, autoflush=True) u = User() u.name = "ed" sess.add(u) @@ -620,7 +620,7 @@ class SessionStateTest(_fixtures.FixtureTest): User, users = self.classes.User, self.tables.users mapper(User, users) - session = create_session(autocommit=True) + session = Session(testing.db, autocommit=True) session.add(User(name="ed")) @@ -629,7 +629,7 @@ class SessionStateTest(_fixtures.FixtureTest): session.commit() def test_active_flag_autocommit(self): - sess = create_session(bind=config.db, autocommit=True) + sess = Session(bind=config.db, autocommit=True) assert not sess.is_active sess.begin() assert sess.is_active @@ -637,7 +637,7 @@ class SessionStateTest(_fixtures.FixtureTest): assert not sess.is_active def test_active_flag_autobegin(self): - sess = create_session(bind=config.db, autocommit=False) + sess = Session(bind=config.db, autocommit=False) assert sess.is_active assert not sess.in_transaction() sess.begin() @@ -646,7 +646,7 @@ class SessionStateTest(_fixtures.FixtureTest): assert sess.is_active def test_active_flag_autobegin_future(self): - sess = create_session(bind=config.db, future=True) + sess = Session(bind=config.db, future=True) assert sess.is_active assert not sess.in_transaction() sess.begin() @@ -655,7 +655,7 @@ class SessionStateTest(_fixtures.FixtureTest): assert sess.is_active def test_active_flag_partial_rollback(self): - sess = create_session(bind=config.db, autocommit=False) + sess = Session(bind=config.db, autocommit=False) assert sess.is_active assert not sess.in_transaction() sess.begin() @@ -693,7 +693,7 @@ class SessionStateTest(_fixtures.FixtureTest): ) s.add(user) - s.flush() + s.commit() user = s.query(User).one() s.expunge(user) assert user not in s @@ -703,8 +703,7 @@ class SessionStateTest(_fixtures.FixtureTest): s.add(user) assert user in s assert user in s.dirty - s.flush() - s.expunge_all() + s.commit() assert s.query(User).count() == 1 user = s.query(User).one() assert user.name == "fred" @@ -766,8 +765,9 @@ class SessionStateTest(_fixtures.FixtureTest): users, User = self.tables.users, self.classes.User mapper(User, users) - for s in (create_session(), create_session()): - users.delete().execute() + + with create_session() as s: + s.execute(users.delete()) u1 = User(name="ed") s.add(u1) s.flush() @@ -1774,7 +1774,8 @@ class DisposedStates(fixtures.MappedTest): def _test_session(self, **kwargs): T = self.classes.T - sess = create_session(**kwargs) + + sess = Session(config.db, **kwargs) data = o1, o2, o3, o4, o5 = [ T("t1"), @@ -1786,7 +1787,7 @@ class DisposedStates(fixtures.MappedTest): sess.add_all(data) - sess.flush() + sess.commit() o1.data = "t1modified" o5.data = "t5modified" @@ -1925,7 +1926,7 @@ class SessionInterface(fixtures.TestBase): def raises_(method, *args, **kw): watchdog.add(method) - callable_ = getattr(create_session(), method) + callable_ = getattr(Session(), method) if is_class: assert_raises( sa.orm.exc.UnmappedClassError, callable_, *args, **kw diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index e8f6c5c405..248f334cf6 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -1951,9 +1951,7 @@ class AccountingFlagsTest(_LocalFixture): sess.add(u1) sess.commit() - testing.db.execute( - users.update(users.c.name == "ed").values(name="edward") - ) + sess.execute(users.update(users.c.name == "ed").values(name="edward")) assert u1.name == "ed" sess.expire_all() diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index ed320db104..31386b07f5 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -778,7 +778,8 @@ class SingleCycleTest(UOWTest): # mysql can't handle delete from nodes # since it doesn't deal with the FKs correctly, # so wipe out the parent_id first - testing.db.execute(self.tables.nodes.update().values(parent_id=None)) + with testing.db.begin() as conn: + conn.execute(self.tables.nodes.update().values(parent_id=None)) super(SingleCycleTest, self).teardown() def test_one_to_many_save(self): diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 4a6ebd0c83..2a2e70bc39 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -1012,9 +1012,7 @@ class PKIncrementTest(fixtures.TablesTest): Column("str1", String(20)), ) - # TODO: add coverage for increment on a secondary column in a key - @testing.fails_on("firebird", "Data type unknown") - def _test_autoincrement(self, connection): + def test_autoincrement(self, connection): aitable = self.tables.aitable ids = set() @@ -1064,14 +1062,6 @@ class PKIncrementTest(fixtures.TablesTest): ], ) - def test_autoincrement_autocommit(self): - with testing.db.connect() as conn: - self._test_autoincrement(conn) - - def test_autoincrement_transaction(self): - with testing.db.begin() as conn: - self._test_autoincrement(conn) - class EmptyInsertTest(fixtures.TestBase): __backend__ = True @@ -1267,7 +1257,7 @@ class SpecialTypePKTest(fixtures.TestBase): implicit_returning=implicit_returning, ) - with testing.db.connect() as conn: + with testing.db.begin() as conn: t.create(conn) r = conn.execute(t.insert().values(data=5)) diff --git a/test/sql/test_delete.py b/test/sql/test_delete.py index 934022560f..6f7b3f8f5d 100644 --- a/test/sql/test_delete.py +++ b/test/sql/test_delete.py @@ -308,32 +308,31 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): ) @testing.requires.delete_from - def test_exec_two_table(self): + def test_exec_two_table(self, connection): users, addresses = self.tables.users, self.tables.addresses dingalings = self.tables.dingalings - with testing.db.connect() as conn: - conn.execute(dingalings.delete()) # fk violation otherwise + connection.execute(dingalings.delete()) # fk violation otherwise - conn.execute( - addresses.delete() - .where(users.c.id == addresses.c.user_id) - .where(users.c.name == "ed") - ) + connection.execute( + addresses.delete() + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + ) - expected = [ - (1, 7, "x", "jack@bean.com"), - (5, 9, "x", "fred@fred.com"), - ] - self._assert_table(addresses, expected) + expected = [ + (1, 7, "x", "jack@bean.com"), + (5, 9, "x", "fred@fred.com"), + ] + self._assert_table(connection, addresses, expected) @testing.requires.delete_from - def test_exec_three_table(self): + def test_exec_three_table(self, connection): users = self.tables.users addresses = self.tables.addresses dingalings = self.tables.dingalings - testing.db.execute( + connection.execute( dingalings.delete() .where(users.c.id == addresses.c.user_id) .where(users.c.name == "ed") @@ -341,34 +340,33 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): ) expected = [(2, 5, "ding 2/5")] - self._assert_table(dingalings, expected) + self._assert_table(connection, dingalings, expected) @testing.requires.delete_from - def test_exec_two_table_plus_alias(self): + def test_exec_two_table_plus_alias(self, connection): users, addresses = self.tables.users, self.tables.addresses dingalings = self.tables.dingalings - with testing.db.connect() as conn: - conn.execute(dingalings.delete()) # fk violation otherwise - a1 = addresses.alias() - conn.execute( - addresses.delete() - .where(users.c.id == addresses.c.user_id) - .where(users.c.name == "ed") - .where(a1.c.id == addresses.c.id) - ) + connection.execute(dingalings.delete()) # fk violation otherwise + a1 = addresses.alias() + connection.execute( + addresses.delete() + .where(users.c.id == addresses.c.user_id) + .where(users.c.name == "ed") + .where(a1.c.id == addresses.c.id) + ) expected = [(1, 7, "x", "jack@bean.com"), (5, 9, "x", "fred@fred.com")] - self._assert_table(addresses, expected) + self._assert_table(connection, addresses, expected) @testing.requires.delete_from - def test_exec_alias_plus_table(self): + def test_exec_alias_plus_table(self, connection): users, addresses = self.tables.users, self.tables.addresses dingalings = self.tables.dingalings d1 = dingalings.alias() - testing.db.execute( + connection.execute( delete(d1) .where(users.c.id == addresses.c.user_id) .where(users.c.name == "ed") @@ -376,8 +374,8 @@ class DeleteFromRoundTripTest(fixtures.TablesTest): ) expected = [(2, 5, "ding 2/5")] - self._assert_table(dingalings, expected) + self._assert_table(connection, dingalings, expected) - def _assert_table(self, table, expected): + def _assert_table(self, connection, table, expected): stmt = table.select().order_by(table.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) diff --git a/test/sql/test_deprecations.py b/test/sql/test_deprecations.py index c0d2e87e86..e082cf55d0 100644 --- a/test/sql/test_deprecations.py +++ b/test/sql/test_deprecations.py @@ -23,6 +23,7 @@ from sqlalchemy import MetaData from sqlalchemy import null from sqlalchemy import or_ from sqlalchemy import select +from sqlalchemy import Sequence from sqlalchemy import sql from sqlalchemy import String from sqlalchemy import table @@ -1271,6 +1272,165 @@ class KeyTargetingTest(fixtures.TablesTest): in_(stmt.c.keyed2_b, row) +class PKIncrementTest(fixtures.TablesTest): + run_define_tables = "each" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "aitable", + metadata, + Column( + "id", + Integer, + Sequence("ai_id_seq", optional=True), + primary_key=True, + ), + Column("int1", Integer), + Column("str1", String(20)), + ) + + def _test_autoincrement(self, connection): + aitable = self.tables.aitable + + ids = set() + rs = connection.execute(aitable.insert(), int1=1) + last = rs.inserted_primary_key[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + rs = connection.execute(aitable.insert(), str1="row 2") + last = rs.inserted_primary_key[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + rs = connection.execute(aitable.insert(), int1=3, str1="row 3") + last = rs.inserted_primary_key[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + rs = connection.execute( + aitable.insert().values({"int1": func.length("four")}) + ) + last = rs.inserted_primary_key[0] + self.assert_(last) + self.assert_(last not in ids) + ids.add(last) + + eq_( + ids, + set( + range( + testing.db.dialect.default_sequence_base, + testing.db.dialect.default_sequence_base + 4, + ) + ), + ) + + eq_( + list(connection.execute(aitable.select().order_by(aitable.c.id))), + [ + (testing.db.dialect.default_sequence_base, 1, None), + (testing.db.dialect.default_sequence_base + 1, None, "row 2"), + (testing.db.dialect.default_sequence_base + 2, 3, "row 3"), + (testing.db.dialect.default_sequence_base + 3, 4, None), + ], + ) + + def test_autoincrement_autocommit(self): + with testing.db.connect() as conn: + with testing.expect_deprecated_20( + "The current statement is being autocommitted using " + "implicit autocommit, " + ): + self._test_autoincrement(conn) + + +class ConnectionlessCursorResultTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "users", + metadata, + Column( + "user_id", INT, primary_key=True, test_needs_autoincrement=True + ), + Column("user_name", VARCHAR(20)), + test_needs_acid=True, + ) + + def test_connectionless_autoclose_rows_exhausted(self): + users = self.tables.users + with testing.db.begin() as conn: + conn.execute(users.insert(), dict(user_id=1, user_name="john")) + + with testing.expect_deprecated_20( + r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method" + ): + result = testing.db.execute(text("select * from users")) + connection = result.connection + assert not connection.closed + eq_(result.fetchone(), (1, "john")) + assert not connection.closed + eq_(result.fetchone(), None) + assert connection.closed + + @testing.requires.returning + def test_connectionless_autoclose_crud_rows_exhausted(self): + users = self.tables.users + stmt = ( + users.insert() + .values(user_id=1, user_name="john") + .returning(users.c.user_id) + ) + with testing.expect_deprecated_20( + r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method" + ): + result = testing.db.execute(stmt) + connection = result.connection + assert not connection.closed + eq_(result.fetchone(), (1,)) + assert not connection.closed + eq_(result.fetchone(), None) + assert connection.closed + + def test_connectionless_autoclose_no_rows(self): + with testing.expect_deprecated_20( + r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method" + ): + result = testing.db.execute(text("select * from users")) + connection = result.connection + assert not connection.closed + eq_(result.fetchone(), None) + assert connection.closed + + @testing.requires.updateable_autoincrement_pks + def test_connectionless_autoclose_no_metadata(self): + with testing.expect_deprecated_20( + r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method" + ): + result = testing.db.execute(text("update users set user_id=5")) + connection = result.connection + assert connection.closed + + assert_raises_message( + exc.ResourceClosedError, + "This result object does not return rows.", + result.fetchone, + ) + assert_raises_message( + exc.ResourceClosedError, + "This result object does not return rows.", + result.keys, + ) + + class CursorResultTest(fixtures.TablesTest): __backend__ = True @@ -1436,7 +1596,7 @@ class CursorResultTest(fixtures.TablesTest): def test_pickled_rows(self): users = self.tables.users addresses = self.tables.addresses - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(users.delete()) conn.execute( users.insert(), @@ -2319,3 +2479,93 @@ class LegacyOperatorTest(AssertsCompiledSQL, fixtures.TestBase): _op_modern = getattr(operators.ColumnOperators, _modern) _op_legacy = getattr(operators.ColumnOperators, _legacy) assert _op_modern == _op_legacy + + +class LegacySequenceExecTest(fixtures.TestBase): + __requires__ = ("sequences",) + __backend__ = True + + @classmethod + def setup_class(cls): + cls.seq = Sequence("my_sequence") + cls.seq.create(testing.db) + + @classmethod + def teardown_class(cls): + cls.seq.drop(testing.db) + + def _assert_seq_result(self, ret): + """asserts return of next_value is an int""" + + assert isinstance(ret, util.int_types) + assert ret >= testing.db.dialect.default_sequence_base + + def test_implicit_connectionless(self): + with testing.expect_deprecated_20( + r"The MetaData.bind argument is deprecated" + ): + s = Sequence("my_sequence", metadata=MetaData(testing.db)) + + with testing.expect_deprecated_20( + r"The DefaultGenerator.execute\(\) method is considered legacy " + "as of the 1.x", + ): + self._assert_seq_result(s.execute()) + + def test_explicit(self, connection): + s = Sequence("my_sequence") + with testing.expect_deprecated_20( + r"The DefaultGenerator.execute\(\) method is considered legacy" + ): + self._assert_seq_result(s.execute(connection)) + + def test_explicit_optional(self): + """test dialect executes a Sequence, returns nextval, whether + or not "optional" is set""" + + s = Sequence("my_sequence", optional=True) + with testing.expect_deprecated_20( + r"The DefaultGenerator.execute\(\) method is considered legacy" + ): + self._assert_seq_result(s.execute(testing.db)) + + def test_func_implicit_connectionless_execute(self): + """test func.next_value().execute()/.scalar() works + with connectionless execution.""" + + with testing.expect_deprecated_20( + r"The MetaData.bind argument is deprecated" + ): + s = Sequence("my_sequence", metadata=MetaData(testing.db)) + with testing.expect_deprecated_20( + r"The Executable.execute\(\) method is considered legacy" + ): + self._assert_seq_result(s.next_value().execute().scalar()) + + def test_func_explicit(self): + s = Sequence("my_sequence") + with testing.expect_deprecated_20( + r"The Engine.scalar\(\) method is considered legacy" + ): + self._assert_seq_result(testing.db.scalar(s.next_value())) + + def test_func_implicit_connectionless_scalar(self): + """test func.next_value().execute()/.scalar() works. """ + + with testing.expect_deprecated_20( + r"The MetaData.bind argument is deprecated" + ): + s = Sequence("my_sequence", metadata=MetaData(testing.db)) + with testing.expect_deprecated_20( + r"The Executable.execute\(\) method is considered legacy" + ): + self._assert_seq_result(s.next_value().scalar()) + + def test_func_embedded_select(self): + """test can use next_value() in select column expr""" + + s = Sequence("my_sequence") + with testing.expect_deprecated_20( + r"The Engine.scalar\(\) method is considered legacy" + ): + self._assert_seq_result(testing.db.scalar(select(s.next_value()))) diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 7d05462abb..6d26f79758 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -84,7 +84,7 @@ class QueryTest(fixtures.TestBase): @engines.close_first def teardown(self): - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(addresses.delete()) conn.execute(users.delete()) conn.execute(users2.delete()) @@ -878,21 +878,22 @@ class RequiredBindTest(fixtures.TablesTest): ) def _assert_raises(self, stmt, params): - assert_raises_message( - exc.StatementError, - "A value is required for bind parameter 'x'", - testing.db.execute, - stmt, - **params - ) + with testing.db.connect() as conn: + assert_raises_message( + exc.StatementError, + "A value is required for bind parameter 'x'", + conn.execute, + stmt, + **params + ) - assert_raises_message( - exc.StatementError, - "A value is required for bind parameter 'x'", - testing.db.execute, - stmt, - params, - ) + assert_raises_message( + exc.StatementError, + "A value is required for bind parameter 'x'", + conn.execute, + stmt, + params, + ) def test_insert(self): stmt = self.tables.foo.insert().values( @@ -953,7 +954,7 @@ class LimitTest(fixtures.TestBase): ) metadata.create_all() - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute(users.insert(), user_id=1, user_name="john") conn.execute( addresses.insert(), address_id=1, user_id=1, address="addr1" @@ -1105,7 +1106,7 @@ class CompoundTest(fixtures.TestBase): ) metadata.create_all() - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute( t1.insert(), [ @@ -1470,7 +1471,7 @@ class JoinTest(fixtures.TestBase): metadata.drop_all() metadata.create_all() - with testing.db.connect() as conn: + with testing.db.begin() as conn: # t1.10 -> t2.20 -> t3.30 # t1.11 -> t2.21 # t1.12 @@ -1823,7 +1824,7 @@ class OperatorTest(fixtures.TestBase): ) metadata.create_all() - with testing.db.connect() as conn: + with testing.db.begin() as conn: conn.execute( flds.insert(), [dict(intcol=5, strcol="foo"), dict(intcol=13, strcol="bar")], diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index 1c023e7b1f..a78d6c16b5 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -25,19 +25,12 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing.util import picklers -class QuoteExecTest(fixtures.TestBase): +class QuoteExecTest(fixtures.TablesTest): __backend__ = True @classmethod - def setup_class(cls): - # TODO: figure out which databases/which identifiers allow special - # characters to be used, such as: spaces, quote characters, - # punctuation characters, set up tests for those as well. - - global table1, table2 - metadata = MetaData(testing.db) - - table1 = Table( + def define_tables(cls, metadata): + Table( "WorstCase1", metadata, Column("lowercase", Integer, primary_key=True), @@ -45,7 +38,7 @@ class QuoteExecTest(fixtures.TestBase): Column("MixedCase", Integer), Column("ASC", Integer, key="a123"), ) - table2 = Table( + Table( "WorstCase2", metadata, Column("desc", Integer, primary_key=True, key="d123"), @@ -53,18 +46,6 @@ class QuoteExecTest(fixtures.TestBase): Column("MixedCase", Integer), ) - table1.create() - table2.create() - - def teardown(self): - table1.delete().execute() - table2.delete().execute() - - @classmethod - def teardown_class(cls): - table1.drop() - table2.drop() - def test_reflect(self): meta2 = MetaData() t2 = Table("WorstCase1", meta2, autoload_with=testing.db, quote=True) @@ -88,25 +69,22 @@ class QuoteExecTest(fixtures.TestBase): assert "MixedCase" in t2.c @testing.provide_metadata - def test_has_table_case_sensitive(self): + def test_has_table_case_sensitive(self, connection): preparer = testing.db.dialect.identifier_preparer - with testing.db.connect() as conn: - if conn.dialect.requires_name_normalize: - conn.exec_driver_sql("CREATE TABLE TAB1 (id INTEGER)") - else: - conn.exec_driver_sql("CREATE TABLE tab1 (id INTEGER)") - conn.exec_driver_sql( - "CREATE TABLE %s (id INTEGER)" - % preparer.quote_identifier("tab2") - ) - conn.exec_driver_sql( - "CREATE TABLE %s (id INTEGER)" - % preparer.quote_identifier("TAB3") - ) - conn.exec_driver_sql( - "CREATE TABLE %s (id INTEGER)" - % preparer.quote_identifier("TAB4") - ) + conn = connection + if conn.dialect.requires_name_normalize: + conn.exec_driver_sql("CREATE TABLE TAB1 (id INTEGER)") + else: + conn.exec_driver_sql("CREATE TABLE tab1 (id INTEGER)") + conn.exec_driver_sql( + "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("tab2") + ) + conn.exec_driver_sql( + "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("TAB3") + ) + conn.exec_driver_sql( + "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("TAB4") + ) t1 = Table( "tab1", self.metadata, Column("id", Integer, primary_key=True) @@ -127,7 +105,7 @@ class QuoteExecTest(fixtures.TestBase): quote=True, ) - insp = inspect(testing.db) + insp = inspect(connection) assert insp.has_table(t1.name) eq_([c["name"] for c in insp.get_columns(t1.name)], ["id"]) @@ -140,16 +118,24 @@ class QuoteExecTest(fixtures.TestBase): assert insp.has_table(t4.name) eq_([c["name"] for c in insp.get_columns(t4.name)], ["id"]) - def test_basic(self): - table1.insert().execute( - {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, - {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, - {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1}, + def test_basic(self, connection): + table1, table2 = self.tables("WorstCase1", "WorstCase2") + + connection.execute( + table1.insert(), + [ + {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1}, + ], ) - table2.insert().execute( - {"d123": 1, "u123": 2, "MixedCase": 3}, - {"d123": 2, "u123": 2, "MixedCase": 3}, - {"d123": 4, "u123": 3, "MixedCase": 2}, + connection.execute( + table2.insert(), + [ + {"d123": 1, "u123": 2, "MixedCase": 3}, + {"d123": 2, "u123": 2, "MixedCase": 3}, + {"d123": 4, "u123": 3, "MixedCase": 2}, + ], ) columns = [ @@ -158,23 +144,30 @@ class QuoteExecTest(fixtures.TestBase): table1.c.MixedCase, table1.c.a123, ] - result = select(columns).execute().fetchall() + result = connection.execute(select(columns)).all() assert result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)] columns = [table2.c.d123, table2.c.u123, table2.c.MixedCase] - result = select(columns).execute().fetchall() + result = connection.execute(select(columns)).all() assert result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)] - def test_use_labels(self): - table1.insert().execute( - {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, - {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, - {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1}, - ) - table2.insert().execute( - {"d123": 1, "u123": 2, "MixedCase": 3}, - {"d123": 2, "u123": 2, "MixedCase": 3}, - {"d123": 4, "u123": 3, "MixedCase": 2}, + def test_use_labels(self, connection): + table1, table2 = self.tables("WorstCase1", "WorstCase2") + connection.execute( + table1.insert(), + [ + {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4}, + {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1}, + ], + ) + connection.execute( + table2.insert(), + [ + {"d123": 1, "u123": 2, "MixedCase": 3}, + {"d123": 2, "u123": 2, "MixedCase": 3}, + {"d123": 4, "u123": 3, "MixedCase": 2}, + ], ) columns = [ @@ -183,11 +176,11 @@ class QuoteExecTest(fixtures.TestBase): table1.c.MixedCase, table1.c.a123, ] - result = select(columns, use_labels=True).execute().fetchall() + result = connection.execute(select(columns).apply_labels()).fetchall() assert result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)] columns = [table2.c.d123, table2.c.u123, table2.c.MixedCase] - result = select(columns, use_labels=True).execute().fetchall() + result = connection.execute(select(columns).apply_labels()).all() assert result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)] diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 9ef533be3a..db0e0d4c81 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -615,63 +615,6 @@ class CursorResultTest(fixtures.TablesTest): result.fetchone, ) - def test_connectionless_autoclose_rows_exhausted(self): - # TODO: deprecate for 2.0 - users = self.tables.users - with testing.db.connect() as conn: - conn.execute(users.insert(), dict(user_id=1, user_name="john")) - - result = testing.db.execute(text("select * from users")) - connection = result.connection - assert not connection.closed - eq_(result.fetchone(), (1, "john")) - assert not connection.closed - eq_(result.fetchone(), None) - assert connection.closed - - @testing.requires.returning - def test_connectionless_autoclose_crud_rows_exhausted(self): - # TODO: deprecate for 2.0 - users = self.tables.users - stmt = ( - users.insert() - .values(user_id=1, user_name="john") - .returning(users.c.user_id) - ) - result = testing.db.execute(stmt) - connection = result.connection - assert not connection.closed - eq_(result.fetchone(), (1,)) - assert not connection.closed - eq_(result.fetchone(), None) - assert connection.closed - - def test_connectionless_autoclose_no_rows(self): - # TODO: deprecate for 2.0 - result = testing.db.execute(text("select * from users")) - connection = result.connection - assert not connection.closed - eq_(result.fetchone(), None) - assert connection.closed - - @testing.requires.updateable_autoincrement_pks - def test_connectionless_autoclose_no_metadata(self): - # TODO: deprecate for 2.0 - result = testing.db.execute(text("update users set user_id=5")) - connection = result.connection - assert connection.closed - - assert_raises_message( - exc.ResourceClosedError, - "This result object does not return rows.", - result.fetchone, - ) - assert_raises_message( - exc.ResourceClosedError, - "This result object does not return rows.", - result.keys, - ) - def test_row_case_sensitive(self, connection): row = connection.execute( select( @@ -1285,7 +1228,7 @@ class CursorResultTest(fixtures.TablesTest): with patch.object( engine.dialect.execution_ctx_cls, "rowcount" ) as mock_rowcount: - with engine.connect() as conn: + with engine.begin() as conn: mock_rowcount.__get__ = Mock() conn.execute( t.insert(), {"data": "d1"}, {"data": "d2"}, {"data": "d3"} @@ -1362,20 +1305,14 @@ class CursorResultTest(fixtures.TablesTest): eq_(row[1:0:-1], ("Uno",)) @testing.requires.cextensions - def test_row_c_sequence_check(self): - # TODO: modernize for 2.0 - metadata = MetaData() - metadata.bind = "sqlite://" - users = Table( - "users", - metadata, - Column("id", Integer, primary_key=True), - Column("name", String(40)), - ) - users.create() + @testing.provide_metadata + def test_row_c_sequence_check(self, connection): + users = self.tables.users2 - users.insert().execute(name="Test") - row = users.select().execute().fetchone() + connection.execute(users.insert(), dict(user_id=1, user_name="Test")) + row = connection.execute( + users.select().where(users.c.user_id == 1) + ).fetchone() s = util.StringIO() writer = csv.writer(s) @@ -2340,7 +2277,7 @@ class AlternateCursorResultTest(fixtures.TablesTest): @testing.fixture def row_growth_fixture(self): with self._proxy_fixture(_cursor.BufferedRowCursorFetchStrategy): - with self.engine.connect() as conn: + with self.engine.begin() as conn: conn.execute( self.table.insert(), [{"x": i, "y": "t_%d" % i} for i in range(15, 3000)], diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 065205c45a..9f2afd7b7d 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -23,9 +23,6 @@ from sqlalchemy.testing.schema import Table from sqlalchemy.types import TypeDecorator -table = GoofyType = seq = None - - class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "postgresql" @@ -92,14 +89,14 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL): ) -class ReturningTest(fixtures.TestBase, AssertsExecutionResults): +class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): __requires__ = ("returning",) __backend__ = True - def setup(self): - meta = MetaData(testing.db) - global table, GoofyType + run_create_tables = "each" + @classmethod + def define_tables(cls, metadata): class GoofyType(TypeDecorator): impl = String @@ -113,9 +110,11 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): return None return value + "BAR" - table = Table( + cls.GoofyType = GoofyType + + Table( "tables", - meta, + metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True ), @@ -123,14 +122,9 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): Column("full", Boolean), Column("goofy", GoofyType(50)), ) - with testing.db.connect() as conn: - table.create(conn, checkfirst=True) - - def teardown(self): - with testing.db.connect() as conn: - table.drop(conn) def test_column_targeting(self, connection): + table = self.tables.tables result = connection.execute( table.insert().returning(table.c.id, table.c.full), {"persons": 1, "full": False}, @@ -155,6 +149,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): @testing.fails_on("firebird", "fb can't handle returning x AS y") def test_labeling(self, connection): + table = self.tables.tables result = connection.execute( table.insert() .values(persons=6) @@ -167,6 +162,8 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): "firebird", "fb/kintersbasdb can't handle the bind params" ) def test_anon_expressions(self, connection): + table = self.tables.tables + GoofyType = self.GoofyType result = connection.execute( table.insert() .values(goofy="someOTHERgoofy") @@ -182,6 +179,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_(row[0], 30) def test_update_returning(self, connection): + table = self.tables.tables connection.execute( table.insert(), [{"persons": 5, "full": False}, {"persons": 3, "full": False}], @@ -201,6 +199,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): @testing.requires.full_returning def test_update_full_returning(self, connection): + table = self.tables.tables connection.execute( table.insert(), [{"persons": 5, "full": False}, {"persons": 3, "full": False}], @@ -215,6 +214,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): @testing.requires.full_returning def test_delete_full_returning(self, connection): + table = self.tables.tables connection.execute( table.insert(), [{"persons": 5, "full": False}, {"persons": 3, "full": False}], @@ -226,6 +226,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_(result.fetchall(), [(1, False), (2, False)]) def test_insert_returning(self, connection): + table = self.tables.tables result = connection.execute( table.insert().returning(table.c.id), {"persons": 1, "full": False} ) @@ -234,6 +235,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): @testing.requires.multivalues_inserts def test_multirow_returning(self, connection): + table = self.tables.tables ins = ( table.insert() .returning(table.c.id, table.c.persons) @@ -249,6 +251,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)]) def test_no_ipk_on_returning(self, connection): + table = self.tables.tables result = connection.execute( table.insert().returning(table.c.id), {"persons": 1, "full": False} ) @@ -274,6 +277,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults): eq_([dict(row._mapping) for row in result4], [{"persons": 10}]) def test_delete_returning(self, connection): + table = self.tables.tables connection.execute( table.insert(), [{"persons": 5, "full": False}, {"persons": 3, "full": False}], @@ -319,17 +323,16 @@ class CompositeStatementTest(fixtures.TestBase): eq_(result.scalar(), 5) -class SequenceReturningTest(fixtures.TestBase): +class SequenceReturningTest(fixtures.TablesTest): __requires__ = "returning", "sequences" __backend__ = True - def setup(self): - meta = MetaData(testing.db) - global table, seq + @classmethod + def define_tables(cls, metadata): seq = Sequence("tid_seq") - table = Table( + Table( "tables", - meta, + metadata, Column( "id", Integer, @@ -338,38 +341,32 @@ class SequenceReturningTest(fixtures.TestBase): ), Column("data", String(50)), ) - with testing.db.connect() as conn: - table.create(conn, checkfirst=True) - - def teardown(self): - with testing.db.connect() as conn: - table.drop(conn) + cls.sequences.tid_seq = seq def test_insert(self, connection): + table = self.tables.tables r = connection.execute( table.insert().values(data="hi").returning(table.c.id) ) eq_(r.first(), tuple([testing.db.dialect.default_sequence_base])) eq_( - connection.execute(seq), + connection.execute(self.sequences.tid_seq), testing.db.dialect.default_sequence_base + 1, ) -class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults): +class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults): """test returning() works with columns that define 'key'.""" __requires__ = ("returning",) __backend__ = True - def setup(self): - meta = MetaData(testing.db) - global table - - table = Table( + @classmethod + def define_tables(cls, metadata): + Table( "tables", - meta, + metadata, Column( "id", Integer, @@ -379,16 +376,11 @@ class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults): ), Column("data", String(20)), ) - with testing.db.connect() as conn: - table.create(conn, checkfirst=True) - - def teardown(self): - with testing.db.connect() as conn: - table.drop(conn) @testing.exclude("firebird", "<", (2, 0), "2.0+ feature") @testing.exclude("postgresql", "<", (8, 2), "8.2+ feature") def test_insert(self, connection): + table = self.tables.tables result = connection.execute( table.insert().returning(table.c.foo_id), data="somedata" ) diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py index e609a8a916..1809e0cca0 100644 --- a/test/sql/test_sequences.py +++ b/test/sql/test_sequences.py @@ -95,64 +95,6 @@ class SequenceDDLTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) -class LegacySequenceExecTest(fixtures.TestBase): - __requires__ = ("sequences",) - __backend__ = True - - @classmethod - def setup_class(cls): - cls.seq = Sequence("my_sequence") - cls.seq.create(testing.db) - - @classmethod - def teardown_class(cls): - cls.seq.drop(testing.db) - - def _assert_seq_result(self, ret): - """asserts return of next_value is an int""" - - assert isinstance(ret, util.int_types) - assert ret >= testing.db.dialect.default_sequence_base - - def test_implicit_connectionless(self): - s = Sequence("my_sequence", metadata=MetaData(testing.db)) - self._assert_seq_result(s.execute()) - - def test_explicit(self, connection): - s = Sequence("my_sequence") - self._assert_seq_result(s.execute(connection)) - - def test_explicit_optional(self): - """test dialect executes a Sequence, returns nextval, whether - or not "optional" is set""" - - s = Sequence("my_sequence", optional=True) - self._assert_seq_result(s.execute(testing.db)) - - def test_func_implicit_connectionless_execute(self): - """test func.next_value().execute()/.scalar() works - with connectionless execution.""" - - s = Sequence("my_sequence", metadata=MetaData(testing.db)) - self._assert_seq_result(s.next_value().execute().scalar()) - - def test_func_explicit(self): - s = Sequence("my_sequence") - self._assert_seq_result(testing.db.scalar(s.next_value())) - - def test_func_implicit_connectionless_scalar(self): - """test func.next_value().execute()/.scalar() works. """ - - s = Sequence("my_sequence", metadata=MetaData(testing.db)) - self._assert_seq_result(s.next_value().scalar()) - - def test_func_embedded_select(self): - """test can use next_value() in select column expr""" - - s = Sequence("my_sequence") - self._assert_seq_result(testing.db.scalar(select(s.next_value()))) - - class SequenceExecTest(fixtures.TestBase): __requires__ = ("sequences",) __backend__ = True @@ -247,7 +189,7 @@ class SequenceExecTest(fixtures.TestBase): s = Sequence("my_sequence_here", metadata=metadata) e = engines.testing_engine(options={"implicit_returning": False}) - with e.connect() as conn: + with e.begin() as conn: t1.create(conn) s.create(conn) @@ -279,7 +221,7 @@ class SequenceExecTest(fixtures.TestBase): t1.create(testing.db) e = engines.testing_engine(options={"implicit_returning": True}) - with e.connect() as conn: + with e.begin() as conn: r = conn.execute(t1.insert().values(x=s.next_value())) self._assert_seq_result(r.inserted_primary_key[0]) @@ -476,7 +418,7 @@ class TableBoundSequenceTest(fixtures.TablesTest): engine = engines.testing_engine(options={"implicit_returning": False}) - with engine.connect() as conn: + with engine.begin() as conn: result = conn.execute(sometable.insert(), dict(name="somename")) eq_(result.postfetch_cols(), [sometable.c.obj_id]) diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py index 09ade319e2..719f8e3187 100644 --- a/test/sql/test_type_expressions.py +++ b/test/sql/test_type_expressions.py @@ -359,34 +359,34 @@ class RoundTripTestBase(object): [("X1", "Y1"), ("X2", "Y2"), ("X3", "Y3")], ) - def test_targeting_no_labels(self): - testing.db.execute( + def test_targeting_no_labels(self, connection): + connection.execute( self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) - row = testing.db.execute(select(self.tables.test_table)).first() + row = connection.execute(select(self.tables.test_table)).first() eq_(row._mapping[self.tables.test_table.c.y], "Y1") - def test_targeting_by_string(self): - testing.db.execute( + def test_targeting_by_string(self, connection): + connection.execute( self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) - row = testing.db.execute(select(self.tables.test_table)).first() + row = connection.execute(select(self.tables.test_table)).first() eq_(row._mapping["y"], "Y1") - def test_targeting_apply_labels(self): - testing.db.execute( + def test_targeting_apply_labels(self, connection): + connection.execute( self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) - row = testing.db.execute( + row = connection.execute( select(self.tables.test_table).apply_labels() ).first() eq_(row._mapping[self.tables.test_table.c.y], "Y1") - def test_targeting_individual_labels(self): - testing.db.execute( + def test_targeting_individual_labels(self, connection): + connection.execute( self.tables.test_table.insert(), {"x": "X1", "y": "Y1"} ) - row = testing.db.execute( + row = connection.execute( select( self.tables.test_table.c.x.label("xbar"), self.tables.test_table.c.y.label("ybar"), @@ -450,9 +450,9 @@ class ReturningTest(fixtures.TablesTest): ) @testing.provide_metadata - def test_insert_returning(self): + def test_insert_returning(self, connection): table = self.tables.test_table - result = testing.db.execute( + result = connection.execute( table.insert().returning(table.c.y), {"x": "xvalue"} ) eq_(result.first(), ("yvalue",)) diff --git a/test/sql/test_types.py b/test/sql/test_types.py index fd1783e098..3f89d438a6 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -535,49 +535,48 @@ class _UserDefinedTypeFixture(object): class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): __backend__ = True - def _data_fixture(self): + def _data_fixture(self, connection): users = self.tables.users - with testing.db.connect() as conn: - conn.execute( - users.insert(), - dict( - user_id=2, - goofy="jack", - goofy2="jack", - goofy4=util.u("jack"), - goofy7=util.u("jack"), - goofy8=12, - goofy9=12, - ), - ) - conn.execute( - users.insert(), - dict( - user_id=3, - goofy="lala", - goofy2="lala", - goofy4=util.u("lala"), - goofy7=util.u("lala"), - goofy8=15, - goofy9=15, - ), - ) - conn.execute( - users.insert(), - dict( - user_id=4, - goofy="fred", - goofy2="fred", - goofy4=util.u("fred"), - goofy7=util.u("fred"), - goofy8=9, - goofy9=9, - ), - ) + connection.execute( + users.insert(), + dict( + user_id=2, + goofy="jack", + goofy2="jack", + goofy4=util.u("jack"), + goofy7=util.u("jack"), + goofy8=12, + goofy9=12, + ), + ) + connection.execute( + users.insert(), + dict( + user_id=3, + goofy="lala", + goofy2="lala", + goofy4=util.u("lala"), + goofy7=util.u("lala"), + goofy8=15, + goofy9=15, + ), + ) + connection.execute( + users.insert(), + dict( + user_id=4, + goofy="fred", + goofy2="fred", + goofy4=util.u("fred"), + goofy7=util.u("fred"), + goofy8=9, + goofy9=9, + ), + ) def test_processing(self, connection): users = self.tables.users - self._data_fixture() + self._data_fixture(connection) result = connection.execute( users.select().order_by(users.c.user_id) @@ -601,7 +600,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): def test_plain_in(self, connection): users = self.tables.users - self._data_fixture() + self._data_fixture(connection) stmt = ( select(users.c.user_id, users.c.goofy8) @@ -613,7 +612,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest): def test_expanding_in(self, connection): users = self.tables.users - self._data_fixture() + self._data_fixture(connection) stmt = ( select(users.c.user_id, users.c.goofy8) @@ -1225,41 +1224,38 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL): @testing.only_on("sqlite") @testing.provide_metadata - def test_round_trip(self): + def test_round_trip(self, connection): variant = self.UTypeOne().with_variant(self.UTypeTwo(), "sqlite") t = Table("t", self.metadata, Column("x", variant)) - with testing.db.connect() as conn: - t.create(conn) + t.create(connection) - conn.execute(t.insert(), x="foo") + connection.execute(t.insert(), x="foo") - eq_(conn.scalar(select(t.c.x).where(t.c.x == "foo")), "fooUTWO") + eq_(connection.scalar(select(t.c.x).where(t.c.x == "foo")), "fooUTWO") @testing.only_on("sqlite") @testing.provide_metadata - def test_round_trip_sqlite_datetime(self): + def test_round_trip_sqlite_datetime(self, connection): variant = DateTime().with_variant( dialects.sqlite.DATETIME(truncate_microseconds=True), "sqlite" ) t = Table("t", self.metadata, Column("x", variant)) - with testing.db.connect() as conn: - t.create(conn) + t.create(connection) - conn.execute( - t.insert(), x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839) - ) + connection.execute( + t.insert(), x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839) + ) - eq_( - conn.scalar( - select(t.c.x).where( - t.c.x - == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059) - ) - ), - datetime.datetime(2015, 4, 18, 10, 15, 17), - ) + eq_( + connection.scalar( + select(t.c.x).where( + t.c.x == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059) + ) + ), + datetime.datetime(2015, 4, 18, 10, 15, 17), + ) class UnicodeTest(fixtures.TestBase): @@ -1702,14 +1698,25 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): 2, ) - with testing.db.connect() as conn: - self.metadata.create_all(conn) + self.metadata.create_all(testing.db) + + # not using the connection fixture because we need to rollback and + # start again in the middle + with testing.db.connect() as connection: + # postgresql needs this in order to continue after the exception + trans = connection.begin() assert_raises( (exc.DBAPIError,), - conn.exec_driver_sql, + connection.exec_driver_sql, "insert into my_table " "(data) values('four')", ) - conn.exec_driver_sql("insert into my_table (data) values ('two')") + trans.rollback() + + with connection.begin(): + connection.exec_driver_sql( + "insert into my_table (data) values ('two')" + ) + eq_(connection.execute(select(t.c.data)).scalar(), "two") @testing.requires.enforces_check_constraints @testing.provide_metadata @@ -1747,34 +1754,44 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): 2, ) - with testing.db.connect() as conn: - self.metadata.create_all(conn) + self.metadata.create_all(testing.db) + + # not using the connection fixture because we need to rollback and + # start again in the middle + with testing.db.connect() as connection: + # postgresql needs this in order to continue after the exception + trans = connection.begin() assert_raises( (exc.DBAPIError,), - conn.exec_driver_sql, + connection.exec_driver_sql, "insert into my_table " "(data) values('two')", ) - conn.exec_driver_sql("insert into my_table (data) values ('four')") + trans.rollback() - def test_skip_check_constraint(self): - with testing.db.connect() as conn: - conn.exec_driver_sql( - "insert into non_native_enum_table " - "(id, someotherenum) values(1, 'four')" - ) - eq_( - conn.exec_driver_sql( - "select someotherenum from non_native_enum_table" - ).scalar(), - "four", - ) - assert_raises_message( - LookupError, - "'four' is not among the defined enum values. " - "Enum name: None. Possible values: one, two, three", - conn.scalar, - select(self.tables.non_native_enum_table.c.someotherenum), - ) + with connection.begin(): + connection.exec_driver_sql( + "insert into my_table (data) values ('four')" + ) + eq_(connection.execute(select(t.c.data)).scalar(), "four") + + def test_skip_check_constraint(self, connection): + connection.exec_driver_sql( + "insert into non_native_enum_table " + "(id, someotherenum) values(1, 'four')" + ) + eq_( + connection.exec_driver_sql( + "select someotherenum from non_native_enum_table" + ).scalar(), + "four", + ) + assert_raises_message( + LookupError, + "'four' is not among the defined enum values. " + "Enum name: None. Possible values: one, two, three", + connection.scalar, + select(self.tables.non_native_enum_table.c.someotherenum), + ) def test_non_native_round_trip(self, connection): non_native_enum_table = self.tables["non_native_enum_table"] @@ -2086,15 +2103,15 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest): eq_(e.length, 42) -binary_table = MyPickleType = metadata = None +MyPickleType = None -class BinaryTest(fixtures.TestBase, AssertsExecutionResults): +class BinaryTest(fixtures.TablesTest, AssertsExecutionResults): __backend__ = True @classmethod - def setup_class(cls): - global binary_table, MyPickleType, metadata + def define_tables(cls, metadata): + global MyPickleType class MyPickleType(types.TypeDecorator): impl = PickleType @@ -2109,8 +2126,7 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults): value.stuff = "this is the right stuff" return value - metadata = MetaData(testing.db) - binary_table = Table( + Table( "binary_table", metadata, Column( @@ -2125,19 +2141,11 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults): Column("pickled", PickleType), Column("mypickle", MyPickleType), ) - metadata.create_all() - - @engines.close_first - def teardown(self): - with testing.db.connect() as conn: - conn.execute(binary_table.delete()) - - @classmethod - def teardown_class(cls): - metadata.drop_all() @testing.requires.non_broken_binary def test_round_trip(self, connection): + binary_table = self.tables.binary_table + testobj1 = pickleable.Foo("im foo 1") testobj2 = pickleable.Foo("im foo 2") testobj3 = pickleable.Foo("im foo 3") @@ -2197,6 +2205,7 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults): @testing.requires.binary_comparisons def test_comparison(self, connection): """test that type coercion occurs on comparison for binary""" + binary_table = self.tables.binary_table expr = binary_table.c.data == "foo" assert isinstance(expr.right.type, LargeBinary) @@ -2419,17 +2428,17 @@ class ArrayTest(fixtures.TestBase): assert isinstance(arrtable.c.strarr[1:3].type, MyArray) -test_table = meta = MyCustomType = MyTypeDec = None +MyCustomType = MyTypeDec = None class ExpressionTest( - fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL + fixtures.TablesTest, AssertsExecutionResults, AssertsCompiledSQL ): __dialect__ = "default" @classmethod - def setup_class(cls): - global test_table, meta, MyCustomType, MyTypeDec + def define_tables(cls, metadata): + global MyCustomType, MyTypeDec class MyCustomType(types.UserDefinedType): def get_col_spec(self): @@ -2463,10 +2472,9 @@ class ExpressionTest( def process_result_value(self, value, dialect): return value + "BIND_OUT" - meta = MetaData(testing.db) - test_table = Table( + Table( "test", - meta, + metadata, Column("id", Integer, primary_key=True), Column("data", String(30)), Column("atimestamp", Date), @@ -2474,25 +2482,22 @@ class ExpressionTest( Column("bvalue", MyTypeDec(50)), ) - meta.create_all() - - with testing.db.connect() as conn: - conn.execute( - test_table.insert(), - { - "id": 1, - "data": "somedata", - "atimestamp": datetime.date(2007, 10, 15), - "avalue": 25, - "bvalue": "foo", - }, - ) - @classmethod - def teardown_class(cls): - meta.drop_all() + def insert_data(cls, connection): + test_table = cls.tables.test + connection.execute( + test_table.insert(), + { + "id": 1, + "data": "somedata", + "atimestamp": datetime.date(2007, 10, 15), + "avalue": 25, + "bvalue": "foo", + }, + ) def test_control(self, connection): + test_table = self.tables.test assert ( connection.exec_driver_sql("select avalue from test").scalar() == 250 @@ -2513,6 +2518,9 @@ class ExpressionTest( def test_bind_adapt(self, connection): # test an untyped bind gets the left side's type + + test_table = self.tables.test + expr = test_table.c.atimestamp == bindparam("thedate") eq_(expr.right.type._type_affinity, Date) @@ -2565,6 +2573,8 @@ class ExpressionTest( ) def test_grouped_bind_adapt(self): + test_table = self.tables.test + expr = test_table.c.atimestamp == elements.Grouping( bindparam("thedate") ) @@ -2579,6 +2589,8 @@ class ExpressionTest( eq_(expr.right.element.element.type._type_affinity, Date) def test_bind_adapt_update(self): + test_table = self.tables.test + bp = bindparam("somevalue") stmt = test_table.update().values(avalue=bp) compiled = stmt.compile() @@ -2586,13 +2598,17 @@ class ExpressionTest( eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType) def test_bind_adapt_insert(self): + test_table = self.tables.test bp = bindparam("somevalue") + stmt = test_table.insert().values(avalue=bp) compiled = stmt.compile() eq_(bp.type._type_affinity, types.NullType) eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType) def test_bind_adapt_expression(self): + test_table = self.tables.test + bp = bindparam("somevalue") stmt = test_table.c.avalue == bp eq_(bp.type._type_affinity, types.NullType) @@ -2629,6 +2645,8 @@ class ExpressionTest( is_(literal(data).type.__class__, expected) def test_typedec_operator_adapt(self, connection): + test_table = self.tables.test + expr = test_table.c.bvalue + "hi" assert expr.type.__class__ is MyTypeDec @@ -2846,6 +2864,8 @@ class ExpressionTest( eq_(expr.type, types.NULLTYPE) def test_distinct(self, connection): + test_table = self.tables.test + s = select(distinct(test_table.c.avalue)) eq_(connection.execute(s).scalar(), 25) @@ -3004,17 +3024,18 @@ class NumericRawSQLTest(fixtures.TestBase): __backend__ = True - def _fixture(self, metadata, type_, data): + def _fixture(self, connection, metadata, type_, data): t = Table("t", metadata, Column("val", type_)) - metadata.create_all() - with testing.db.connect() as conn: - conn.execute(t.insert(), val=data) + metadata.create_all(connection) + connection.execute(t.insert(), val=data) @testing.fails_on("sqlite", "Doesn't provide Decimal results natively") @testing.provide_metadata def test_decimal_fp(self, connection): metadata = self.metadata - self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45.5")) + self._fixture( + connection, metadata, Numeric(10, 5), decimal.Decimal("45.5") + ) val = connection.exec_driver_sql("select val from t").scalar() assert isinstance(val, decimal.Decimal) eq_(val, decimal.Decimal("45.5")) @@ -3023,7 +3044,9 @@ class NumericRawSQLTest(fixtures.TestBase): @testing.provide_metadata def test_decimal_int(self, connection): metadata = self.metadata - self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45")) + self._fixture( + connection, metadata, Numeric(10, 5), decimal.Decimal("45") + ) val = connection.exec_driver_sql("select val from t").scalar() assert isinstance(val, decimal.Decimal) eq_(val, decimal.Decimal("45")) @@ -3031,7 +3054,7 @@ class NumericRawSQLTest(fixtures.TestBase): @testing.provide_metadata def test_ints(self, connection): metadata = self.metadata - self._fixture(metadata, Integer, 45) + self._fixture(connection, metadata, Integer, 45) val = connection.exec_driver_sql("select val from t").scalar() assert isinstance(val, util.int_types) eq_(val, 45) @@ -3039,7 +3062,7 @@ class NumericRawSQLTest(fixtures.TestBase): @testing.provide_metadata def test_float(self, connection): metadata = self.metadata - self._fixture(metadata, Float, 46.583) + self._fixture(connection, metadata, Float, 46.583) val = connection.exec_driver_sql("select val from t").scalar() assert isinstance(val, float) @@ -3050,19 +3073,14 @@ class NumericRawSQLTest(fixtures.TestBase): eq_(val, 46.583) -interval_table = metadata = None - - -class IntervalTest(fixtures.TestBase, AssertsExecutionResults): +class IntervalTest(fixtures.TablesTest, AssertsExecutionResults): __backend__ = True @classmethod - def setup_class(cls): - global interval_table, metadata - metadata = MetaData(testing.db) - interval_table = Table( - "intervaltable", + def define_tables(cls, metadata): + Table( + "intervals", metadata, Column( "id", Integer, primary_key=True, test_needs_autoincrement=True @@ -3074,16 +3092,6 @@ class IntervalTest(fixtures.TestBase, AssertsExecutionResults): ), Column("non_native_interval", Interval(native=False)), ) - metadata.create_all() - - @engines.close_first - def teardown(self): - with testing.db.connect() as conn: - conn.execute(interval_table.delete()) - - @classmethod - def teardown_class(cls): - metadata.drop_all() def test_non_native_adapt(self): interval = Interval(native=False) @@ -3092,30 +3100,32 @@ class IntervalTest(fixtures.TestBase, AssertsExecutionResults): assert adapted.native is False eq_(str(adapted), "DATETIME") - def test_roundtrip(self): + def test_roundtrip(self, connection): + interval_table = self.tables.intervals + small_delta = datetime.timedelta(days=15, seconds=5874) delta = datetime.timedelta(14) - with testing.db.begin() as conn: - conn.execute( - interval_table.insert(), - native_interval=small_delta, - native_interval_args=delta, - non_native_interval=delta, - ) - row = conn.execute(interval_table.select()).first() + connection.execute( + interval_table.insert(), + native_interval=small_delta, + native_interval_args=delta, + non_native_interval=delta, + ) + row = connection.execute(interval_table.select()).first() eq_(row.native_interval, small_delta) eq_(row.native_interval_args, delta) eq_(row.non_native_interval, delta) - def test_null(self): - with testing.db.begin() as conn: - conn.execute( - interval_table.insert(), - id=1, - native_inverval=None, - non_native_interval=None, - ) - row = conn.execute(interval_table.select()).first() + def test_null(self, connection): + interval_table = self.tables.intervals + + connection.execute( + interval_table.insert(), + id=1, + native_inverval=None, + non_native_interval=None, + ) + row = connection.execute(interval_table.select()).first() eq_(row.native_interval, None) eq_(row.native_interval_args, None) eq_(row.non_native_interval, None) @@ -3215,25 +3225,24 @@ class BooleanTest( ) @testing.requires.non_native_boolean_unconstrained - def test_nonnative_processor_coerces_integer_to_boolean(self): + def test_nonnative_processor_coerces_integer_to_boolean(self, connection): boolean_table = self.tables.boolean_table - with testing.db.connect() as conn: - conn.exec_driver_sql( - "insert into boolean_table (id, unconstrained_value) " - "values (1, 5)" - ) + connection.exec_driver_sql( + "insert into boolean_table (id, unconstrained_value) " + "values (1, 5)" + ) - eq_( - conn.exec_driver_sql( - "select unconstrained_value from boolean_table" - ).scalar(), - 5, - ) + eq_( + connection.exec_driver_sql( + "select unconstrained_value from boolean_table" + ).scalar(), + 5, + ) - eq_( - conn.scalar(select(boolean_table.c.unconstrained_value)), - True, - ) + eq_( + connection.scalar(select(boolean_table.c.unconstrained_value)), + True, + ) def test_bind_processor_coercion_native_true(self): proc = Boolean().bind_processor( diff --git a/test/sql/test_update.py b/test/sql/test_update.py index ec96af207e..946a01651a 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -1263,10 +1263,10 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): __backend__ = True @testing.requires.update_from - def test_exec_two_table(self): + def test_exec_two_table(self, connection): users, addresses = self.tables.users, self.tables.addresses - testing.db.execute( + connection.execute( addresses.update() .values(email_address=users.c.name) .where(users.c.id == addresses.c.user_id) @@ -1280,14 +1280,14 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "x", "ed"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) @testing.requires.update_from - def test_exec_two_table_plus_alias(self): + def test_exec_two_table_plus_alias(self, connection): users, addresses = self.tables.users, self.tables.addresses a1 = addresses.alias() - testing.db.execute( + connection.execute( addresses.update() .values(email_address=users.c.name) .where(users.c.id == a1.c.user_id) @@ -1302,15 +1302,15 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "x", "ed"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) @testing.requires.update_from - def test_exec_three_table(self): + def test_exec_three_table(self, connection): users = self.tables.users addresses = self.tables.addresses dingalings = self.tables.dingalings - testing.db.execute( + connection.execute( addresses.update() .values(email_address=users.c.name) .where(users.c.id == addresses.c.user_id) @@ -1326,15 +1326,15 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "x", "ed@lala.com"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) @testing.only_on("mysql", "Multi table update") - def test_exec_multitable(self): + def test_exec_multitable(self, connection): users, addresses = self.tables.users, self.tables.addresses values = {addresses.c.email_address: "updated", users.c.name: "ed2"} - testing.db.execute( + connection.execute( addresses.update() .values(values) .where(users.c.id == addresses.c.user_id) @@ -1348,18 +1348,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "x", "updated"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) @testing.only_on("mysql", "Multi table update") - def test_exec_join_multitable(self): + def test_exec_join_multitable(self, connection): users, addresses = self.tables.users, self.tables.addresses values = {addresses.c.email_address: "updated", users.c.name: "ed2"} - testing.db.execute( + connection.execute( update(users.join(addresses)) .values(values) .where(users.c.name == "ed") @@ -1372,18 +1372,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "x", "updated"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) @testing.only_on("mysql", "Multi table update") - def test_exec_multitable_same_name(self): + def test_exec_multitable_same_name(self, connection): users, addresses = self.tables.users, self.tables.addresses values = {addresses.c.name: "ad_ed2", users.c.name: "ed2"} - testing.db.execute( + connection.execute( addresses.update() .values(values) .where(users.c.id == addresses.c.user_id) @@ -1397,18 +1397,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (4, 8, "ad_ed2", "ed@lala.com"), (5, 9, "x", "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) - def _assert_addresses(self, addresses, expected): + def _assert_addresses(self, connection, addresses, expected): stmt = addresses.select().order_by(addresses.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) - def _assert_users(self, users, expected): + def _assert_users(self, connection, users, expected): stmt = users.select().order_by(users.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) class UpdateFromMultiTableUpdateDefaultsTest( @@ -1472,12 +1472,12 @@ class UpdateFromMultiTableUpdateDefaultsTest( ) @testing.only_on("mysql", "Multi table update") - def test_defaults_second_table(self): + def test_defaults_second_table(self, connection): users, addresses = self.tables.users, self.tables.addresses values = {addresses.c.email_address: "updated", users.c.name: "ed2"} - ret = testing.db.execute( + ret = connection.execute( addresses.update() .values(values) .where(users.c.id == addresses.c.user_id) @@ -1491,18 +1491,18 @@ class UpdateFromMultiTableUpdateDefaultsTest( (3, 8, "updated"), (4, 9, "fred@fred.com"), ] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) expected = [(8, "ed2", "im the update"), (9, "fred", "value")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) @testing.only_on("mysql", "Multi table update") - def test_defaults_second_table_same_name(self): + def test_defaults_second_table_same_name(self, connection): users, foobar = self.tables.users, self.tables.foobar values = {foobar.c.data: foobar.c.data + "a", users.c.name: "ed2"} - ret = testing.db.execute( + ret = connection.execute( users.update() .values(values) .where(users.c.id == foobar.c.user_id) @@ -1519,16 +1519,16 @@ class UpdateFromMultiTableUpdateDefaultsTest( (3, 8, "d2a", "im the other update"), (4, 9, "d3", None), ] - self._assert_foobar(foobar, expected) + self._assert_foobar(connection, foobar, expected) expected = [(8, "ed2", "im the update"), (9, "fred", "value")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) @testing.only_on("mysql", "Multi table update") - def test_no_defaults_second_table(self): + def test_no_defaults_second_table(self, connection): users, addresses = self.tables.users, self.tables.addresses - ret = testing.db.execute( + ret = connection.execute( addresses.update() .values({"email_address": users.c.name}) .where(users.c.id == addresses.c.user_id) @@ -1538,20 +1538,20 @@ class UpdateFromMultiTableUpdateDefaultsTest( eq_(ret.prefetch_cols(), []) expected = [(2, 8, "ed"), (3, 8, "ed"), (4, 9, "fred@fred.com")] - self._assert_addresses(addresses, expected) + self._assert_addresses(connection, addresses, expected) # users table not actually updated, so no onupdate expected = [(8, "ed", "value"), (9, "fred", "value")] - self._assert_users(users, expected) + self._assert_users(connection, users, expected) - def _assert_foobar(self, foobar, expected): + def _assert_foobar(self, connection, foobar, expected): stmt = foobar.select().order_by(foobar.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) - def _assert_addresses(self, addresses, expected): + def _assert_addresses(self, connection, addresses, expected): stmt = addresses.select().order_by(addresses.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) - def _assert_users(self, users, expected): + def _assert_users(self, connection, users, expected): stmt = users.select().order_by(users.c.id) - eq_(testing.db.execute(stmt).fetchall(), expected) + eq_(connection.execute(stmt).fetchall(), expected) diff --git a/tox.ini b/tox.ini index 6cfcf62efc..e1aef1a23d 100644 --- a/tox.ini +++ b/tox.ini @@ -56,7 +56,6 @@ setenv= PYTHONPATH= PYTHONNOUSERSITE=1 MEMUSAGE=--nomemory - SQLALCHEMY_WARN_20=true BASECOMMAND=python -m pytest --rootdir {toxinidir} --log-info=sqlalchemy.testing WORKERS={env:TOX_WORKERS:-n4 --max-worker-restart=5}