]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
correct for "autocommit" deprecation warning
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 15 Nov 2020 21:58:50 +0000 (16:58 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 11 Dec 2020 18:26:05 +0000 (13:26 -0500)
Ensure no autocommit warnings occur internally or
within tests.

Also includes fixes for SQL Server full text tests
which apparently have not been working at all for a long
time, as it used long removed APIs.  CI has not had
fulltext running for some years and is now installed.

Change-Id: Id806e1856c9da9f0a9eac88cebc7a94ecc95eb96

65 files changed:
lib/sqlalchemy/dialects/mysql/provision.py
lib/sqlalchemy/dialects/oracle/provision.py
lib/sqlalchemy/dialects/postgresql/provision.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/testing/suite/test_cte.py
lib/sqlalchemy/testing/suite/test_dialect.py
lib/sqlalchemy/testing/suite/test_results.py
lib/sqlalchemy/testing/suite/test_rowcount.py
lib/sqlalchemy/testing/suite/test_types.py
lib/sqlalchemy/testing/util.py
lib/sqlalchemy/testing/warnings.py
test/aaa_profiling/test_resultset.py
test/conftest.py
test/dialect/mssql/test_engine.py
test/dialect/mssql/test_query.py
test/dialect/mssql/test_reflection.py
test/dialect/mssql/test_types.py
test/dialect/mysql/test_dialect.py
test/dialect/mysql/test_on_duplicate.py
test/dialect/mysql/test_query.py
test/dialect/mysql/test_reflection.py
test/dialect/oracle/test_dialect.py
test/dialect/oracle/test_reflection.py
test/dialect/oracle/test_types.py
test/dialect/postgresql/test_dialect.py
test/dialect/postgresql/test_on_conflict.py
test/dialect/postgresql/test_query.py
test/dialect/postgresql/test_reflection.py
test/dialect/postgresql/test_types.py
test/dialect/test_mxodbc.py
test/dialect/test_sqlite.py
test/engine/test_ddlevents.py
test/engine/test_deprecations.py
test/engine/test_execute.py
test/engine/test_logging.py
test/engine/test_reconnect.py
test/engine/test_reflection.py
test/engine/test_transaction.py
test/ext/test_associationproxy.py
test/ext/test_horizontal_shard.py
test/orm/inheritance/test_selects.py
test/orm/test_bind.py
test/orm/test_compile.py
test/orm/test_eager_relations.py
test/orm/test_expire.py
test/orm/test_lazy_relations.py
test/orm/test_mapper.py
test/orm/test_naturalpks.py
test/orm/test_query.py
test/orm/test_session.py
test/orm/test_transaction.py
test/orm/test_unitofworkv2.py
test/sql/test_defaults.py
test/sql/test_delete.py
test/sql/test_deprecations.py
test/sql/test_query.py
test/sql/test_quote.py
test/sql/test_resultset.py
test/sql/test_returning.py
test/sql/test_sequences.py
test/sql/test_type_expressions.py
test/sql/test_types.py
test/sql/test_update.py
tox.ini

index c1d83bbb7652b6078d83b95b351725a2259c4b54..50b6e3c8508f0a677a2371c1dfec5b683b817259 100644 (file)
@@ -41,12 +41,13 @@ def generate_driver_url(url, driver, query_str):
 
 @create_db.for_db("mysql", "mariadb")
 def _mysql_create_db(cfg, eng, ident):
-    with eng.connect() as conn:
+    with eng.begin() as conn:
         try:
             _mysql_drop_db(cfg, conn, ident)
         except Exception:
             pass
 
+    with eng.begin() as conn:
         conn.exec_driver_sql(
             "CREATE DATABASE %s CHARACTER SET utf8mb4" % ident
         )
@@ -66,7 +67,7 @@ def _mysql_configure_follower(config, ident):
 
 @drop_db.for_db("mysql", "mariadb")
 def _mysql_drop_db(cfg, eng, ident):
-    with eng.connect() as conn:
+    with eng.begin() as conn:
         conn.exec_driver_sql("DROP DATABASE %s_test_schema" % ident)
         conn.exec_driver_sql("DROP DATABASE %s_test_schema_2" % ident)
         conn.exec_driver_sql("DROP DATABASE %s" % ident)
index d19dfc9fe69fcd51398050c8bb9ed7139f740faa..aadc2c5a999d6021c89afd8f5cf4be06529c1d96 100644 (file)
@@ -17,7 +17,7 @@ def _oracle_create_db(cfg, eng, ident):
     # NOTE: make sure you've run "ALTER DATABASE default tablespace users" or
     # similar, so that the default tablespace is not "system"; reflection will
     # fail otherwise
-    with eng.connect() as conn:
+    with eng.begin() as conn:
         conn.exec_driver_sql("create user %s identified by xe" % ident)
         conn.exec_driver_sql("create user %s_ts1 identified by xe" % ident)
         conn.exec_driver_sql("create user %s_ts2 identified by xe" % ident)
@@ -45,7 +45,7 @@ def _ora_drop_ignore(conn, dbname):
 
 @drop_db.for_db("oracle")
 def _oracle_drop_db(cfg, eng, ident):
-    with eng.connect() as conn:
+    with eng.begin() as conn:
         # cx_Oracle seems to occasionally leak open connections when a large
         # suite it run, even if we confirm we have zero references to
         # connection objects.
@@ -65,7 +65,7 @@ def _oracle_update_db_opts(db_url, db_opts):
 def _reap_oracle_dbs(url, idents):
     log.info("db reaper connecting to %r", url)
     eng = create_engine(url)
-    with eng.connect() as conn:
+    with eng.begin() as conn:
 
         log.info("identifiers in file: %s", ", ".join(idents))
 
index 9433ec4585bef67dd6185cbf9c98cdf342c194dd..575316c61dee8a459e63274dca6c4fe5b892e33b 100644 (file)
@@ -13,7 +13,7 @@ from ...testing.provision import temp_table_keyword_args
 def _pg_create_db(cfg, eng, ident):
     template_db = cfg.options.postgresql_templatedb
 
-    with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
+    with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn:
         try:
             _pg_drop_db(cfg, conn, ident)
         except Exception:
@@ -51,15 +51,16 @@ def _pg_create_db(cfg, eng, ident):
 @drop_db.for_db("postgresql")
 def _pg_drop_db(cfg, eng, ident):
     with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
-        conn.execute(
-            text(
-                "select pg_terminate_backend(pid) from pg_stat_activity "
-                "where usename=current_user and pid != pg_backend_pid() "
-                "and datname=:dname"
-            ),
-            dname=ident,
-        )
-        conn.exec_driver_sql("DROP DATABASE %s" % ident)
+        with conn.begin():
+            conn.execute(
+                text(
+                    "select pg_terminate_backend(pid) from pg_stat_activity "
+                    "where usename=current_user and pid != pg_backend_pid() "
+                    "and datname=:dname"
+                ),
+                dname=ident,
+            )
+            conn.exec_driver_sql("DROP DATABASE %s" % ident)
 
 
 @temp_table_keyword_args.for_db("postgresql")
index 9a5518a961f185b9cd3e8d6ebf1ad6002547d21f..028af9fbb79e96d7b87d876014cf897ff2175194 100644 (file)
@@ -840,7 +840,15 @@ class Connection(Connectable):
     def _commit_impl(self, autocommit=False):
         assert not self.__branch_from
 
-        if autocommit:
+        # AUTOCOMMIT isolation-level is a dialect-specific concept, however
+        # if a connection has this set as the isolation level, we can skip
+        # the "autocommit" warning as the operation will do "autocommit"
+        # in any case
+        if (
+            autocommit
+            and self._execution_options.get("isolation_level", None)
+            != "AUTOCOMMIT"
+        ):
             util.warn_deprecated_20(
                 "The current statement is being autocommitted using "
                 "implicit autocommit, which will be removed in "
@@ -2687,9 +2695,11 @@ class Engine(Connectable, log.Identified):
         self.pool = self.pool.recreate()
         self.dispatch.engine_disposed(self)
 
-    def _execute_default(self, default):
+    def _execute_default(
+        self, default, multiparams=(), params=util.EMPTY_DICT
+    ):
         with self.connect() as conn:
-            return conn._execute_default(default, (), {})
+            return conn._execute_default(default, multiparams, params)
 
     @contextlib.contextmanager
     def _optional_conn_ctx_manager(self, connection=None):
index 4b19ff02a11d9bfaee1b5bb102ac12a2b9da63d0..b5e45c18d9108e3bdc99015b596516bfba471773 100644 (file)
@@ -2258,10 +2258,10 @@ class DefaultGenerator(SchemaItem):
         "or in the ORM by the :meth:`.Session.execute` method of "
         ":class:`.Session`.",
     )
-    def execute(self, bind=None, **kwargs):
+    def execute(self, bind=None):
         if bind is None:
             bind = _bind_or_error(self)
-        return bind.execute(self, **kwargs)
+        return bind._execute_default(self, (), util.EMPTY_DICT)
 
     def _execute_on_connection(
         self, connection, multiparams, params, execution_options
index 4addca009b011a83715ef448cb273709b47596a4..a94ee55dc03c92e464c968f0be66dbb9edca4fed 100644 (file)
@@ -1,4 +1,3 @@
-from .. import config
 from .. import fixtures
 from ..assertions import eq_
 from ..schema import Column
@@ -48,164 +47,158 @@ class CTETest(fixtures.TablesTest):
             ],
         )
 
-    def test_select_nonrecursive_round_trip(self):
+    def test_select_nonrecursive_round_trip(self, connection):
         some_table = self.tables.some_table
 
-        with config.db.connect() as conn:
-            cte = (
-                select(some_table)
-                .where(some_table.c.data.in_(["d2", "d3", "d4"]))
-                .cte("some_cte")
-            )
-            result = conn.execute(
-                select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
-            )
-            eq_(result.fetchall(), [("d4",)])
+        cte = (
+            select(some_table)
+            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+            .cte("some_cte")
+        )
+        result = connection.execute(
+            select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
+        )
+        eq_(result.fetchall(), [("d4",)])
 
-    def test_select_recursive_round_trip(self):
+    def test_select_recursive_round_trip(self, connection):
         some_table = self.tables.some_table
 
-        with config.db.connect() as conn:
-            cte = (
-                select(some_table)
-                .where(some_table.c.data.in_(["d2", "d3", "d4"]))
-                .cte("some_cte", recursive=True)
-            )
+        cte = (
+            select(some_table)
+            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+            .cte("some_cte", recursive=True)
+        )
 
-            cte_alias = cte.alias("c1")
-            st1 = some_table.alias()
-            # note that SQL Server requires this to be UNION ALL,
-            # can't be UNION
-            cte = cte.union_all(
-                select(st1).where(st1.c.id == cte_alias.c.parent_id)
-            )
-            result = conn.execute(
-                select(cte.c.data)
-                .where(cte.c.data != "d2")
-                .order_by(cte.c.data.desc())
-            )
-            eq_(
-                result.fetchall(),
-                [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
-            )
+        cte_alias = cte.alias("c1")
+        st1 = some_table.alias()
+        # note that SQL Server requires this to be UNION ALL,
+        # can't be UNION
+        cte = cte.union_all(
+            select(st1).where(st1.c.id == cte_alias.c.parent_id)
+        )
+        result = connection.execute(
+            select(cte.c.data)
+            .where(cte.c.data != "d2")
+            .order_by(cte.c.data.desc())
+        )
+        eq_(
+            result.fetchall(),
+            [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
+        )
 
-    def test_insert_from_select_round_trip(self):
+    def test_insert_from_select_round_trip(self, connection):
         some_table = self.tables.some_table
         some_other_table = self.tables.some_other_table
 
-        with config.db.connect() as conn:
-            cte = (
-                select(some_table)
-                .where(some_table.c.data.in_(["d2", "d3", "d4"]))
-                .cte("some_cte")
-            )
-            conn.execute(
-                some_other_table.insert().from_select(
-                    ["id", "data", "parent_id"], select(cte)
-                )
-            )
-            eq_(
-                conn.execute(
-                    select(some_other_table).order_by(some_other_table.c.id)
-                ).fetchall(),
-                [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
+        cte = (
+            select(some_table)
+            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+            .cte("some_cte")
+        )
+        connection.execute(
+            some_other_table.insert().from_select(
+                ["id", "data", "parent_id"], select(cte)
             )
+        )
+        eq_(
+            connection.execute(
+                select(some_other_table).order_by(some_other_table.c.id)
+            ).fetchall(),
+            [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
+        )
 
     @testing.requires.ctes_with_update_delete
     @testing.requires.update_from
-    def test_update_from_round_trip(self):
+    def test_update_from_round_trip(self, connection):
         some_table = self.tables.some_table
         some_other_table = self.tables.some_other_table
 
-        with config.db.connect() as conn:
-            conn.execute(
-                some_other_table.insert().from_select(
-                    ["id", "data", "parent_id"], select(some_table)
-                )
+        connection.execute(
+            some_other_table.insert().from_select(
+                ["id", "data", "parent_id"], select(some_table)
             )
+        )
 
-            cte = (
-                select(some_table)
-                .where(some_table.c.data.in_(["d2", "d3", "d4"]))
-                .cte("some_cte")
-            )
-            conn.execute(
-                some_other_table.update()
-                .values(parent_id=5)
-                .where(some_other_table.c.data == cte.c.data)
-            )
-            eq_(
-                conn.execute(
-                    select(some_other_table).order_by(some_other_table.c.id)
-                ).fetchall(),
-                [
-                    (1, "d1", None),
-                    (2, "d2", 5),
-                    (3, "d3", 5),
-                    (4, "d4", 5),
-                    (5, "d5", 3),
-                ],
-            )
+        cte = (
+            select(some_table)
+            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+            .cte("some_cte")
+        )
+        connection.execute(
+            some_other_table.update()
+            .values(parent_id=5)
+            .where(some_other_table.c.data == cte.c.data)
+        )
+        eq_(
+            connection.execute(
+                select(some_other_table).order_by(some_other_table.c.id)
+            ).fetchall(),
+            [
+                (1, "d1", None),
+                (2, "d2", 5),
+                (3, "d3", 5),
+                (4, "d4", 5),
+                (5, "d5", 3),
+            ],
+        )
 
     @testing.requires.ctes_with_update_delete
     @testing.requires.delete_from
-    def test_delete_from_round_trip(self):
+    def test_delete_from_round_trip(self, connection):
         some_table = self.tables.some_table
         some_other_table = self.tables.some_other_table
 
-        with config.db.connect() as conn:
-            conn.execute(
-                some_other_table.insert().from_select(
-                    ["id", "data", "parent_id"], select(some_table)
-                )
+        connection.execute(
+            some_other_table.insert().from_select(
+                ["id", "data", "parent_id"], select(some_table)
             )
+        )
 
-            cte = (
-                select(some_table)
-                .where(some_table.c.data.in_(["d2", "d3", "d4"]))
-                .cte("some_cte")
-            )
-            conn.execute(
-                some_other_table.delete().where(
-                    some_other_table.c.data == cte.c.data
-                )
-            )
-            eq_(
-                conn.execute(
-                    select(some_other_table).order_by(some_other_table.c.id)
-                ).fetchall(),
-                [(1, "d1", None), (5, "d5", 3)],
+        cte = (
+            select(some_table)
+            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+            .cte("some_cte")
+        )
+        connection.execute(
+            some_other_table.delete().where(
+                some_other_table.c.data == cte.c.data
             )
+        )
+        eq_(
+            connection.execute(
+                select(some_other_table).order_by(some_other_table.c.id)
+            ).fetchall(),
+            [(1, "d1", None), (5, "d5", 3)],
+        )
 
     @testing.requires.ctes_with_update_delete
-    def test_delete_scalar_subq_round_trip(self):
+    def test_delete_scalar_subq_round_trip(self, connection):
 
         some_table = self.tables.some_table
         some_other_table = self.tables.some_other_table
 
-        with config.db.connect() as conn:
-            conn.execute(
-                some_other_table.insert().from_select(
-                    ["id", "data", "parent_id"], select(some_table)
-                )
+        connection.execute(
+            some_other_table.insert().from_select(
+                ["id", "data", "parent_id"], select(some_table)
             )
+        )
 
-            cte = (
-                select(some_table)
-                .where(some_table.c.data.in_(["d2", "d3", "d4"]))
-                .cte("some_cte")
-            )
-            conn.execute(
-                some_other_table.delete().where(
-                    some_other_table.c.data
-                    == select(cte.c.data)
-                    .where(cte.c.id == some_other_table.c.id)
-                    .scalar_subquery()
-                )
-            )
-            eq_(
-                conn.execute(
-                    select(some_other_table).order_by(some_other_table.c.id)
-                ).fetchall(),
-                [(1, "d1", None), (5, "d5", 3)],
+        cte = (
+            select(some_table)
+            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+            .cte("some_cte")
+        )
+        connection.execute(
+            some_other_table.delete().where(
+                some_other_table.c.data
+                == select(cte.c.data)
+                .where(cte.c.id == some_other_table.c.id)
+                .scalar_subquery()
             )
+        )
+        eq_(
+            connection.execute(
+                select(some_other_table).order_by(some_other_table.c.id)
+            ).fetchall(),
+            [(1, "d1", None), (5, "d5", 3)],
+        )
index 7f697b915d02ab620082114514f36300cee07f02..b0df1218dd6c1f570f71cfb845b82971efe93b3e 100644 (file)
@@ -123,7 +123,7 @@ class IsolationLevelTest(fixtures.TestBase):
             eq_(conn.get_isolation_level(), existing)
 
 
-class AutocommitTest(fixtures.TablesTest):
+class AutocommitIsolationTest(fixtures.TablesTest):
 
     run_deletes = "each"
 
@@ -153,7 +153,8 @@ class AutocommitTest(fixtures.TablesTest):
             1 if autocommit else None,
         )
 
-        conn.execute(self.tables.some_table.delete())
+        with conn.begin():
+            conn.execute(self.tables.some_table.delete())
 
     def test_autocommit_on(self):
         conn = config.db.connect()
@@ -170,7 +171,7 @@ class AutocommitTest(fixtures.TablesTest):
 
     def test_turn_autocommit_off_via_default_iso_level(self):
         conn = config.db.connect()
-        conn.execution_options(isolation_level="AUTOCOMMIT")
+        conn = conn.execution_options(isolation_level="AUTOCOMMIT")
         self._test_conn_autocommits(conn, True)
 
         conn.execution_options(
index 9484d41d09c4639cbca7aa4ea86cd483dcee4b15..0298738663305a16ec753ab08d2e0013f5035594 100644 (file)
@@ -355,7 +355,7 @@ class ServerSideCursorsTest(
             Column("data", String(50)),
         )
 
-        with engine.connect() as connection:
+        with engine.begin() as connection:
             test_table.create(connection, checkfirst=True)
             connection.execute(test_table.insert(), dict(data="data1"))
             connection.execute(test_table.insert(), dict(data="data2"))
@@ -396,7 +396,7 @@ class ServerSideCursorsTest(
             Column("data", String(50)),
         )
 
-        with engine.connect() as connection:
+        with engine.begin() as connection:
             test_table.create(connection, checkfirst=True)
             connection.execute(
                 test_table.insert(),
index 06945ff2a7cc7bca144b60a59a9ae3a6302a6778..f3f902abd26133b97e34cd9e98f9fdf0f09e1016 100644 (file)
@@ -58,12 +58,14 @@ class RowCountTest(fixtures.TablesTest):
 
         assert len(r) == len(self.data)
 
-    def test_update_rowcount1(self):
+    def test_update_rowcount1(self, connection):
         employees_table = self.tables.employees
 
         # WHERE matches 3, 3 rows changed
         department = employees_table.c.department
-        r = employees_table.update(department == "C").execute(department="Z")
+        r = connection.execute(
+            employees_table.update(department == "C"), {"department": "Z"}
+        )
         assert r.rowcount == 3
 
     def test_update_rowcount2(self, connection):
index da01aa484bd4b6d26951fbe3562021a41d453069..21d2e8942d1d972d6d23f2448db3f2243b7f83d9 100644 (file)
@@ -340,7 +340,7 @@ class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
         # passing NULL for an expression that needs to be interpreted as
         # a certain type, does the DBAPI have the info it needs to do this.
         date_table = self.tables.date_table
-        with config.db.connect() as conn:
+        with config.db.begin() as conn:
             result = conn.execute(
                 date_table.insert(), {"date_data": self.data}
             )
@@ -702,7 +702,7 @@ class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
         # testing "WHERE <column>" renders a compatible expression
         boolean_table = self.tables.boolean_table
 
-        with config.db.connect() as conn:
+        with config.db.begin() as conn:
             conn.execute(
                 boolean_table.insert(),
                 [
@@ -817,7 +817,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     def test_index_typed_access(self, datatype, value):
         data_table = self.tables.data_table
         data_element = {"key1": value}
-        with config.db.connect() as conn:
+        with config.db.begin() as conn:
             conn.execute(
                 data_table.insert(),
                 {
@@ -841,7 +841,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     def test_index_typed_comparison(self, datatype, value):
         data_table = self.tables.data_table
         data_element = {"key1": value}
-        with config.db.connect() as conn:
+        with config.db.begin() as conn:
             conn.execute(
                 data_table.insert(),
                 {
@@ -864,7 +864,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     def test_path_typed_comparison(self, datatype, value):
         data_table = self.tables.data_table
         data_element = {"key1": {"subkey1": value}}
-        with config.db.connect() as conn:
+        with config.db.begin() as conn:
             conn.execute(
                 data_table.insert(),
                 {
@@ -900,7 +900,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     def test_single_element_round_trip(self, element):
         data_table = self.tables.data_table
         data_element = element
-        with config.db.connect() as conn:
+        with config.db.begin() as conn:
             conn.execute(
                 data_table.insert(),
                 {
@@ -928,7 +928,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
 
         # support sqlite :memory: database...
         data_table.create(engine, checkfirst=True)
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             conn.execute(
                 data_table.insert(), {"name": "row1", "data": data_element}
             )
@@ -978,7 +978,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
     def test_round_trip_none_as_json_null(self):
         col = self.tables.data_table.c["data"]
 
-        with config.db.connect() as conn:
+        with config.db.begin() as conn:
             conn.execute(
                 self.tables.data_table.insert(), {"name": "r1", "data": None}
             )
@@ -996,7 +996,7 @@ class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
 
     def test_unicode_round_trip(self):
         # note we include Unicode supplementary characters as well
-        with config.db.connect() as conn:
+        with config.db.begin() as conn:
             conn.execute(
                 self.tables.data_table.insert(),
                 {
index c52dc4a19b524fe11c5cd408211e9f14e005a1fe..c6626b9e087f9f17c6b9a801c9e461409f43dd85 100644 (file)
@@ -370,7 +370,7 @@ def drop_all_tables(engine, inspector, schema=None, include_names=None):
     if include_names is not None:
         include_names = set(include_names)
 
-    with engine.connect() as conn:
+    with engine.begin() as conn:
         for tname, fkcs in reversed(
             inspector.get_sorted_table_and_fkc_names(schema=schema)
         ):
index b230bad6f09dd4d6861e20bd176ef44e0be45c63..cd919cc0b0e8be58e896c7024e14f2b88b2198d6 100644 (file)
@@ -51,13 +51,11 @@ def setup_filters():
         # Core execution
         #
         r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method",
-        r"The current statement is being autocommitted using implicit "
-        "autocommit,",
         r"The connection.execute\(\) method in SQLAlchemy 2.0 will accept "
         "parameters as a single dictionary or a single sequence of "
         "dictionaries only.",
         r"The Connection.connect\(\) method is considered legacy",
-        r".*DefaultGenerator.execute\(\)",
+        #        r".*DefaultGenerator.execute\(\)",
         #
         # bound metadaa
         #
index 7188c412505831366ce406df5a910f71e7c2ec42..d36a0c9e1b17b4594cb87dd3d2a9f9eca5853cb8 100644 (file)
@@ -48,25 +48,28 @@ class ResultSetTest(fixtures.TestBase, AssertsExecutionResults):
         )
 
     def setup(self):
-        metadata.create_all()
-        t.insert().execute(
-            [
-                dict(
-                    ("field%d" % fnum, u("value%d" % fnum))
-                    for fnum in range(NUM_FIELDS)
-                )
-                for r_num in range(NUM_RECORDS)
-            ]
-        )
-        t2.insert().execute(
-            [
-                dict(
-                    ("field%d" % fnum, u("value%d" % fnum))
-                    for fnum in range(NUM_FIELDS)
-                )
-                for r_num in range(NUM_RECORDS)
-            ]
-        )
+        with testing.db.begin() as conn:
+            metadata.create_all(conn)
+            conn.execute(
+                t.insert(),
+                [
+                    dict(
+                        ("field%d" % fnum, u("value%d" % fnum))
+                        for fnum in range(NUM_FIELDS)
+                    )
+                    for r_num in range(NUM_RECORDS)
+                ],
+            )
+            conn.execute(
+                t2.insert(),
+                [
+                    dict(
+                        ("field%d" % fnum, u("value%d" % fnum))
+                        for fnum in range(NUM_FIELDS)
+                    )
+                    for r_num in range(NUM_RECORDS)
+                ],
+            )
 
         # warm up type caches
         with testing.db.connect() as conn:
index 63f3989ebc1a41ba76ecba9de54ebba9f48a912b..0db4486a92fbd5bed5fc9063bf41a4c3f00b92c5 100755 (executable)
@@ -12,6 +12,8 @@ import sys
 import pytest
 
 
+os.environ["SQLALCHEMY_WARN_20"] = "true"
+
 collect_ignore_glob = []
 
 # minimum version for a py3k only test is at
index 44445595893d145eebfa4b1f0a73e573619e7bb3..668df6ecbc83967075548e2c097f5a56cde1e351 100644 (file)
@@ -382,7 +382,7 @@ class FastExecutemanyTest(fixtures.TestBase):
             if executemany:
                 assert cursor.fast_executemany
 
-        with eng.connect() as conn:
+        with eng.begin() as conn:
             conn.execute(
                 t.insert(),
                 [{"id": i, "data": "data_%d" % i} for i in range(100)],
index d9dc033e1668349e63add06ac7575e2779d87a22..ea0bfa4d2708e65e11865ae31383c199e721340c 100644 (file)
@@ -9,7 +9,6 @@ from sqlalchemy import func
 from sqlalchemy import Identity
 from sqlalchemy import Integer
 from sqlalchemy import literal
-from sqlalchemy import MetaData
 from sqlalchemy import or_
 from sqlalchemy import PrimaryKeyConstraint
 from sqlalchemy import select
@@ -26,22 +25,15 @@ from sqlalchemy.testing.assertsql import CursorSQL
 from sqlalchemy.testing.assertsql import DialectSQL
 from sqlalchemy.util import ue
 
-metadata = None
-cattable = None
-matchtable = None
 
-
-class IdentityInsertTest(fixtures.TestBase, AssertsCompiledSQL):
+class IdentityInsertTest(fixtures.TablesTest, AssertsCompiledSQL):
     __only_on__ = "mssql"
     __dialect__ = mssql.MSDialect()
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
-        global metadata, cattable
-        metadata = MetaData(testing.db)
-
-        cattable = Table(
+    def define_tables(cls, metadata):
+        Table(
             "cattable",
             metadata,
             Column("id", Integer),
@@ -49,82 +41,82 @@ class IdentityInsertTest(fixtures.TestBase, AssertsCompiledSQL):
             PrimaryKeyConstraint("id", name="PK_cattable"),
         )
 
-    def setup(self):
-        metadata.create_all()
-
-    def teardown(self):
-        metadata.drop_all()
-
     def test_compiled(self):
+        cattable = self.tables.cattable
         self.assert_compile(
             cattable.insert().values(id=9, description="Python"),
             "INSERT INTO cattable (id, description) "
             "VALUES (:id, :description)",
         )
 
-    def test_execute(self):
-        with testing.db.connect() as conn:
-            conn.execute(cattable.insert().values(id=9, description="Python"))
-
-            cats = conn.execute(cattable.select().order_by(cattable.c.id))
-            eq_([(9, "Python")], list(cats))
+    def test_execute(self, connection):
+        conn = connection
+        cattable = self.tables.cattable
+        conn.execute(cattable.insert().values(id=9, description="Python"))
 
-            result = conn.execute(cattable.insert().values(description="PHP"))
-            eq_(result.inserted_primary_key, (10,))
-            lastcat = conn.execute(
-                cattable.select().order_by(desc(cattable.c.id))
-            )
-            eq_((10, "PHP"), lastcat.first())
-
-    def test_executemany(self):
-        with testing.db.connect() as conn:
-            conn.execute(
-                cattable.insert(),
-                [
-                    {"id": 89, "description": "Python"},
-                    {"id": 8, "description": "Ruby"},
-                    {"id": 3, "description": "Perl"},
-                    {"id": 1, "description": "Java"},
-                ],
-            )
-            cats = conn.execute(cattable.select().order_by(cattable.c.id))
-            eq_(
-                [(1, "Java"), (3, "Perl"), (8, "Ruby"), (89, "Python")],
-                list(cats),
-            )
-            conn.execute(
-                cattable.insert(),
-                [{"description": "PHP"}, {"description": "Smalltalk"}],
-            )
-            lastcats = conn.execute(
-                cattable.select().order_by(desc(cattable.c.id)).limit(2)
-            )
-            eq_([(91, "Smalltalk"), (90, "PHP")], list(lastcats))
+        cats = conn.execute(cattable.select().order_by(cattable.c.id))
+        eq_([(9, "Python")], list(cats))
 
-    def test_insert_plain_param(self):
-        with testing.db.connect() as conn:
-            conn.execute(cattable.insert(), id=5)
-            eq_(conn.scalar(select(cattable.c.id)), 5)
+        result = conn.execute(cattable.insert().values(description="PHP"))
+        eq_(result.inserted_primary_key, (10,))
+        lastcat = conn.execute(cattable.select().order_by(desc(cattable.c.id)))
+        eq_((10, "PHP"), lastcat.first())
 
-    def test_insert_values_key_plain(self):
-        with testing.db.connect() as conn:
-            conn.execute(cattable.insert().values(id=5))
-            eq_(conn.scalar(select(cattable.c.id)), 5)
-
-    def test_insert_values_key_expression(self):
-        with testing.db.connect() as conn:
-            conn.execute(cattable.insert().values(id=literal(5)))
-            eq_(conn.scalar(select(cattable.c.id)), 5)
-
-    def test_insert_values_col_plain(self):
-        with testing.db.connect() as conn:
-            conn.execute(cattable.insert().values({cattable.c.id: 5}))
-            eq_(conn.scalar(select(cattable.c.id)), 5)
-
-    def test_insert_values_col_expression(self):
-        with testing.db.connect() as conn:
-            conn.execute(cattable.insert().values({cattable.c.id: literal(5)}))
-            eq_(conn.scalar(select(cattable.c.id)), 5)
+    def test_executemany(self, connection):
+        conn = connection
+        cattable = self.tables.cattable
+        conn.execute(
+            cattable.insert(),
+            [
+                {"id": 89, "description": "Python"},
+                {"id": 8, "description": "Ruby"},
+                {"id": 3, "description": "Perl"},
+                {"id": 1, "description": "Java"},
+            ],
+        )
+        cats = conn.execute(cattable.select().order_by(cattable.c.id))
+        eq_(
+            [(1, "Java"), (3, "Perl"), (8, "Ruby"), (89, "Python")],
+            list(cats),
+        )
+        conn.execute(
+            cattable.insert(),
+            [{"description": "PHP"}, {"description": "Smalltalk"}],
+        )
+        lastcats = conn.execute(
+            cattable.select().order_by(desc(cattable.c.id)).limit(2)
+        )
+        eq_([(91, "Smalltalk"), (90, "PHP")], list(lastcats))
+
+    def test_insert_plain_param(self, connection):
+        conn = connection
+        cattable = self.tables.cattable
+        conn.execute(cattable.insert(), id=5)
+        eq_(conn.scalar(select(cattable.c.id)), 5)
+
+    def test_insert_values_key_plain(self, connection):
+        conn = connection
+        cattable = self.tables.cattable
+        conn.execute(cattable.insert().values(id=5))
+        eq_(conn.scalar(select(cattable.c.id)), 5)
+
+    def test_insert_values_key_expression(self, connection):
+        conn = connection
+        cattable = self.tables.cattable
+        conn.execute(cattable.insert().values(id=literal(5)))
+        eq_(conn.scalar(select(cattable.c.id)), 5)
+
+    def test_insert_values_col_plain(self, connection):
+        conn = connection
+        cattable = self.tables.cattable
+        conn.execute(cattable.insert().values({cattable.c.id: 5}))
+        eq_(conn.scalar(select(cattable.c.id)), 5)
+
+    def test_insert_values_col_expression(self, connection):
+        conn = connection
+        cattable = self.tables.cattable
+        conn.execute(cattable.insert().values({cattable.c.id: literal(5)}))
+        eq_(conn.scalar(select(cattable.c.id)), 5)
 
 
 class QueryUnicodeTest(fixtures.TestBase):
@@ -391,37 +383,35 @@ def full_text_search_missing():
     """Test if full text search is not implemented and return False if
     it is and True otherwise."""
 
-    try:
-        connection = testing.db.connect()
-        try:
-            connection.exec_driver_sql(
-                "CREATE FULLTEXT CATALOG Catalog AS " "DEFAULT"
-            )
-            return False
-        except Exception:
-            return True
-    finally:
-        connection.close()
+    if not testing.against("mssql"):
+        return True
+
+    with testing.db.connect() as conn:
+        result = conn.exec_driver_sql(
+            "SELECT cast(SERVERPROPERTY('IsFullTextInstalled') as integer)"
+        )
+        return result.scalar() == 0
 
 
-class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
+class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
 
     __only_on__ = "mssql"
     __skip_if__ = (full_text_search_missing,)
     __backend__ = True
 
+    run_setup_tables = "once"
+    run_inserts = run_deletes = "once"
+
     @classmethod
-    def setup_class(cls):
-        global metadata, cattable, matchtable
-        metadata = MetaData(testing.db)
-        cattable = Table(
+    def define_tables(cls, metadata):
+        Table(
             "cattable",
             metadata,
             Column("id", Integer),
             Column("description", String(50)),
             PrimaryKeyConstraint("id", name="PK_cattable"),
         )
-        matchtable = Table(
+        Table(
             "matchtable",
             metadata,
             Column("id", Integer),
@@ -429,24 +419,65 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
             Column("category_id", Integer, ForeignKey("cattable.id")),
             PrimaryKeyConstraint("id", name="PK_matchtable"),
         )
-        DDL(
-            """CREATE FULLTEXT INDEX
+
+        event.listen(
+            metadata,
+            "before_create",
+            DDL("CREATE FULLTEXT CATALOG Catalog AS DEFAULT"),
+        )
+        event.listen(
+            metadata,
+            "after_create",
+            DDL(
+                """CREATE FULLTEXT INDEX
                        ON cattable (description)
                        KEY INDEX PK_cattable"""
-        ).execute_at("after-create", matchtable)
-        DDL(
-            """CREATE FULLTEXT INDEX
+            ),
+        )
+        event.listen(
+            metadata,
+            "after_create",
+            DDL(
+                """CREATE FULLTEXT INDEX
                        ON matchtable (title)
                        KEY INDEX PK_matchtable"""
-        ).execute_at("after-create", matchtable)
-        metadata.create_all()
-        cattable.insert().execute(
+            ),
+        )
+
+        event.listen(
+            metadata,
+            "after_drop",
+            DDL("DROP FULLTEXT CATALOG Catalog"),
+        )
+
+    @classmethod
+    def setup_bind(cls):
+        return testing.db.execution_options(isolation_level="AUTOCOMMIT")
+
+    @classmethod
+    def setup_class(cls):
+        with testing.db.connect().execution_options(
+            isolation_level="AUTOCOMMIT"
+        ) as conn:
+            try:
+                conn.exec_driver_sql("DROP FULLTEXT CATALOG Catalog")
+            except:
+                pass
+        super(MatchTest, cls).setup_class()
+
+    @classmethod
+    def insert_data(cls, connection):
+        cattable, matchtable = cls.tables("cattable", "matchtable")
+
+        connection.execute(
+            cattable.insert(),
             [
                 {"id": 1, "description": "Python"},
                 {"id": 2, "description": "Ruby"},
-            ]
+            ],
         )
-        matchtable.insert().execute(
+        connection.execute(
+            matchtable.insert(),
             [
                 {
                     "id": 1,
@@ -461,62 +492,53 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
                 },
                 {"id": 4, "title": "Guide to Django", "category_id": 1},
                 {"id": 5, "title": "Python in a Nutshell", "category_id": 1},
-            ]
+            ],
         )
-        DDL("WAITFOR DELAY '00:00:05'").execute(bind=engines.testing_engine())
-
-    @classmethod
-    def teardown_class(cls):
-        metadata.drop_all()
-        connection = testing.db.connect()
-        connection.exec_driver_sql("DROP FULLTEXT CATALOG Catalog")
-        connection.close()
+        # apparently this is needed!   index must run asynchronously
+        connection.execute(DDL("WAITFOR DELAY '00:00:05'"))
 
     def test_expression(self):
+        matchtable = self.tables.matchtable
         self.assert_compile(
             matchtable.c.title.match("somstr"),
             "CONTAINS (matchtable.title, ?)",
         )
 
-    def test_simple_match(self):
-        results = (
+    def test_simple_match(self, connection):
+        matchtable = self.tables.matchtable
+        results = connection.execute(
             matchtable.select()
             .where(matchtable.c.title.match("python"))
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([2, 5], [r.id for r in results])
 
-    def test_simple_match_with_apostrophe(self):
-        results = (
-            matchtable.select()
-            .where(matchtable.c.title.match("Matz's"))
-            .execute()
-            .fetchall()
-        )
+    def test_simple_match_with_apostrophe(self, connection):
+        matchtable = self.tables.matchtable
+        results = connection.execute(
+            matchtable.select().where(matchtable.c.title.match("Matz's"))
+        ).fetchall()
         eq_([3], [r.id for r in results])
 
-    def test_simple_prefix_match(self):
-        results = (
-            matchtable.select()
-            .where(matchtable.c.title.match('"nut*"'))
-            .execute()
-            .fetchall()
-        )
+    def test_simple_prefix_match(self, connection):
+        matchtable = self.tables.matchtable
+        results = connection.execute(
+            matchtable.select().where(matchtable.c.title.match('"nut*"'))
+        ).fetchall()
         eq_([5], [r.id for r in results])
 
-    def test_simple_inflectional_match(self):
-        results = (
-            matchtable.select()
-            .where(matchtable.c.title.match('FORMSOF(INFLECTIONAL, "dives")'))
-            .execute()
-            .fetchall()
-        )
+    def test_simple_inflectional_match(self, connection):
+        matchtable = self.tables.matchtable
+        results = connection.execute(
+            matchtable.select().where(
+                matchtable.c.title.match('FORMSOF(INFLECTIONAL, "dives")')
+            )
+        ).fetchall()
         eq_([2], [r.id for r in results])
 
-    def test_or_match(self):
-        results1 = (
+    def test_or_match(self, connection):
+        matchtable = self.tables.matchtable
+        results1 = connection.execute(
             matchtable.select()
             .where(
                 or_(
@@ -525,31 +547,25 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
                 )
             )
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([3, 5], [r.id for r in results1])
-        results2 = (
+        results2 = connection.execute(
             matchtable.select()
             .where(matchtable.c.title.match("nutshell OR ruby"))
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([3, 5], [r.id for r in results2])
 
-    def test_and_match(self):
-        results1 = (
-            matchtable.select()
-            .where(
+    def test_and_match(self, connection):
+        matchtable = self.tables.matchtable
+        results1 = connection.execute(
+            matchtable.select().where(
                 and_(
                     matchtable.c.title.match("python"),
                     matchtable.c.title.match("nutshell"),
                 )
             )
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([5], [r.id for r in results1])
         results2 = (
             matchtable.select()
@@ -559,8 +575,10 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
         )
         eq_([5], [r.id for r in results2])
 
-    def test_match_across_joins(self):
-        results = (
+    def test_match_across_joins(self, connection):
+        matchtable = self.tables.matchtable
+        cattable = self.tables.cattable
+        results = connection.execute(
             matchtable.select()
             .where(
                 and_(
@@ -572,7 +590,5 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
                 )
             )
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([1, 3, 5], [r.id for r in results])
index 6009bfb6cb88f82f3136827865e034331b34e436..86c97316ad2a9143d2617dd2b3ae8d7eb197ab9c 100644 (file)
@@ -741,14 +741,9 @@ class IdentityReflectionTest(fixtures.TablesTest):
 
     @testing.requires.views
     def test_reflect_views(self, connection):
-        try:
-            with testing.db.connect() as conn:
-                conn.exec_driver_sql("CREATE VIEW view1 AS SELECT * FROM t1")
-            insp = inspect(testing.db)
-            for col in insp.get_columns("view1"):
-                is_true("dialect_options" not in col)
-                is_true("identity" in col)
-                eq_(col["identity"], {})
-        finally:
-            with testing.db.connect() as conn:
-                conn.exec_driver_sql("DROP VIEW view1")
+        connection.exec_driver_sql("CREATE VIEW view1 AS SELECT * FROM t1")
+        insp = inspect(connection)
+        for col in insp.get_columns("view1"):
+            is_true("dialect_options" not in col)
+            is_true("identity" in col)
+            eq_(col["identity"], {})
index 11a2a25b3fe04a8cd7e1cf2b190c4d743cc7b0ca..a4a3bedda3c26f67a142190a09b73e02438ec2f8 100644 (file)
@@ -221,7 +221,7 @@ class RowVersionTest(fixtures.TablesTest):
             Column("rv", cls(convert_int=convert_int)),
         )
 
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.execute(t.insert().values(data="foo"))
             last_ts_1 = conn.exec_driver_sql("SELECT @@DBTS").scalar()
 
@@ -545,7 +545,7 @@ class TypeRoundTripTest(
     __backend__ = True
 
     @testing.provide_metadata
-    def test_decimal_notation(self):
+    def test_decimal_notation(self, connection):
         metadata = self.metadata
         numeric_table = Table(
             "numeric_table",
@@ -560,7 +560,7 @@ class TypeRoundTripTest(
                 "numericcol", Numeric(precision=38, scale=20, asdecimal=True)
             ),
         )
-        metadata.create_all()
+        metadata.create_all(connection)
         test_items = [
             decimal.Decimal(d)
             for d in (
@@ -623,21 +623,20 @@ class TypeRoundTripTest(
             )
         ]
 
-        with testing.db.connect() as conn:
-            for value in test_items:
-                result = conn.execute(
-                    numeric_table.insert(), dict(numericcol=value)
-                )
-                primary_key = result.inserted_primary_key
-                returned = conn.scalar(
-                    select(numeric_table.c.numericcol).where(
-                        numeric_table.c.id == primary_key[0]
-                    )
+        for value in test_items:
+            result = connection.execute(
+                numeric_table.insert(), dict(numericcol=value)
+            )
+            primary_key = result.inserted_primary_key
+            returned = connection.scalar(
+                select(numeric_table.c.numericcol).where(
+                    numeric_table.c.id == primary_key[0]
                 )
-                eq_(value, returned)
+            )
+            eq_(value, returned)
 
     @testing.provide_metadata
-    def test_float(self):
+    def test_float(self, connection):
         metadata = self.metadata
 
         float_table = Table(
@@ -652,41 +651,47 @@ class TypeRoundTripTest(
             Column("floatcol", Float()),
         )
 
-        metadata.create_all()
-        try:
-            test_items = [
-                float(d)
-                for d in (
-                    "1500000.00000000000000000000",
-                    "-1500000.00000000000000000000",
-                    "1500000",
-                    "0.0000000000000000002",
-                    "0.2",
-                    "-0.0000000000000000002",
-                    "156666.458923543",
-                    "-156666.458923543",
-                    "1",
-                    "-1",
-                    "1234",
-                    "2E-12",
-                    "4E8",
-                    "3E-6",
-                    "3E-7",
-                    "4.1",
-                    "1E-1",
-                    "1E-2",
-                    "1E-3",
-                    "1E-4",
-                    "1E-5",
-                    "1E-6",
-                    "1E-7",
-                    "1E-8",
+        metadata.create_all(connection)
+        test_items = [
+            float(d)
+            for d in (
+                "1500000.00000000000000000000",
+                "-1500000.00000000000000000000",
+                "1500000",
+                "0.0000000000000000002",
+                "0.2",
+                "-0.0000000000000000002",
+                "156666.458923543",
+                "-156666.458923543",
+                "1",
+                "-1",
+                "1234",
+                "2E-12",
+                "4E8",
+                "3E-6",
+                "3E-7",
+                "4.1",
+                "1E-1",
+                "1E-2",
+                "1E-3",
+                "1E-4",
+                "1E-5",
+                "1E-6",
+                "1E-7",
+                "1E-8",
+            )
+        ]
+        for value in test_items:
+            result = connection.execute(
+                float_table.insert(), dict(floatcol=value)
+            )
+            primary_key = result.inserted_primary_key
+            returned = connection.scalar(
+                select(float_table.c.floatcol).where(
+                    float_table.c.id == primary_key[0]
                 )
-            ]
-            for value in test_items:
-                float_table.insert().execute(floatcol=value)
-        except Exception as e:
-            raise e
+            )
+            eq_(value, returned)
 
     # todo this should suppress warnings, but it does not
     @emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*")
@@ -770,18 +775,17 @@ class TypeRoundTripTest(
         d2 = datetime.datetime(2007, 10, 30, 11, 2, 32)
         return t, (d1, t1, d2)
 
-    def test_date_roundtrips(self, date_fixture):
+    def test_date_roundtrips(self, date_fixture, connection):
         t, (d1, t1, d2) = date_fixture
-        with testing.db.begin() as conn:
-            conn.execute(
-                t.insert(), adate=d1, adatetime=d2, atime1=t1, atime2=d2
-            )
+        connection.execute(
+            t.insert(), adate=d1, adatetime=d2, atime1=t1, atime2=d2
+        )
 
-            row = conn.execute(t.select()).first()
-            eq_(
-                (row.adate, row.adatetime, row.atime1, row.atime2),
-                (d1, d2, t1, d2.time()),
-            )
+        row = connection.execute(t.select()).first()
+        eq_(
+            (row.adate, row.adatetime, row.atime1, row.atime2),
+            (d1, d2, t1, d2.time()),
+        )
 
     @testing.metadata_fixture()
     def datetimeoffset_fixture(self, metadata):
@@ -870,45 +874,45 @@ class TypeRoundTripTest(
         dto_param_value,
         expected_offset_hours,
         should_fail,
+        connection,
     ):
         t = datetimeoffset_fixture
         dto_param_value = dto_param_value()
 
-        with testing.db.begin() as conn:
-            if should_fail:
-                assert_raises(
-                    sa.exc.DBAPIError,
-                    conn.execute,
-                    t.insert(),
-                    adatetimeoffset=dto_param_value,
-                )
-                return
-
-            conn.execute(
+        if should_fail:
+            assert_raises(
+                sa.exc.DBAPIError,
+                connection.execute,
                 t.insert(),
                 adatetimeoffset=dto_param_value,
             )
+            return
 
-            row = conn.execute(t.select()).first()
+        connection.execute(
+            t.insert(),
+            adatetimeoffset=dto_param_value,
+        )
 
-            if dto_param_value is None:
-                is_(row.adatetimeoffset, None)
-            else:
-                eq_(
-                    row.adatetimeoffset,
-                    datetime.datetime(
-                        2007,
-                        10,
-                        30,
-                        11,
-                        2,
-                        32,
-                        123456,
-                        util.timezone(
-                            datetime.timedelta(hours=expected_offset_hours)
-                        ),
+        row = connection.execute(t.select()).first()
+
+        if dto_param_value is None:
+            is_(row.adatetimeoffset, None)
+        else:
+            eq_(
+                row.adatetimeoffset,
+                datetime.datetime(
+                    2007,
+                    10,
+                    30,
+                    11,
+                    2,
+                    32,
+                    123456,
+                    util.timezone(
+                        datetime.timedelta(hours=expected_offset_hours)
                     ),
-                )
+                ),
+            )
 
     @emits_warning_on("mssql+mxodbc", r".*does not have any indexes.*")
     @testing.provide_metadata
@@ -1173,7 +1177,7 @@ class BinaryTest(fixtures.TestBase):
         if expected is None:
             expected = data
 
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             conn.execute(binary_table.insert(), data=data)
 
             eq_(conn.scalar(select(binary_table.c.data)), expected)
index abd3a491ff1a8a1d533fa8e091cf0e3a6de954b3..3c569bf058e705cb21329f4eb41e4b92977021e9 100644 (file)
@@ -20,7 +20,7 @@ from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import mock
-from ...engine import test_execute
+from ...engine import test_deprecations
 
 
 class BackendDialectTest(fixtures.TestBase):
@@ -382,56 +382,56 @@ class RemoveUTCTimestampTest(fixtures.TablesTest):
             Column("udata", DateTime, onupdate=func.utc_timestamp()),
         )
 
-    def test_insert_executemany(self):
-        with testing.db.connect() as conn:
-            conn.execute(
-                self.tables.t.insert().values(data=func.utc_timestamp()),
-                [{"x": 5}, {"x": 6}, {"x": 7}],
-            )
+    def test_insert_executemany(self, connection):
+        conn = connection
+        conn.execute(
+            self.tables.t.insert().values(data=func.utc_timestamp()),
+            [{"x": 5}, {"x": 6}, {"x": 7}],
+        )
 
-    def test_update_executemany(self):
-        with testing.db.connect() as conn:
-            timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2)
-            conn.execute(
-                self.tables.t.insert(),
-                [
-                    {"x": 5, "data": timestamp},
-                    {"x": 6, "data": timestamp},
-                    {"x": 7, "data": timestamp},
-                ],
-            )
+    def test_update_executemany(self, connection):
+        conn = connection
+        timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2)
+        conn.execute(
+            self.tables.t.insert(),
+            [
+                {"x": 5, "data": timestamp},
+                {"x": 6, "data": timestamp},
+                {"x": 7, "data": timestamp},
+            ],
+        )
 
-            conn.execute(
-                self.tables.t.update()
-                .values(data=func.utc_timestamp())
-                .where(self.tables.t.c.x == bindparam("xval")),
-                [{"xval": 5}, {"xval": 6}, {"xval": 7}],
-            )
+        conn.execute(
+            self.tables.t.update()
+            .values(data=func.utc_timestamp())
+            .where(self.tables.t.c.x == bindparam("xval")),
+            [{"xval": 5}, {"xval": 6}, {"xval": 7}],
+        )
 
-    def test_insert_executemany_w_default(self):
-        with testing.db.connect() as conn:
-            conn.execute(
-                self.tables.t_default.insert(), [{"x": 5}, {"x": 6}, {"x": 7}]
-            )
+    def test_insert_executemany_w_default(self, connection):
+        conn = connection
+        conn.execute(
+            self.tables.t_default.insert(), [{"x": 5}, {"x": 6}, {"x": 7}]
+        )
 
-    def test_update_executemany_w_default(self):
-        with testing.db.connect() as conn:
-            timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2)
-            conn.execute(
-                self.tables.t_default.insert(),
-                [
-                    {"x": 5, "idata": timestamp},
-                    {"x": 6, "idata": timestamp},
-                    {"x": 7, "idata": timestamp},
-                ],
-            )
+    def test_update_executemany_w_default(self, connection):
+        conn = connection
+        timestamp = datetime.datetime(2015, 4, 17, 18, 5, 2)
+        conn.execute(
+            self.tables.t_default.insert(),
+            [
+                {"x": 5, "idata": timestamp},
+                {"x": 6, "idata": timestamp},
+                {"x": 7, "idata": timestamp},
+            ],
+        )
 
-            conn.execute(
-                self.tables.t_default.update()
-                .values(idata=func.utc_timestamp())
-                .where(self.tables.t_default.c.x == bindparam("xval")),
-                [{"xval": 5}, {"xval": 6}, {"xval": 7}],
-            )
+        conn.execute(
+            self.tables.t_default.update()
+            .values(idata=func.utc_timestamp())
+            .where(self.tables.t_default.c.x == bindparam("xval")),
+            [{"xval": 5}, {"xval": 6}, {"xval": 7}],
+        )
 
 
 class SQLModeDetectionTest(fixtures.TestBase):
@@ -505,7 +505,7 @@ class ExecutionTest(fixtures.TestBase):
 
 
 class AutocommitTextTest(
-    test_execute.AutocommitKeywordFixture, fixtures.TestBase
+    test_deprecations.AutocommitKeywordFixture, fixtures.TestBase
 ):
     __only_on__ = "mysql", "mariadb"
 
index ed88121a553f3293054ef049dee39d8b861c81cf..dc86aaeb05d486e0d6ca842872f714f6ca8062cb 100644 (file)
@@ -5,7 +5,6 @@ from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import Table
-from sqlalchemy import testing
 from sqlalchemy.dialects.mysql import insert
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import assert_raises
@@ -47,155 +46,145 @@ class OnDuplicateTest(fixtures.TablesTest):
             {"id": 2, "bar": "baz"},
         )
 
-    def test_on_duplicate_key_update_multirow(self):
+    def test_on_duplicate_key_update_multirow(self, connection):
         foos = self.tables.foos
-        with testing.db.connect() as conn:
-            conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
-            stmt = insert(foos).values(
-                [dict(id=1, bar="ab"), dict(id=2, bar="b")]
-            )
-            stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
-
-            result = conn.execute(stmt)
-
-            # multirow, so its ambiguous.  this is a behavioral change
-            # in 1.4
-            eq_(result.inserted_primary_key, (None,))
-            eq_(
-                conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
-                [(1, "ab", "bz", False)],
-            )
+        conn = connection
+        conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+        stmt = insert(foos).values([dict(id=1, bar="ab"), dict(id=2, bar="b")])
+        stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
+
+        result = conn.execute(stmt)
+
+        # multirow, so its ambiguous.  this is a behavioral change
+        # in 1.4
+        eq_(result.inserted_primary_key, (None,))
+        eq_(
+            conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+            [(1, "ab", "bz", False)],
+        )
 
-    def test_on_duplicate_key_update_singlerow(self):
+    def test_on_duplicate_key_update_singlerow(self, connection):
         foos = self.tables.foos
-        with testing.db.connect() as conn:
-            conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
-            stmt = insert(foos).values(dict(id=2, bar="b"))
-            stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
-
-            result = conn.execute(stmt)
-
-            # only one row in the INSERT so we do inserted_primary_key
-            eq_(result.inserted_primary_key, (2,))
-            eq_(
-                conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
-                [(1, "b", "bz", False)],
-            )
+        conn = connection
+        conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+        stmt = insert(foos).values(dict(id=2, bar="b"))
+        stmt = stmt.on_duplicate_key_update(bar=stmt.inserted.bar)
+
+        result = conn.execute(stmt)
+
+        # only one row in the INSERT so we do inserted_primary_key
+        eq_(result.inserted_primary_key, (2,))
+        eq_(
+            conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+            [(1, "b", "bz", False)],
+        )
 
-    def test_on_duplicate_key_update_null_multirow(self):
+    def test_on_duplicate_key_update_null_multirow(self, connection):
         foos = self.tables.foos
-        with testing.db.connect() as conn:
-            conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
-            stmt = insert(foos).values(
-                [dict(id=1, bar="ab"), dict(id=2, bar="b")]
-            )
-            stmt = stmt.on_duplicate_key_update(updated_once=None)
-            result = conn.execute(stmt)
-
-            # ambiguous
-            eq_(result.inserted_primary_key, (None,))
-            eq_(
-                conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
-                [(1, "b", "bz", None)],
-            )
+        conn = connection
+        conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+        stmt = insert(foos).values([dict(id=1, bar="ab"), dict(id=2, bar="b")])
+        stmt = stmt.on_duplicate_key_update(updated_once=None)
+        result = conn.execute(stmt)
+
+        # ambiguous
+        eq_(result.inserted_primary_key, (None,))
+        eq_(
+            conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+            [(1, "b", "bz", None)],
+        )
 
-    def test_on_duplicate_key_update_expression_multirow(self):
+    def test_on_duplicate_key_update_expression_multirow(self, connection):
         foos = self.tables.foos
-        with testing.db.connect() as conn:
-            conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
-            stmt = insert(foos).values(
-                [dict(id=1, bar="ab"), dict(id=2, bar="b")]
-            )
-            stmt = stmt.on_duplicate_key_update(
-                bar=func.concat(stmt.inserted.bar, "_foo")
-            )
-            result = conn.execute(stmt)
-            eq_(result.inserted_primary_key, (None,))
-            eq_(
-                conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
-                [(1, "ab_foo", "bz", False)],
-            )
+        conn = connection
+        conn.execute(insert(foos, dict(id=1, bar="b", baz="bz")))
+        stmt = insert(foos).values([dict(id=1, bar="ab"), dict(id=2, bar="b")])
+        stmt = stmt.on_duplicate_key_update(
+            bar=func.concat(stmt.inserted.bar, "_foo")
+        )
+        result = conn.execute(stmt)
+        eq_(result.inserted_primary_key, (None,))
+        eq_(
+            conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+            [(1, "ab_foo", "bz", False)],
+        )
 
-    def test_on_duplicate_key_update_preserve_order(self):
+    def test_on_duplicate_key_update_preserve_order(self, connection):
         foos = self.tables.foos
-        with testing.db.connect() as conn:
-            conn.execute(
-                insert(
-                    foos,
-                    [
-                        dict(id=1, bar="b", baz="bz"),
-                        dict(id=2, bar="b", baz="bz2"),
-                    ],
-                )
-            )
-
-            stmt = insert(foos)
-            update_condition = foos.c.updated_once == False
-
-            # The following statements show importance of the columns update
-            # ordering as old values being referenced in UPDATE clause are
-            # getting replaced one by one from left to right with their new
-            # values.
-            stmt1 = stmt.on_duplicate_key_update(
+        conn = connection
+        conn.execute(
+            insert(
+                foos,
                 [
-                    (
-                        "bar",
-                        func.if_(
-                            update_condition,
-                            func.values(foos.c.bar),
-                            foos.c.bar,
-                        ),
-                    ),
-                    (
-                        "updated_once",
-                        func.if_(update_condition, True, foos.c.updated_once),
-                    ),
-                ]
+                    dict(id=1, bar="b", baz="bz"),
+                    dict(id=2, bar="b", baz="bz2"),
+                ],
             )
-            stmt2 = stmt.on_duplicate_key_update(
-                [
-                    (
-                        "updated_once",
-                        func.if_(update_condition, True, foos.c.updated_once),
+        )
+
+        stmt = insert(foos)
+        update_condition = foos.c.updated_once == False
+
+        # The following statements show importance of the columns update
+        # ordering as old values being referenced in UPDATE clause are
+        # getting replaced one by one from left to right with their new
+        # values.
+        stmt1 = stmt.on_duplicate_key_update(
+            [
+                (
+                    "bar",
+                    func.if_(
+                        update_condition,
+                        func.values(foos.c.bar),
+                        foos.c.bar,
                     ),
-                    (
-                        "bar",
-                        func.if_(
-                            update_condition,
-                            func.values(foos.c.bar),
-                            foos.c.bar,
-                        ),
+                ),
+                (
+                    "updated_once",
+                    func.if_(update_condition, True, foos.c.updated_once),
+                ),
+            ]
+        )
+        stmt2 = stmt.on_duplicate_key_update(
+            [
+                (
+                    "updated_once",
+                    func.if_(update_condition, True, foos.c.updated_once),
+                ),
+                (
+                    "bar",
+                    func.if_(
+                        update_condition,
+                        func.values(foos.c.bar),
+                        foos.c.bar,
                     ),
-                ]
-            )
-            # First statement should succeed updating column bar
-            conn.execute(stmt1, dict(id=1, bar="ab"))
-            eq_(
-                conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
-                [(1, "ab", "bz", True)],
-            )
-            # Second statement will do noop update of column bar
-            conn.execute(stmt2, dict(id=2, bar="ab"))
-            eq_(
-                conn.execute(foos.select().where(foos.c.id == 2)).fetchall(),
-                [(2, "b", "bz2", True)],
-            )
+                ),
+            ]
+        )
+        # First statement should succeed updating column bar
+        conn.execute(stmt1, dict(id=1, bar="ab"))
+        eq_(
+            conn.execute(foos.select().where(foos.c.id == 1)).fetchall(),
+            [(1, "ab", "bz", True)],
+        )
+        # Second statement will do noop update of column bar
+        conn.execute(stmt2, dict(id=2, bar="ab"))
+        eq_(
+            conn.execute(foos.select().where(foos.c.id == 2)).fetchall(),
+            [(2, "b", "bz2", True)],
+        )
 
-    def test_last_inserted_id(self):
+    def test_last_inserted_id(self, connection):
         foos = self.tables.foos
-        with testing.db.connect() as conn:
-            stmt = insert(foos).values({"bar": "b", "baz": "bz"})
-            result = conn.execute(
-                stmt.on_duplicate_key_update(
-                    bar=stmt.inserted.bar, baz="newbz"
-                )
-            )
-            eq_(result.inserted_primary_key, (1,))
+        conn = connection
+        stmt = insert(foos).values({"bar": "b", "baz": "bz"})
+        result = conn.execute(
+            stmt.on_duplicate_key_update(bar=stmt.inserted.bar, baz="newbz")
+        )
+        eq_(result.inserted_primary_key, (1,))
 
-            stmt = insert(foos).values({"id": 1, "bar": "b", "baz": "bz"})
-            result = conn.execute(
-                stmt.on_duplicate_key_update(
-                    bar=stmt.inserted.bar, baz="newbz"
-                )
-            )
-            eq_(result.inserted_primary_key, (1,))
+        stmt = insert(foos).values({"id": 1, "bar": "b", "baz": "bz"})
+        result = conn.execute(
+            stmt.on_duplicate_key_update(bar=stmt.inserted.bar, baz="newbz")
+        )
+        eq_(result.inserted_primary_key, (1,))
index f9d9caf166b9931f0ede1398b49ef2912ab6f026..f56cd98aa3ab64e9a91494bf06a78d97516eff10 100644 (file)
@@ -9,7 +9,6 @@ from sqlalchemy import Column
 from sqlalchemy import false
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
-from sqlalchemy import MetaData
 from sqlalchemy import or_
 from sqlalchemy import select
 from sqlalchemy import String
@@ -44,16 +43,13 @@ class IdiosyncrasyTest(fixtures.TestBase):
         )
 
 
-class MatchTest(fixtures.TestBase):
+class MatchTest(fixtures.TablesTest):
     __only_on__ = "mysql", "mariadb"
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
-        global metadata, cattable, matchtable
-        metadata = MetaData(testing.db)
-
-        cattable = Table(
+    def define_tables(cls, metadata):
+        Table(
             "cattable",
             metadata,
             Column("id", Integer, primary_key=True),
@@ -61,7 +57,7 @@ class MatchTest(fixtures.TestBase):
             mysql_engine="MyISAM",
             mariadb_engine="MyISAM",
         )
-        matchtable = Table(
+        Table(
             "matchtable",
             metadata,
             Column("id", Integer, primary_key=True),
@@ -70,15 +66,20 @@ class MatchTest(fixtures.TestBase):
             mysql_engine="MyISAM",
             mariadb_engine="MyISAM",
         )
-        metadata.create_all()
 
-        cattable.insert().execute(
+    @classmethod
+    def insert_data(cls, connection):
+        cattable, matchtable = cls.tables("cattable", "matchtable")
+
+        connection.execute(
+            cattable.insert(),
             [
                 {"id": 1, "description": "Python"},
                 {"id": 2, "description": "Ruby"},
-            ]
+            ],
         )
-        matchtable.insert().execute(
+        connection.execute(
+            matchtable.insert(),
             [
                 {
                     "id": 1,
@@ -97,43 +98,36 @@ class MatchTest(fixtures.TestBase):
                     "category_id": 1,
                 },
                 {"id": 5, "title": "Python in a Nutshell", "category_id": 1},
-            ]
+            ],
         )
 
-    @classmethod
-    def teardown_class(cls):
-        metadata.drop_all()
-
-    def test_simple_match(self):
-        results = (
+    def test_simple_match(self, connection):
+        matchtable = self.tables.matchtable
+        results = connection.execute(
             matchtable.select()
             .where(matchtable.c.title.match("python"))
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([2, 5], [r.id for r in results])
 
-    def test_not_match(self):
-        results = (
+    def test_not_match(self, connection):
+        matchtable = self.tables.matchtable
+        results = connection.execute(
             matchtable.select()
             .where(~matchtable.c.title.match("python"))
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
         )
         eq_([1, 3, 4], [r.id for r in results])
 
-    def test_simple_match_with_apostrophe(self):
-        results = (
-            matchtable.select()
-            .where(matchtable.c.title.match("Matz's"))
-            .execute()
-            .fetchall()
-        )
+    def test_simple_match_with_apostrophe(self, connection):
+        matchtable = self.tables.matchtable
+        results = connection.execute(
+            matchtable.select().where(matchtable.c.title.match("Matz's"))
+        ).fetchall()
         eq_([3], [r.id for r in results])
 
     def test_return_value(self, connection):
+        matchtable = self.tables.matchtable
         # test [ticket:3263]
         result = connection.execute(
             select(
@@ -155,8 +149,9 @@ class MatchTest(fixtures.TestBase):
             ],
         )
 
-    def test_or_match(self):
-        results1 = (
+    def test_or_match(self, connection):
+        matchtable = self.tables.matchtable
+        results1 = connection.execute(
             matchtable.select()
             .where(
                 or_(
@@ -165,42 +160,37 @@ class MatchTest(fixtures.TestBase):
                 )
             )
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([1, 3, 5], [r.id for r in results1])
-        results2 = (
+        results2 = connection.execute(
             matchtable.select()
             .where(matchtable.c.title.match("nutshell ruby"))
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([1, 3, 5], [r.id for r in results2])
 
-    def test_and_match(self):
-        results1 = (
-            matchtable.select()
-            .where(
+    def test_and_match(self, connection):
+        matchtable = self.tables.matchtable
+        results1 = connection.execute(
+            matchtable.select().where(
                 and_(
                     matchtable.c.title.match("python"),
                     matchtable.c.title.match("nutshell"),
                 )
             )
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([5], [r.id for r in results1])
-        results2 = (
-            matchtable.select()
-            .where(matchtable.c.title.match("+python +nutshell"))
-            .execute()
-            .fetchall()
-        )
+        results2 = connection.execute(
+            matchtable.select().where(
+                matchtable.c.title.match("+python +nutshell")
+            )
+        ).fetchall()
         eq_([5], [r.id for r in results2])
 
-    def test_match_across_joins(self):
-        results = (
+    def test_match_across_joins(self, connection):
+        matchtable = self.tables.matchtable
+        cattable = self.tables.cattable
+        results = connection.execute(
             matchtable.select()
             .where(
                 and_(
@@ -212,9 +202,7 @@ class MatchTest(fixtures.TestBase):
                 )
             )
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([1, 3, 5], [r.id for r in results])
 
 
index 3871dbecca7dfa241dc875863c25f80b33255655..55d88957a31855805667a8df1be69af39d201d65 100644 (file)
@@ -324,7 +324,8 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
             str(reflected.c.c6.server_default.arg).upper(),
         )
 
-    def test_reflection_with_table_options(self):
+    @testing.provide_metadata
+    def test_reflection_with_table_options(self, connection):
         comment = r"""Comment types type speedily ' " \ '' Fun!"""
         if testing.against("mariadb"):
             kwargs = dict(
@@ -347,18 +348,15 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
 
         def_table = Table(
             "mysql_def",
-            MetaData(),
+            self.metadata,
             Column("c1", Integer()),
             comment=comment,
             **kwargs
         )
 
-        with testing.db.connect() as conn:
-            def_table.create(conn)
-            try:
-                reflected = Table("mysql_def", MetaData(), autoload_with=conn)
-            finally:
-                def_table.drop(conn)
+        conn = connection
+        def_table.create(conn)
+        reflected = Table("mysql_def", MetaData(), autoload_with=conn)
 
         if testing.against("mariadb"):
             assert def_table.kwargs["mariadb_engine"] == "MEMORY"
@@ -554,31 +552,31 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
                     assert 1 not in list(conn.execute(tbl.select()).first())
 
     @testing.provide_metadata
-    def test_view_reflection(self):
+    def test_view_reflection(self, connection):
         Table(
             "x", self.metadata, Column("a", Integer), Column("b", String(50))
         )
-        self.metadata.create_all()
+        self.metadata.create_all(connection)
 
-        with testing.db.connect() as conn:
-            conn.exec_driver_sql("CREATE VIEW v1 AS SELECT * FROM x")
-            conn.exec_driver_sql(
-                "CREATE ALGORITHM=MERGE VIEW v2 AS SELECT * FROM x"
-            )
-            conn.exec_driver_sql(
-                "CREATE ALGORITHM=UNDEFINED VIEW v3 AS SELECT * FROM x"
-            )
-            conn.exec_driver_sql(
-                "CREATE DEFINER=CURRENT_USER VIEW v4 AS SELECT * FROM x"
-            )
+        conn = connection
+        conn.exec_driver_sql("CREATE VIEW v1 AS SELECT * FROM x")
+        conn.exec_driver_sql(
+            "CREATE ALGORITHM=MERGE VIEW v2 AS SELECT * FROM x"
+        )
+        conn.exec_driver_sql(
+            "CREATE ALGORITHM=UNDEFINED VIEW v3 AS SELECT * FROM x"
+        )
+        conn.exec_driver_sql(
+            "CREATE DEFINER=CURRENT_USER VIEW v4 AS SELECT * FROM x"
+        )
 
         @event.listens_for(self.metadata, "before_drop")
         def cleanup(*arg, **kw):
-            with testing.db.connect() as conn:
+            with testing.db.begin() as conn:
                 for v in ["v1", "v2", "v3", "v4"]:
                     conn.exec_driver_sql("DROP VIEW %s" % v)
 
-        insp = inspect(testing.db)
+        insp = inspect(connection)
         for v in ["v1", "v2", "v3", "v4"]:
             eq_(
                 [
@@ -589,38 +587,36 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
             )
 
     @testing.provide_metadata
-    def test_skip_not_describable(self):
+    def test_skip_not_describable(self, connection):
         @event.listens_for(self.metadata, "before_drop")
         def cleanup(*arg, **kw):
-            with testing.db.connect() as conn:
+            with testing.db.begin() as conn:
                 conn.exec_driver_sql("DROP TABLE IF EXISTS test_t1")
                 conn.exec_driver_sql("DROP TABLE IF EXISTS test_t2")
                 conn.exec_driver_sql("DROP VIEW IF EXISTS test_v")
 
-        with testing.db.connect() as conn:
-            conn.exec_driver_sql("CREATE TABLE test_t1 (id INTEGER)")
-            conn.exec_driver_sql("CREATE TABLE test_t2 (id INTEGER)")
-            conn.exec_driver_sql(
-                "CREATE VIEW test_v AS SELECT id FROM test_t1"
-            )
-            conn.exec_driver_sql("DROP TABLE test_t1")
-
-            m = MetaData()
-            with expect_warnings(
-                "Skipping .* Table or view named .?test_v.? could not be "
-                "reflected: .* references invalid table"
-            ):
-                m.reflect(views=True, bind=conn)
-            eq_(m.tables["test_t2"].name, "test_t2")
-
-            assert_raises_message(
-                exc.UnreflectableTableError,
-                "references invalid table",
-                Table,
-                "test_v",
-                MetaData(),
-                autoload_with=conn,
-            )
+        conn = connection
+        conn.exec_driver_sql("CREATE TABLE test_t1 (id INTEGER)")
+        conn.exec_driver_sql("CREATE TABLE test_t2 (id INTEGER)")
+        conn.exec_driver_sql("CREATE VIEW test_v AS SELECT id FROM test_t1")
+        conn.exec_driver_sql("DROP TABLE test_t1")
+
+        m = MetaData()
+        with expect_warnings(
+            "Skipping .* Table or view named .?test_v.? could not be "
+            "reflected: .* references invalid table"
+        ):
+            m.reflect(views=True, bind=conn)
+        eq_(m.tables["test_t2"].name, "test_t2")
+
+        assert_raises_message(
+            exc.UnreflectableTableError,
+            "references invalid table",
+            Table,
+            "test_v",
+            MetaData(),
+            autoload_with=conn,
+        )
 
     @testing.exclude("mysql", "<", (5, 0, 0), "no information_schema support")
     def test_system_views(self):
@@ -663,7 +659,7 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
         ):
             Table("nn_t%d" % idx, meta)  # to allow DROP
 
-            with testing.db.connect() as c:
+            with testing.db.begin() as c:
                 c.exec_driver_sql(
                     """
                         CREATE TABLE nn_t%d (
index aafad8dc15b48631d0fe88a26540a74b8415fd44..9a2174a24a7a00181a6eb51a5471429b7929d198 100644 (file)
@@ -89,6 +89,8 @@ class DefaultSchemaNameTest(fixtures.TestBase):
         eng = engines.testing_engine()
 
         with eng.connect() as conn:
+
+            trans = conn.begin()
             eq_(
                 testing.db.dialect._get_default_schema_name(conn),
                 default_schema_name,
@@ -104,6 +106,7 @@ class DefaultSchemaNameTest(fixtures.TestBase):
             )
 
             conn.invalidate()
+            trans.rollback()
 
             eq_(
                 testing.db.dialect._get_default_schema_name(conn),
@@ -317,53 +320,51 @@ class ComputedReturningTest(fixtures.TablesTest):
             implicit_returning=False,
         )
 
-    def test_computed_insert(self):
+    def test_computed_insert(self, connection):
         test = self.tables.test
-        with testing.db.connect() as conn:
-            result = conn.execute(
-                test.insert().return_defaults(), {"id": 1, "foo": 5}
-            )
+        conn = connection
+        result = conn.execute(
+            test.insert().return_defaults(), {"id": 1, "foo": 5}
+        )
 
-            eq_(result.returned_defaults, (47,))
+        eq_(result.returned_defaults, (47,))
 
-            eq_(conn.scalar(select(test.c.bar)), 47)
+        eq_(conn.scalar(select(test.c.bar)), 47)
 
-    def test_computed_update_warning(self):
+    def test_computed_update_warning(self, connection):
         test = self.tables.test
-        with testing.db.connect() as conn:
-            conn.execute(test.insert(), {"id": 1, "foo": 5})
+        conn = connection
+        conn.execute(test.insert(), {"id": 1, "foo": 5})
 
-            if testing.db.dialect._supports_update_returning_computed_cols:
+        if testing.db.dialect._supports_update_returning_computed_cols:
+            result = conn.execute(
+                test.update().values(foo=10).return_defaults()
+            )
+            eq_(result.returned_defaults, (52,))
+        else:
+            with testing.expect_warnings(
+                "Computed columns don't work with Oracle UPDATE"
+            ):
                 result = conn.execute(
                     test.update().values(foo=10).return_defaults()
                 )
-                eq_(result.returned_defaults, (52,))
-            else:
-                with testing.expect_warnings(
-                    "Computed columns don't work with Oracle UPDATE"
-                ):
-                    result = conn.execute(
-                        test.update().values(foo=10).return_defaults()
-                    )
 
-                    # returns the *old* value
-                    eq_(result.returned_defaults, (47,))
+                # returns the *old* value
+                eq_(result.returned_defaults, (47,))
 
-            eq_(conn.scalar(select(test.c.bar)), 52)
+        eq_(conn.scalar(select(test.c.bar)), 52)
 
-    def test_computed_update_no_warning(self):
+    def test_computed_update_no_warning(self, connection):
         test = self.tables.test_no_returning
-        with testing.db.connect() as conn:
-            conn.execute(test.insert(), {"id": 1, "foo": 5})
+        conn = connection
+        conn.execute(test.insert(), {"id": 1, "foo": 5})
 
-            result = conn.execute(
-                test.update().values(foo=10).return_defaults()
-            )
+        result = conn.execute(test.update().values(foo=10).return_defaults())
 
-            # no returning
-            eq_(result.returned_defaults, None)
+        # no returning
+        eq_(result.returned_defaults, None)
 
-            eq_(conn.scalar(select(test.c.bar)), 52)
+        eq_(conn.scalar(select(test.c.bar)), 52)
 
 
 class OutParamTest(fixtures.TestBase, AssertsExecutionResults):
@@ -372,7 +373,7 @@ class OutParamTest(fixtures.TestBase, AssertsExecutionResults):
 
     @classmethod
     def setup_class(cls):
-        with testing.db.connect() as c:
+        with testing.db.begin() as c:
             c.exec_driver_sql(
                 """
 create or replace procedure foo(x_in IN number, x_out OUT number,
@@ -404,7 +405,7 @@ end;
 
     @classmethod
     def teardown_class(cls):
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.execute(text("DROP PROCEDURE foo"))
 
 
@@ -674,7 +675,7 @@ class ExecuteTest(fixtures.TestBase):
             seq.drop(connection)
 
     @testing.provide_metadata
-    def test_limit_offset_for_update(self):
+    def test_limit_offset_for_update(self, connection):
         metadata = self.metadata
         # oracle can't actually do the ROWNUM thing with FOR UPDATE
         # very well.
@@ -685,19 +686,24 @@ class ExecuteTest(fixtures.TestBase):
             Column("id", Integer, primary_key=True),
             Column("data", Integer),
         )
-        metadata.create_all()
+        metadata.create_all(connection)
 
-        t.insert().execute(
-            {"id": 1, "data": 1},
-            {"id": 2, "data": 7},
-            {"id": 3, "data": 12},
-            {"id": 4, "data": 15},
-            {"id": 5, "data": 32},
+        connection.execute(
+            t.insert(),
+            [
+                {"id": 1, "data": 1},
+                {"id": 2, "data": 7},
+                {"id": 3, "data": 12},
+                {"id": 4, "data": 15},
+                {"id": 5, "data": 32},
+            ],
         )
 
         # here, we can't use ORDER BY.
         eq_(
-            t.select().with_for_update().limit(2).execute().fetchall(),
+            connection.execute(
+                t.select().with_for_update().limit(2)
+            ).fetchall(),
             [(1, 1), (2, 7)],
         )
 
@@ -706,7 +712,8 @@ class ExecuteTest(fixtures.TestBase):
         assert_raises_message(
             exc.DatabaseError,
             "ORA-02014",
-            t.select().with_for_update().limit(2).offset(3).execute,
+            connection.execute,
+            t.select().with_for_update().limit(2).offset(3),
         )
 
 
index efa21fc1a3d6940e862e603425e0400f967a6512..2e515556f37db4330c561a7698a51e77543548a8 100644 (file)
@@ -34,11 +34,6 @@ from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
 
-def exec_sql(engine, sql, *args, **kwargs):
-    with engine.connect() as conn:
-        return conn.exec_driver_sql(sql, *args, **kwargs)
-
-
 class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
     __only_on__ = "oracle"
     __backend__ = True
@@ -49,62 +44,64 @@ class MultiSchemaTest(fixtures.TestBase, AssertsCompiledSQL):
         # don't really know how else to go here unless
         # we connect as the other user.
 
-        for stmt in (
-            """
-create table %(test_schema)s.parent(
-    id integer primary key,
-    data varchar2(50)
-);
-
-COMMENT ON TABLE %(test_schema)s.parent IS 'my table comment';
-
-create table %(test_schema)s.child(
-    id integer primary key,
-    data varchar2(50),
-    parent_id integer references %(test_schema)s.parent(id)
-);
-
-create table local_table(
-    id integer primary key,
-    data varchar2(50)
-);
-
-create synonym %(test_schema)s.ptable for %(test_schema)s.parent;
-create synonym %(test_schema)s.ctable for %(test_schema)s.child;
-
-create synonym %(test_schema)s_pt for %(test_schema)s.parent;
-
-create synonym %(test_schema)s.local_table for local_table;
-
--- can't make a ref from local schema to the
--- remote schema's table without this,
--- *and* cant give yourself a grant !
--- so we give it to public.  ideas welcome.
-grant references on %(test_schema)s.parent to public;
-grant references on %(test_schema)s.child to public;
-"""
-            % {"test_schema": testing.config.test_schema}
-        ).split(";"):
-            if stmt.strip():
-                exec_sql(testing.db, stmt)
+        with testing.db.begin() as conn:
+            for stmt in (
+                """
+    create table %(test_schema)s.parent(
+        id integer primary key,
+        data varchar2(50)
+    );
+
+    COMMENT ON TABLE %(test_schema)s.parent IS 'my table comment';
+
+    create table %(test_schema)s.child(
+        id integer primary key,
+        data varchar2(50),
+        parent_id integer references %(test_schema)s.parent(id)
+    );
+
+    create table local_table(
+        id integer primary key,
+        data varchar2(50)
+    );
+
+    create synonym %(test_schema)s.ptable for %(test_schema)s.parent;
+    create synonym %(test_schema)s.ctable for %(test_schema)s.child;
+
+    create synonym %(test_schema)s_pt for %(test_schema)s.parent;
+
+    create synonym %(test_schema)s.local_table for local_table;
+
+    -- can't make a ref from local schema to the
+    -- remote schema's table without this,
+    -- *and* cant give yourself a grant !
+    -- so we give it to public.  ideas welcome.
+    grant references on %(test_schema)s.parent to public;
+    grant references on %(test_schema)s.child to public;
+    """
+                % {"test_schema": testing.config.test_schema}
+            ).split(";"):
+                if stmt.strip():
+                    conn.exec_driver_sql(stmt)
 
     @classmethod
     def teardown_class(cls):
-        for stmt in (
-            """
-drop table %(test_schema)s.child;
-drop table %(test_schema)s.parent;
-drop table local_table;
-drop synonym %(test_schema)s.ctable;
-drop synonym %(test_schema)s.ptable;
-drop synonym %(test_schema)s_pt;
-drop synonym %(test_schema)s.local_table;
-
-"""
-            % {"test_schema": testing.config.test_schema}
-        ).split(";"):
-            if stmt.strip():
-                exec_sql(testing.db, stmt)
+        with testing.db.begin() as conn:
+            for stmt in (
+                """
+    drop table %(test_schema)s.child;
+    drop table %(test_schema)s.parent;
+    drop table local_table;
+    drop synonym %(test_schema)s.ctable;
+    drop synonym %(test_schema)s.ptable;
+    drop synonym %(test_schema)s_pt;
+    drop synonym %(test_schema)s.local_table;
+
+    """
+                % {"test_schema": testing.config.test_schema}
+            ).split(";"):
+                if stmt.strip():
+                    conn.exec_driver_sql(stmt)
 
     @testing.provide_metadata
     def test_create_same_names_explicit_schema(self):
@@ -162,7 +159,7 @@ drop synonym %(test_schema)s.local_table;
         )
 
     @testing.provide_metadata
-    def test_create_same_names_implicit_schema(self):
+    def test_create_same_names_implicit_schema(self, connection):
         meta = self.metadata
         parent = Table(
             "parent", meta, Column("pid", Integer, primary_key=True)
@@ -173,10 +170,11 @@ drop synonym %(test_schema)s.local_table;
             Column("cid", Integer, primary_key=True),
             Column("pid", Integer, ForeignKey("parent.pid")),
         )
-        meta.create_all()
-        parent.insert().execute({"pid": 1})
-        child.insert().execute({"cid": 1, "pid": 1})
-        eq_(child.select().execute().fetchall(), [(1, 1)])
+        meta.create_all(connection)
+
+        connection.execute(parent.insert(), {"pid": 1})
+        connection.execute(child.insert(), {"cid": 1, "pid": 1})
+        eq_(connection.execute(child.select()).fetchall(), [(1, 1)])
 
     def test_reflect_alt_owner_explicit(self):
         meta = MetaData()
@@ -238,9 +236,8 @@ drop synonym %(test_schema)s.local_table;
             {"text": "my local comment"},
         )
 
-    def test_reflect_local_to_remote(self):
-        exec_sql(
-            testing.db,
+    def test_reflect_local_to_remote(self, connection):
+        connection.exec_driver_sql(
             "CREATE TABLE localtable (id INTEGER "
             "PRIMARY KEY, parent_id INTEGER REFERENCES "
             "%(test_schema)s.parent(id))"
@@ -258,7 +255,7 @@ drop synonym %(test_schema)s.local_table;
                 % {"test_schema": testing.config.test_schema},
             )
         finally:
-            exec_sql(testing.db, "DROP TABLE localtable")
+            connection.exec_driver_sql("DROP TABLE localtable")
 
     def test_reflect_alt_owner_implicit(self):
         meta = MetaData()
@@ -286,9 +283,8 @@ drop synonym %(test_schema)s.local_table;
                 select(parent, child).select_from(parent.join(child))
             ).fetchall()
 
-    def test_reflect_alt_owner_synonyms(self):
-        exec_sql(
-            testing.db,
+    def test_reflect_alt_owner_synonyms(self, connection):
+        connection.exec_driver_sql(
             "CREATE TABLE localtable (id INTEGER "
             "PRIMARY KEY, parent_id INTEGER REFERENCES "
             "%s.ptable(id))" % testing.config.test_schema,
@@ -298,7 +294,7 @@ drop synonym %(test_schema)s.local_table;
             lcl = Table(
                 "localtable",
                 meta,
-                autoload_with=testing.db,
+                autoload_with=connection,
                 oracle_resolve_synonyms=True,
             )
             parent = meta.tables["%s.ptable" % testing.config.test_schema]
@@ -309,12 +305,11 @@ drop synonym %(test_schema)s.local_table;
                 "localtable.parent_id"
                 % {"test_schema": testing.config.test_schema},
             )
-            with testing.db.connect() as conn:
-                conn.execute(
-                    select(parent, lcl).select_from(parent.join(lcl))
-                ).fetchall()
+            connection.execute(
+                select(parent, lcl).select_from(parent.join(lcl))
+            ).fetchall()
         finally:
-            exec_sql(testing.db, "DROP TABLE localtable")
+            connection.exec_driver_sql("DROP TABLE localtable")
 
     def test_reflect_remote_synonyms(self):
         meta = MetaData()
@@ -389,19 +384,20 @@ class SystemTableTablenamesTest(fixtures.TestBase):
     __backend__ = True
 
     def setup(self):
-        exec_sql(testing.db, "create table my_table (id integer)")
-        exec_sql(
-            testing.db,
-            "create global temporary table my_temp_table (id integer)",
-        )
-        exec_sql(
-            testing.db, "create table foo_table (id integer) tablespace SYSTEM"
-        )
+        with testing.db.begin() as conn:
+            conn.exec_driver_sql("create table my_table (id integer)")
+            conn.exec_driver_sql(
+                "create global temporary table my_temp_table (id integer)",
+            )
+            conn.exec_driver_sql(
+                "create table foo_table (id integer) tablespace SYSTEM"
+            )
 
     def teardown(self):
-        exec_sql(testing.db, "drop table my_temp_table")
-        exec_sql(testing.db, "drop table my_table")
-        exec_sql(testing.db, "drop table foo_table")
+        with testing.db.begin() as conn:
+            conn.exec_driver_sql("drop table my_temp_table")
+            conn.exec_driver_sql("drop table my_table")
+            conn.exec_driver_sql("drop table foo_table")
 
     def test_table_names_no_system(self):
         insp = inspect(testing.db)
@@ -430,24 +426,25 @@ class DontReflectIOTTest(fixtures.TestBase):
     __backend__ = True
 
     def setup(self):
-        exec_sql(
-            testing.db,
-            """
-        CREATE TABLE admin_docindex(
-                token char(20),
-                doc_id NUMBER,
-                token_frequency NUMBER,
-                token_offsets VARCHAR2(2000),
-                CONSTRAINT pk_admin_docindex PRIMARY KEY (token, doc_id))
-            ORGANIZATION INDEX
-            TABLESPACE users
-            PCTTHRESHOLD 20
-            OVERFLOW TABLESPACE users
-        """,
-        )
+        with testing.db.begin() as conn:
+            conn.exec_driver_sql(
+                """
+            CREATE TABLE admin_docindex(
+                    token char(20),
+                    doc_id NUMBER,
+                    token_frequency NUMBER,
+                    token_offsets VARCHAR2(2000),
+                    CONSTRAINT pk_admin_docindex PRIMARY KEY (token, doc_id))
+                ORGANIZATION INDEX
+                TABLESPACE users
+                PCTTHRESHOLD 20
+                OVERFLOW TABLESPACE users
+            """,
+            )
 
     def teardown(self):
-        exec_sql(testing.db, "drop table admin_docindex")
+        with testing.db.begin() as conn:
+            conn.exec_driver_sql("drop table admin_docindex")
 
     def test_reflect_all(self):
         m = MetaData(testing.db)
@@ -456,30 +453,24 @@ class DontReflectIOTTest(fixtures.TestBase):
 
 
 def all_tables_compression_missing():
-    try:
-        exec_sql(testing.db, "SELECT compression FROM all_tables")
+    with testing.db.connect() as conn:
         if (
             "Enterprise Edition"
-            not in exec_sql(testing.db, "select * from v$version").scalar()
+            not in conn.exec_driver_sql("select * from v$version").scalar()
             # this works in Oracle Database 18c Express Edition Release
         ) and testing.db.dialect.server_version_info < (18,):
             return True
         return False
-    except Exception:
-        return True
 
 
 def all_tables_compress_for_missing():
-    try:
-        exec_sql(testing.db, "SELECT compress_for FROM all_tables")
+    with testing.db.connect() as conn:
         if (
             "Enterprise Edition"
-            not in exec_sql(testing.db, "select * from v$version").scalar()
+            not in conn.exec_driver_sql("select * from v$version").scalar()
         ):
             return True
         return False
-    except Exception:
-        return True
 
 
 class TableReflectionTest(fixtures.TestBase):
@@ -748,7 +739,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
         # note that the synonym here is still not totally functional
         # when accessing via a different username as we do with the
         # multiprocess test suite, so testing here is minimal
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.exec_driver_sql(
                 "create table test_table "
                 "(id integer primary key, data varchar2(50))"
@@ -760,7 +751,7 @@ class DBLinkReflectionTest(fixtures.TestBase):
 
     @classmethod
     def teardown_class(cls):
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.exec_driver_sql("drop synonym test_table_syn")
             conn.exec_driver_sql("drop table test_table")
 
index 8fbf374ee59d22c1a8f4528e47171ed364e85510..db3825d13750e4603d66475e19bc085c138822b0 100644 (file)
@@ -228,16 +228,16 @@ class TypesTest(fixtures.TestBase):
 
     @testing.requires.returning
     @testing.provide_metadata
-    def test_int_not_float(self):
+    def test_int_not_float(self, connection):
         m = self.metadata
         t1 = Table("t1", m, Column("foo", Integer))
-        t1.create()
-        r = t1.insert().values(foo=5).returning(t1.c.foo).execute()
+        t1.create(connection)
+        r = connection.execute(t1.insert().values(foo=5).returning(t1.c.foo))
         x = r.scalar()
         assert x == 5
         assert isinstance(x, int)
 
-        x = t1.select().scalar()
+        x = connection.scalar(t1.select())
         assert x == 5
         assert isinstance(x, int)
 
@@ -281,7 +281,7 @@ class TypesTest(fixtures.TestBase):
             eq_(conn.execute(s3).fetchall(), [(5, rowid)])
 
     @testing.provide_metadata
-    def test_interval(self):
+    def test_interval(self, connection):
         metadata = self.metadata
         interval_table = Table(
             "intervaltable",
@@ -291,11 +291,12 @@ class TypesTest(fixtures.TestBase):
             ),
             Column("day_interval", oracle.INTERVAL(day_precision=3)),
         )
-        metadata.create_all()
-        interval_table.insert().execute(
-            day_interval=datetime.timedelta(days=35, seconds=5743)
+        metadata.create_all(connection)
+        connection.execute(
+            interval_table.insert(),
+            dict(day_interval=datetime.timedelta(days=35, seconds=5743)),
         )
-        row = interval_table.select().execute().first()
+        row = connection.execute(interval_table.select()).first()
         eq_(row["day_interval"], datetime.timedelta(days=35, seconds=5743))
 
     @testing.provide_metadata
@@ -364,16 +365,19 @@ class TypesTest(fixtures.TestBase):
             Column("intcol", Integer),
             Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=False)),
         )
-        t1.create()
-        t1.insert().execute(
+        t1.create(connection)
+        connection.execute(
+            t1.insert(),
             [
                 dict(intcol=1, numericcol=float("inf")),
                 dict(intcol=2, numericcol=float("-inf")),
-            ]
+            ],
         )
 
         eq_(
-            select(t1.c.numericcol).order_by(t1.c.intcol).execute().fetchall(),
+            connection.execute(
+                select(t1.c.numericcol).order_by(t1.c.intcol)
+            ).fetchall(),
             [(float("inf"),), (float("-inf"),)],
         )
 
@@ -393,16 +397,19 @@ class TypesTest(fixtures.TestBase):
             Column("intcol", Integer),
             Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=True)),
         )
-        t1.create()
-        t1.insert().execute(
+        t1.create(connection)
+        connection.execute(
+            t1.insert(),
             [
                 dict(intcol=1, numericcol=decimal.Decimal("Infinity")),
                 dict(intcol=2, numericcol=decimal.Decimal("-Infinity")),
-            ]
+            ],
         )
 
         eq_(
-            select(t1.c.numericcol).order_by(t1.c.intcol).execute().fetchall(),
+            connection.execute(
+                select(t1.c.numericcol).order_by(t1.c.intcol)
+            ).fetchall(),
             [(decimal.Decimal("Infinity"),), (decimal.Decimal("-Infinity"),)],
         )
 
@@ -422,20 +429,21 @@ class TypesTest(fixtures.TestBase):
             Column("intcol", Integer),
             Column("numericcol", oracle.BINARY_DOUBLE(asdecimal=False)),
         )
-        t1.create()
-        t1.insert().execute(
+        t1.create(connection)
+        connection.execute(
+            t1.insert(),
             [
                 dict(intcol=1, numericcol=float("nan")),
                 dict(intcol=2, numericcol=float("-nan")),
-            ]
+            ],
         )
 
         eq_(
             [
                 tuple(str(col) for col in row)
-                for row in select(t1.c.numericcol)
-                .order_by(t1.c.intcol)
-                .execute()
+                for row in connection.execute(
+                    select(t1.c.numericcol).order_by(t1.c.intcol)
+                )
             ],
             [("nan",), ("nan",)],
         )
@@ -786,7 +794,7 @@ class TypesTest(fixtures.TestBase):
         eq_(connection.execute(raw_table.select()).first(), (1, b("ABCDEF")))
 
     @testing.provide_metadata
-    def test_reflect_nvarchar(self):
+    def test_reflect_nvarchar(self, connection):
         metadata = self.metadata
         Table(
             "tnv",
@@ -794,31 +802,30 @@ class TypesTest(fixtures.TestBase):
             Column("nv_data", sqltypes.NVARCHAR(255)),
             Column("c_data", sqltypes.NCHAR(20)),
         )
-        metadata.create_all()
+        metadata.create_all(connection)
         m2 = MetaData()
-        t2 = Table("tnv", m2, autoload_with=testing.db)
+        t2 = Table("tnv", m2, autoload_with=connection)
         assert isinstance(t2.c.nv_data.type, sqltypes.NVARCHAR)
         assert isinstance(t2.c.c_data.type, sqltypes.NCHAR)
 
         if testing.against("oracle+cx_oracle"):
             assert isinstance(
-                t2.c.nv_data.type.dialect_impl(testing.db.dialect),
+                t2.c.nv_data.type.dialect_impl(connection.dialect),
                 cx_oracle._OracleUnicodeStringNCHAR,
             )
 
             assert isinstance(
-                t2.c.c_data.type.dialect_impl(testing.db.dialect),
+                t2.c.c_data.type.dialect_impl(connection.dialect),
                 cx_oracle._OracleNChar,
             )
 
         data = u("m’a réveillé.")
-        with testing.db.connect() as conn:
-            conn.execute(t2.insert(), dict(nv_data=data, c_data=data))
-            nv_data, c_data = conn.execute(t2.select()).first()
-            eq_(nv_data, data)
-            eq_(c_data, data + (" " * 7))  # char is space padded
-            assert isinstance(nv_data, util.text_type)
-            assert isinstance(c_data, util.text_type)
+        connection.execute(t2.insert(), dict(nv_data=data, c_data=data))
+        nv_data, c_data = connection.execute(t2.select()).first()
+        eq_(nv_data, data)
+        eq_(c_data, data + (" " * 7))  # char is space padded
+        assert isinstance(nv_data, util.text_type)
+        assert isinstance(c_data, util.text_type)
 
     @testing.provide_metadata
     def test_reflect_unicode_no_nvarchar(self):
@@ -1183,7 +1190,7 @@ class SetInputSizesTest(fixtures.TestBase):
         else:
             engine = testing.db
 
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             connection_fairy = conn.connection
             for tab in [t1, t2, t3]:
                 with mock.patch.object(
index 5cea604d686898eb9accd913fd331b9055c6b0bc..3bd8e9da0b214d9db1269f85862b2f93b5e19c0a 100644 (file)
@@ -36,6 +36,7 @@ from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_VALUES
 from sqlalchemy.engine import cursor as _cursor
 from sqlalchemy.engine import engine_from_config
 from sqlalchemy.engine import url
+from sqlalchemy.testing import config
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -51,7 +52,7 @@ from sqlalchemy.testing.assertions import eq_regex
 from sqlalchemy.testing.assertions import ne_
 from sqlalchemy.util import u
 from sqlalchemy.util import ue
-from ...engine import test_execute
+from ...engine import test_deprecations
 
 if True:
     from sqlalchemy.dialects.postgresql.psycopg2 import (
@@ -195,6 +196,20 @@ class ExecuteManyMode(object):
 
     options = None
 
+    @config.fixture()
+    def connection(self):
+        eng = engines.testing_engine(options=self.options)
+
+        conn = eng.connect()
+        trans = conn.begin()
+        try:
+            yield conn
+        finally:
+            if trans.is_active:
+                trans.rollback()
+            conn.close()
+            eng.dispose()
+
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -213,20 +228,12 @@ class ExecuteManyMode(object):
             Column(ue("\u6e2c\u8a66"), Integer),
         )
 
-    def setup(self):
-        super(ExecuteManyMode, self).setup()
-        self.engine = engines.testing_engine(options=self.options)
-
-    def teardown(self):
-        self.engine.dispose()
-        super(ExecuteManyMode, self).teardown()
-
-    def test_insert(self):
+    def test_insert(self, connection):
         from psycopg2 import extras
 
-        values_page_size = self.engine.dialect.executemany_values_page_size
-        batch_page_size = self.engine.dialect.executemany_batch_page_size
-        if self.engine.dialect.executemany_mode & EXECUTEMANY_VALUES:
+        values_page_size = connection.dialect.executemany_values_page_size
+        batch_page_size = connection.dialect.executemany_batch_page_size
+        if connection.dialect.executemany_mode & EXECUTEMANY_VALUES:
             meth = extras.execute_values
             stmt = "INSERT INTO data (x, y) VALUES %s"
             expected_kwargs = {
@@ -234,7 +241,7 @@ class ExecuteManyMode(object):
                 "page_size": values_page_size,
                 "fetch": False,
             }
-        elif self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH:
+        elif connection.dialect.executemany_mode & EXECUTEMANY_BATCH:
             meth = extras.execute_batch
             stmt = "INSERT INTO data (x, y) VALUES (%(x)s, %(y)s)"
             expected_kwargs = {"page_size": batch_page_size}
@@ -244,24 +251,23 @@ class ExecuteManyMode(object):
         with mock.patch.object(
             extras, meth.__name__, side_effect=meth
         ) as mock_exec:
-            with self.engine.connect() as conn:
-                conn.execute(
-                    self.tables.data.insert(),
-                    [
-                        {"x": "x1", "y": "y1"},
-                        {"x": "x2", "y": "y2"},
-                        {"x": "x3", "y": "y3"},
-                    ],
-                )
+            connection.execute(
+                self.tables.data.insert(),
+                [
+                    {"x": "x1", "y": "y1"},
+                    {"x": "x2", "y": "y2"},
+                    {"x": "x3", "y": "y3"},
+                ],
+            )
 
-                eq_(
-                    conn.execute(select(self.tables.data)).fetchall(),
-                    [
-                        (1, "x1", "y1", 5),
-                        (2, "x2", "y2", 5),
-                        (3, "x3", "y3", 5),
-                    ],
-                )
+            eq_(
+                connection.execute(select(self.tables.data)).fetchall(),
+                [
+                    (1, "x1", "y1", 5),
+                    (2, "x2", "y2", 5),
+                    (3, "x3", "y3", 5),
+                ],
+            )
         eq_(
             mock_exec.mock_calls,
             [
@@ -278,14 +284,13 @@ class ExecuteManyMode(object):
             ],
         )
 
-    def test_insert_no_page_size(self):
+    def test_insert_no_page_size(self, connection):
         from psycopg2 import extras
 
-        values_page_size = self.engine.dialect.executemany_values_page_size
-        batch_page_size = self.engine.dialect.executemany_batch_page_size
+        values_page_size = connection.dialect.executemany_values_page_size
+        batch_page_size = connection.dialect.executemany_batch_page_size
 
-        eng = self.engine
-        if eng.dialect.executemany_mode & EXECUTEMANY_VALUES:
+        if connection.dialect.executemany_mode & EXECUTEMANY_VALUES:
             meth = extras.execute_values
             stmt = "INSERT INTO data (x, y) VALUES %s"
             expected_kwargs = {
@@ -293,7 +298,7 @@ class ExecuteManyMode(object):
                 "page_size": values_page_size,
                 "fetch": False,
             }
-        elif eng.dialect.executemany_mode & EXECUTEMANY_BATCH:
+        elif connection.dialect.executemany_mode & EXECUTEMANY_BATCH:
             meth = extras.execute_batch
             stmt = "INSERT INTO data (x, y) VALUES (%(x)s, %(y)s)"
             expected_kwargs = {"page_size": batch_page_size}
@@ -303,15 +308,14 @@ class ExecuteManyMode(object):
         with mock.patch.object(
             extras, meth.__name__, side_effect=meth
         ) as mock_exec:
-            with eng.connect() as conn:
-                conn.execute(
-                    self.tables.data.insert(),
-                    [
-                        {"x": "x1", "y": "y1"},
-                        {"x": "x2", "y": "y2"},
-                        {"x": "x3", "y": "y3"},
-                    ],
-                )
+            connection.execute(
+                self.tables.data.insert(),
+                [
+                    {"x": "x1", "y": "y1"},
+                    {"x": "x2", "y": "y2"},
+                    {"x": "x3", "y": "y3"},
+                ],
+            )
 
         eq_(
             mock_exec.mock_calls,
@@ -356,7 +360,7 @@ class ExecuteManyMode(object):
         with mock.patch.object(
             extras, meth.__name__, side_effect=meth
         ) as mock_exec:
-            with eng.connect() as conn:
+            with eng.begin() as conn:
                 conn.execute(
                     self.tables.data.insert(),
                     [
@@ -398,11 +402,10 @@ class ExecuteManyMode(object):
 
         eq_(connection.execute(table.select()).all(), [(1, 1), (2, 2), (3, 3)])
 
-    def test_update_fallback(self):
+    def test_update_fallback(self, connection):
         from psycopg2 import extras
 
-        batch_page_size = self.engine.dialect.executemany_batch_page_size
-        eng = self.engine
+        batch_page_size = connection.dialect.executemany_batch_page_size
         meth = extras.execute_batch
         stmt = "UPDATE data SET y=%(yval)s WHERE data.x = %(xval)s"
         expected_kwargs = {"page_size": batch_page_size}
@@ -410,18 +413,17 @@ class ExecuteManyMode(object):
         with mock.patch.object(
             extras, meth.__name__, side_effect=meth
         ) as mock_exec:
-            with eng.connect() as conn:
-                conn.execute(
-                    self.tables.data.update()
-                    .where(self.tables.data.c.x == bindparam("xval"))
-                    .values(y=bindparam("yval")),
-                    [
-                        {"xval": "x1", "yval": "y5"},
-                        {"xval": "x3", "yval": "y6"},
-                    ],
-                )
+            connection.execute(
+                self.tables.data.update()
+                .where(self.tables.data.c.x == bindparam("xval"))
+                .values(y=bindparam("yval")),
+                [
+                    {"xval": "x1", "yval": "y5"},
+                    {"xval": "x3", "yval": "y6"},
+                ],
+            )
 
-        if eng.dialect.executemany_mode & EXECUTEMANY_BATCH:
+        if connection.dialect.executemany_mode & EXECUTEMANY_BATCH:
             eq_(
                 mock_exec.mock_calls,
                 [
@@ -439,36 +441,34 @@ class ExecuteManyMode(object):
         else:
             eq_(mock_exec.mock_calls, [])
 
-    def test_not_sane_rowcount(self):
-        self.engine.connect().close()
-        if self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH:
-            assert not self.engine.dialect.supports_sane_multi_rowcount
+    def test_not_sane_rowcount(self, connection):
+        if connection.dialect.executemany_mode & EXECUTEMANY_BATCH:
+            assert not connection.dialect.supports_sane_multi_rowcount
         else:
-            assert self.engine.dialect.supports_sane_multi_rowcount
+            assert connection.dialect.supports_sane_multi_rowcount
 
-    def test_update(self):
-        with self.engine.connect() as conn:
-            conn.execute(
-                self.tables.data.insert(),
-                [
-                    {"x": "x1", "y": "y1"},
-                    {"x": "x2", "y": "y2"},
-                    {"x": "x3", "y": "y3"},
-                ],
-            )
+    def test_update(self, connection):
+        connection.execute(
+            self.tables.data.insert(),
+            [
+                {"x": "x1", "y": "y1"},
+                {"x": "x2", "y": "y2"},
+                {"x": "x3", "y": "y3"},
+            ],
+        )
 
-            conn.execute(
-                self.tables.data.update()
-                .where(self.tables.data.c.x == bindparam("xval"))
-                .values(y=bindparam("yval")),
-                [{"xval": "x1", "yval": "y5"}, {"xval": "x3", "yval": "y6"}],
-            )
-            eq_(
-                conn.execute(
-                    select(self.tables.data).order_by(self.tables.data.c.id)
-                ).fetchall(),
-                [(1, "x1", "y5", 5), (2, "x2", "y2", 5), (3, "x3", "y6", 5)],
-            )
+        connection.execute(
+            self.tables.data.update()
+            .where(self.tables.data.c.x == bindparam("xval"))
+            .values(y=bindparam("yval")),
+            [{"xval": "x1", "yval": "y5"}, {"xval": "x3", "yval": "y6"}],
+        )
+        eq_(
+            connection.execute(
+                select(self.tables.data).order_by(self.tables.data.c.id)
+            ).fetchall(),
+            [(1, "x1", "y5", 5), (2, "x2", "y2", 5), (3, "x3", "y6", 5)],
+        )
 
 
 class ExecutemanyBatchModeTest(ExecuteManyMode, fixtures.TablesTest):
@@ -578,7 +578,7 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
             [(pk,) for pk in range(1 + first_pk, total_rows + first_pk)],
         )
 
-    def test_insert_w_newlines(self):
+    def test_insert_w_newlines(self, connection):
         from psycopg2 import extras
 
         t = self.tables.data
@@ -606,15 +606,14 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
             extras, "execute_values", side_effect=meth
         ) as mock_exec:
 
-            with self.engine.connect() as conn:
-                conn.execute(
-                    ins,
-                    [
-                        {"id": 1, "y": "y1", "z": 1},
-                        {"id": 2, "y": "y2", "z": 2},
-                        {"id": 3, "y": "y3", "z": 3},
-                    ],
-                )
+            connection.execute(
+                ins,
+                [
+                    {"id": 1, "y": "y1", "z": 1},
+                    {"id": 2, "y": "y2", "z": 2},
+                    {"id": 3, "y": "y3", "z": 3},
+                ],
+            )
 
         eq_(
             mock_exec.mock_calls,
@@ -629,12 +628,12 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
                     ),
                     template="(%(id)s, (SELECT 5 \nFROM data), %(y)s, %(z)s)",
                     fetch=False,
-                    page_size=conn.dialect.executemany_values_page_size,
+                    page_size=connection.dialect.executemany_values_page_size,
                 )
             ],
         )
 
-    def test_insert_modified_by_event(self):
+    def test_insert_modified_by_event(self, connection):
         from psycopg2 import extras
 
         t = self.tables.data
@@ -664,33 +663,33 @@ class ExecutemanyValuesInsertsTest(ExecuteManyMode, fixtures.TablesTest):
             extras, "execute_batch", side_effect=meth
         ) as mock_batch:
 
-            with self.engine.connect() as conn:
-
-                # create an event hook that will change the statement to
-                # something else, meaning the dialect has to detect that
-                # insert_single_values_expr is no longer useful
-                @event.listens_for(conn, "before_cursor_execute", retval=True)
-                def before_cursor_execute(
-                    conn, cursor, statement, parameters, context, executemany
-                ):
-                    statement = (
-                        "INSERT INTO data (id, y, z) VALUES "
-                        "(%(id)s, %(y)s, %(z)s)"
-                    )
-                    return statement, parameters
-
-                conn.execute(
-                    ins,
-                    [
-                        {"id": 1, "y": "y1", "z": 1},
-                        {"id": 2, "y": "y2", "z": 2},
-                        {"id": 3, "y": "y3", "z": 3},
-                    ],
+            # create an event hook that will change the statement to
+            # something else, meaning the dialect has to detect that
+            # insert_single_values_expr is no longer useful
+            @event.listens_for(
+                connection, "before_cursor_execute", retval=True
+            )
+            def before_cursor_execute(
+                conn, cursor, statement, parameters, context, executemany
+            ):
+                statement = (
+                    "INSERT INTO data (id, y, z) VALUES "
+                    "(%(id)s, %(y)s, %(z)s)"
                 )
+                return statement, parameters
+
+            connection.execute(
+                ins,
+                [
+                    {"id": 1, "y": "y1", "z": 1},
+                    {"id": 2, "y": "y2", "z": 2},
+                    {"id": 3, "y": "y3", "z": 3},
+                ],
+            )
 
         eq_(mock_values.mock_calls, [])
 
-        if self.engine.dialect.executemany_mode & EXECUTEMANY_BATCH:
+        if connection.dialect.executemany_mode & EXECUTEMANY_BATCH:
             eq_(
                 mock_batch.mock_calls,
                 [
@@ -727,10 +726,10 @@ class ExecutemanyFlagOptionsTest(fixtures.TablesTest):
             ("values_only", EXECUTEMANY_VALUES),
             ("values_plus_batch", EXECUTEMANY_VALUES_PLUS_BATCH),
         ]:
-            self.engine = engines.testing_engine(
+            connection = engines.testing_engine(
                 options={"executemany_mode": opt}
             )
-            is_(self.engine.dialect.executemany_mode, expected)
+            is_(connection.dialect.executemany_mode, expected)
 
     def test_executemany_wrong_flag_options(self):
         for opt in [1, True, "batch_insert"]:
@@ -1082,7 +1081,7 @@ $$ LANGUAGE plpgsql;
         t.create(connection, checkfirst=True)
 
     @testing.provide_metadata
-    def test_schema_roundtrips(self):
+    def test_schema_roundtrips(self, connection):
         meta = self.metadata
         users = Table(
             "users",
@@ -1091,33 +1090,37 @@ $$ LANGUAGE plpgsql;
             Column("name", String(50)),
             schema="test_schema",
         )
-        users.create()
-        users.insert().execute(id=1, name="name1")
-        users.insert().execute(id=2, name="name2")
-        users.insert().execute(id=3, name="name3")
-        users.insert().execute(id=4, name="name4")
+        users.create(connection)
+        connection.execute(users.insert(), dict(id=1, name="name1"))
+        connection.execute(users.insert(), dict(id=2, name="name2"))
+        connection.execute(users.insert(), dict(id=3, name="name3"))
+        connection.execute(users.insert(), dict(id=4, name="name4"))
         eq_(
-            users.select().where(users.c.name == "name2").execute().fetchall(),
+            connection.execute(
+                users.select().where(users.c.name == "name2")
+            ).fetchall(),
             [(2, "name2")],
         )
         eq_(
-            users.select(use_labels=True)
-            .where(users.c.name == "name2")
-            .execute()
-            .fetchall(),
+            connection.execute(
+                users.select().apply_labels().where(users.c.name == "name2")
+            ).fetchall(),
             [(2, "name2")],
         )
-        users.delete().where(users.c.id == 3).execute()
+        connection.execute(users.delete().where(users.c.id == 3))
         eq_(
-            users.select().where(users.c.name == "name3").execute().fetchall(),
+            connection.execute(
+                users.select().where(users.c.name == "name3")
+            ).fetchall(),
             [],
         )
-        users.update().where(users.c.name == "name4").execute(name="newname")
+        connection.execute(
+            users.update().where(users.c.name == "name4"), dict(name="newname")
+        )
         eq_(
-            users.select(use_labels=True)
-            .where(users.c.id == 4)
-            .execute()
-            .fetchall(),
+            connection.execute(
+                users.select().apply_labels().where(users.c.id == 4)
+            ).fetchall(),
             [(4, "newname")],
         )
 
@@ -1233,7 +1236,7 @@ $$ LANGUAGE plpgsql;
             ne_(conn.connection.status, STATUS_IN_TRANSACTION)
 
 
-class AutocommitTextTest(test_execute.AutocommitTextTest):
+class AutocommitTextTest(test_deprecations.AutocommitTextTest):
     __only_on__ = "postgresql"
 
     def test_grant(self):
index 76048784264004f1ac993e5c5977d6310a5c61cd..4e96cc6a217af1e67cb8b2b6b693d3c8d57dce55 100644 (file)
@@ -99,28 +99,29 @@ class OnConflictTest(fixtures.TablesTest):
             ValueError, insert(self.tables.users).on_conflict_do_update
         )
 
-    def test_on_conflict_do_nothing(self):
+    def test_on_conflict_do_nothing(self, connection):
         users = self.tables.users
 
-        with testing.db.connect() as conn:
-            result = conn.execute(
-                insert(users).on_conflict_do_nothing(),
-                dict(id=1, name="name1"),
-            )
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
-
-            result = conn.execute(
-                insert(users).on_conflict_do_nothing(),
-                dict(id=1, name="name2"),
-            )
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
-
-            eq_(
-                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
-                [(1, "name1")],
-            )
+        result = connection.execute(
+            insert(users).on_conflict_do_nothing(),
+            dict(id=1, name="name1"),
+        )
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
+
+        result = connection.execute(
+            insert(users).on_conflict_do_nothing(),
+            dict(id=1, name="name2"),
+        )
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
+
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name1")],
+        )
 
     def test_on_conflict_do_nothing_connectionless(self, connection):
         users = self.tables.users_xtra
@@ -147,95 +148,99 @@ class OnConflictTest(fixtures.TablesTest):
         )
 
     @testing.provide_metadata
-    def test_on_conflict_do_nothing_target(self):
+    def test_on_conflict_do_nothing_target(self, connection):
         users = self.tables.users
 
-        with testing.db.connect() as conn:
-            result = conn.execute(
-                insert(users).on_conflict_do_nothing(
-                    index_elements=users.primary_key.columns
-                ),
-                dict(id=1, name="name1"),
-            )
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
-
-            result = conn.execute(
-                insert(users).on_conflict_do_nothing(
-                    index_elements=users.primary_key.columns
-                ),
-                dict(id=1, name="name2"),
-            )
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
-
-            eq_(
-                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
-                [(1, "name1")],
-            )
-
-    def test_on_conflict_do_update_one(self):
+        result = connection.execute(
+            insert(users).on_conflict_do_nothing(
+                index_elements=users.primary_key.columns
+            ),
+            dict(id=1, name="name1"),
+        )
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
+
+        result = connection.execute(
+            insert(users).on_conflict_do_nothing(
+                index_elements=users.primary_key.columns
+            ),
+            dict(id=1, name="name2"),
+        )
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
+
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name1")],
+        )
+
+    def test_on_conflict_do_update_one(self, connection):
         users = self.tables.users
 
-        with testing.db.connect() as conn:
-            conn.execute(users.insert(), dict(id=1, name="name1"))
+        connection.execute(users.insert(), dict(id=1, name="name1"))
 
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=[users.c.id], set_=dict(name=i.excluded.name)
-            )
-            result = conn.execute(i, dict(id=1, name="name1"))
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.id], set_=dict(name=i.excluded.name)
+        )
+        result = connection.execute(i, dict(id=1, name="name1"))
 
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
 
-            eq_(
-                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
-                [(1, "name1")],
-            )
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name1")],
+        )
 
-    def test_on_conflict_do_update_schema(self):
+    def test_on_conflict_do_update_schema(self, connection):
         users = self.tables.get("%s.users_schema" % config.test_schema)
 
-        with testing.db.connect() as conn:
-            conn.execute(users.insert(), dict(id=1, name="name1"))
+        connection.execute(users.insert(), dict(id=1, name="name1"))
 
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=[users.c.id], set_=dict(name=i.excluded.name)
-            )
-            result = conn.execute(i, dict(id=1, name="name1"))
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.id], set_=dict(name=i.excluded.name)
+        )
+        result = connection.execute(i, dict(id=1, name="name1"))
 
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
 
-            eq_(
-                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
-                [(1, "name1")],
-            )
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name1")],
+        )
 
-    def test_on_conflict_do_update_column_as_key_set(self):
+    def test_on_conflict_do_update_column_as_key_set(self, connection):
         users = self.tables.users
 
-        with testing.db.connect() as conn:
-            conn.execute(users.insert(), dict(id=1, name="name1"))
+        connection.execute(users.insert(), dict(id=1, name="name1"))
 
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=[users.c.id],
-                set_={users.c.name: i.excluded.name},
-            )
-            result = conn.execute(i, dict(id=1, name="name1"))
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.id],
+            set_={users.c.name: i.excluded.name},
+        )
+        result = connection.execute(i, dict(id=1, name="name1"))
 
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
 
-            eq_(
-                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
-                [(1, "name1")],
-            )
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name1")],
+        )
 
-    def test_on_conflict_do_update_clauseelem_as_key_set(self):
+    def test_on_conflict_do_update_clauseelem_as_key_set(self, connection):
         users = self.tables.users
 
         class MyElem(object):
@@ -245,162 +250,165 @@ class OnConflictTest(fixtures.TablesTest):
             def __clause_element__(self):
                 return self.expr
 
-        with testing.db.connect() as conn:
-            conn.execute(
-                users.insert(),
-                {"id": 1, "name": "name1"},
-            )
+        connection.execute(
+            users.insert(),
+            {"id": 1, "name": "name1"},
+        )
 
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=[users.c.id],
-                set_={MyElem(users.c.name): i.excluded.name},
-            ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name1"})
-            result = conn.execute(i)
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.id],
+            set_={MyElem(users.c.name): i.excluded.name},
+        ).values({MyElem(users.c.id): 1, MyElem(users.c.name): "name1"})
+        result = connection.execute(i)
 
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
 
-            eq_(
-                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
-                [(1, "name1")],
-            )
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name1")],
+        )
 
-    def test_on_conflict_do_update_column_as_key_set_schema(self):
+    def test_on_conflict_do_update_column_as_key_set_schema(self, connection):
         users = self.tables.get("%s.users_schema" % config.test_schema)
 
-        with testing.db.connect() as conn:
-            conn.execute(users.insert(), dict(id=1, name="name1"))
+        connection.execute(users.insert(), dict(id=1, name="name1"))
 
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=[users.c.id],
-                set_={users.c.name: i.excluded.name},
-            )
-            result = conn.execute(i, dict(id=1, name="name1"))
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.id],
+            set_={users.c.name: i.excluded.name},
+        )
+        result = connection.execute(i, dict(id=1, name="name1"))
 
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
 
-            eq_(
-                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
-                [(1, "name1")],
-            )
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name1")],
+        )
 
-    def test_on_conflict_do_update_two(self):
+    def test_on_conflict_do_update_two(self, connection):
         users = self.tables.users
 
-        with testing.db.connect() as conn:
-            conn.execute(users.insert(), dict(id=1, name="name1"))
+        connection.execute(users.insert(), dict(id=1, name="name1"))
 
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=[users.c.id],
-                set_=dict(id=i.excluded.id, name=i.excluded.name),
-            )
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.id],
+            set_=dict(id=i.excluded.id, name=i.excluded.name),
+        )
 
-            result = conn.execute(i, dict(id=1, name="name2"))
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
+        result = connection.execute(i, dict(id=1, name="name2"))
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
 
-            eq_(
-                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
-                [(1, "name2")],
-            )
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name2")],
+        )
 
-    def test_on_conflict_do_update_three(self):
+    def test_on_conflict_do_update_three(self, connection):
         users = self.tables.users
 
-        with testing.db.connect() as conn:
-            conn.execute(users.insert(), dict(id=1, name="name1"))
+        connection.execute(users.insert(), dict(id=1, name="name1"))
 
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=users.primary_key.columns,
-                set_=dict(name=i.excluded.name),
-            )
-            result = conn.execute(i, dict(id=1, name="name3"))
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_=dict(name=i.excluded.name),
+        )
+        result = connection.execute(i, dict(id=1, name="name3"))
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
 
-            eq_(
-                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
-                [(1, "name3")],
-            )
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name3")],
+        )
 
-    def test_on_conflict_do_update_four(self):
+    def test_on_conflict_do_update_four(self, connection):
         users = self.tables.users
 
-        with testing.db.connect() as conn:
-            conn.execute(users.insert(), dict(id=1, name="name1"))
+        connection.execute(users.insert(), dict(id=1, name="name1"))
 
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=users.primary_key.columns,
-                set_=dict(id=i.excluded.id, name=i.excluded.name),
-            ).values(id=1, name="name4")
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_=dict(id=i.excluded.id, name=i.excluded.name),
+        ).values(id=1, name="name4")
 
-            result = conn.execute(i)
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
+        result = connection.execute(i)
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
 
-            eq_(
-                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
-                [(1, "name4")],
-            )
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name4")],
+        )
 
-    def test_on_conflict_do_update_five(self):
+    def test_on_conflict_do_update_five(self, connection):
         users = self.tables.users
 
-        with testing.db.connect() as conn:
-            conn.execute(users.insert(), dict(id=1, name="name1"))
+        connection.execute(users.insert(), dict(id=1, name="name1"))
 
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=users.primary_key.columns,
-                set_=dict(id=10, name="I'm a name"),
-            ).values(id=1, name="name4")
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_=dict(id=10, name="I'm a name"),
+        ).values(id=1, name="name4")
 
-            result = conn.execute(i)
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
+        result = connection.execute(i)
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
 
-            eq_(
-                conn.execute(
-                    users.select().where(users.c.id == 10)
-                ).fetchall(),
-                [(10, "I'm a name")],
-            )
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 10)
+            ).fetchall(),
+            [(10, "I'm a name")],
+        )
 
-    def test_on_conflict_do_update_multivalues(self):
+    def test_on_conflict_do_update_multivalues(self, connection):
         users = self.tables.users
 
-        with testing.db.connect() as conn:
-            conn.execute(users.insert(), dict(id=1, name="name1"))
-            conn.execute(users.insert(), dict(id=2, name="name2"))
-
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=users.primary_key.columns,
-                set_=dict(name="updated"),
-                where=(i.excluded.name != "name12"),
-            ).values(
-                [
-                    dict(id=1, name="name11"),
-                    dict(id=2, name="name12"),
-                    dict(id=3, name="name13"),
-                    dict(id=4, name="name14"),
-                ]
-            )
-
-            result = conn.execute(i)
-            eq_(result.inserted_primary_key, (None,))
-            eq_(result.returned_defaults, None)
-
-            eq_(
-                conn.execute(users.select().order_by(users.c.id)).fetchall(),
-                [(1, "updated"), (2, "name2"), (3, "name13"), (4, "name14")],
-            )
+        connection.execute(users.insert(), dict(id=1, name="name1"))
+        connection.execute(users.insert(), dict(id=2, name="name2"))
+
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_=dict(name="updated"),
+            where=(i.excluded.name != "name12"),
+        ).values(
+            [
+                dict(id=1, name="name11"),
+                dict(id=2, name="name12"),
+                dict(id=3, name="name13"),
+                dict(id=4, name="name14"),
+            ]
+        )
+
+        result = connection.execute(i)
+        eq_(result.inserted_primary_key, (None,))
+        eq_(result.returned_defaults, None)
+
+        eq_(
+            connection.execute(users.select().order_by(users.c.id)).fetchall(),
+            [(1, "updated"), (2, "name2"), (3, "name13"), (4, "name14")],
+        )
 
     def _exotic_targets_fixture(self, conn):
         users = self.tables.users_xtra
@@ -429,260 +437,250 @@ class OnConflictTest(fixtures.TablesTest):
             [(1, "name1", "name1@gmail.com", "not")],
         )
 
-    def test_on_conflict_do_update_exotic_targets_two(self):
+    def test_on_conflict_do_update_exotic_targets_two(self, connection):
         users = self.tables.users_xtra
 
-        with testing.db.connect() as conn:
-            self._exotic_targets_fixture(conn)
-            # try primary key constraint: cause an upsert on unique id column
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=users.primary_key.columns,
-                set_=dict(
-                    name=i.excluded.name, login_email=i.excluded.login_email
-                ),
-            )
-            result = conn.execute(
-                i,
-                dict(
-                    id=1,
-                    name="name2",
-                    login_email="name1@gmail.com",
-                    lets_index_this="not",
-                ),
-            )
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, None)
-
-            eq_(
-                conn.execute(users.select().where(users.c.id == 1)).fetchall(),
-                [(1, "name2", "name1@gmail.com", "not")],
-            )
-
-    def test_on_conflict_do_update_exotic_targets_three(self):
+        self._exotic_targets_fixture(connection)
+        # try primary key constraint: cause an upsert on unique id column
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=users.primary_key.columns,
+            set_=dict(
+                name=i.excluded.name, login_email=i.excluded.login_email
+            ),
+        )
+        result = connection.execute(
+            i,
+            dict(
+                id=1,
+                name="name2",
+                login_email="name1@gmail.com",
+                lets_index_this="not",
+            ),
+        )
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, None)
+
+        eq_(
+            connection.execute(
+                users.select().where(users.c.id == 1)
+            ).fetchall(),
+            [(1, "name2", "name1@gmail.com", "not")],
+        )
+
+    def test_on_conflict_do_update_exotic_targets_three(self, connection):
         users = self.tables.users_xtra
 
-        with testing.db.connect() as conn:
-            self._exotic_targets_fixture(conn)
-            # try unique constraint: cause an upsert on target
-            # login_email, not id
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                constraint=self.unique_constraint,
-                set_=dict(
-                    id=i.excluded.id,
-                    name=i.excluded.name,
-                    login_email=i.excluded.login_email,
-                ),
-            )
-            # note: lets_index_this value totally ignored in SET clause.
-            result = conn.execute(
-                i,
-                dict(
-                    id=42,
-                    name="nameunique",
-                    login_email="name2@gmail.com",
-                    lets_index_this="unique",
-                ),
-            )
-            eq_(result.inserted_primary_key, (42,))
-            eq_(result.returned_defaults, None)
-
-            eq_(
-                conn.execute(
-                    users.select().where(
-                        users.c.login_email == "name2@gmail.com"
-                    )
-                ).fetchall(),
-                [(42, "nameunique", "name2@gmail.com", "not")],
-            )
-
-    def test_on_conflict_do_update_exotic_targets_four(self):
+        self._exotic_targets_fixture(connection)
+        # try unique constraint: cause an upsert on target
+        # login_email, not id
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            constraint=self.unique_constraint,
+            set_=dict(
+                id=i.excluded.id,
+                name=i.excluded.name,
+                login_email=i.excluded.login_email,
+            ),
+        )
+        # note: lets_index_this value totally ignored in SET clause.
+        result = connection.execute(
+            i,
+            dict(
+                id=42,
+                name="nameunique",
+                login_email="name2@gmail.com",
+                lets_index_this="unique",
+            ),
+        )
+        eq_(result.inserted_primary_key, (42,))
+        eq_(result.returned_defaults, None)
+
+        eq_(
+            connection.execute(
+                users.select().where(users.c.login_email == "name2@gmail.com")
+            ).fetchall(),
+            [(42, "nameunique", "name2@gmail.com", "not")],
+        )
+
+    def test_on_conflict_do_update_exotic_targets_four(self, connection):
         users = self.tables.users_xtra
 
-        with testing.db.connect() as conn:
-            self._exotic_targets_fixture(conn)
-            # try unique constraint by name: cause an
-            # upsert on target login_email, not id
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                constraint=self.unique_constraint.name,
-                set_=dict(
-                    id=i.excluded.id,
-                    name=i.excluded.name,
-                    login_email=i.excluded.login_email,
-                ),
-            )
-            # note: lets_index_this value totally ignored in SET clause.
-
-            result = conn.execute(
-                i,
-                dict(
-                    id=43,
-                    name="nameunique2",
-                    login_email="name2@gmail.com",
-                    lets_index_this="unique",
-                ),
-            )
-            eq_(result.inserted_primary_key, (43,))
-            eq_(result.returned_defaults, None)
-
-            eq_(
-                conn.execute(
-                    users.select().where(
-                        users.c.login_email == "name2@gmail.com"
-                    )
-                ).fetchall(),
-                [(43, "nameunique2", "name2@gmail.com", "not")],
-            )
-
-    def test_on_conflict_do_update_exotic_targets_four_no_pk(self):
+        self._exotic_targets_fixture(connection)
+        # try unique constraint by name: cause an
+        # upsert on target login_email, not id
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            constraint=self.unique_constraint.name,
+            set_=dict(
+                id=i.excluded.id,
+                name=i.excluded.name,
+                login_email=i.excluded.login_email,
+            ),
+        )
+        # note: lets_index_this value totally ignored in SET clause.
+
+        result = connection.execute(
+            i,
+            dict(
+                id=43,
+                name="nameunique2",
+                login_email="name2@gmail.com",
+                lets_index_this="unique",
+            ),
+        )
+        eq_(result.inserted_primary_key, (43,))
+        eq_(result.returned_defaults, None)
+
+        eq_(
+            connection.execute(
+                users.select().where(users.c.login_email == "name2@gmail.com")
+            ).fetchall(),
+            [(43, "nameunique2", "name2@gmail.com", "not")],
+        )
+
+    def test_on_conflict_do_update_exotic_targets_four_no_pk(self, connection):
         users = self.tables.users_xtra
 
-        with testing.db.connect() as conn:
-            self._exotic_targets_fixture(conn)
-            # try unique constraint by name: cause an
-            # upsert on target login_email, not id
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=[users.c.login_email],
-                set_=dict(
-                    id=i.excluded.id,
-                    name=i.excluded.name,
-                    login_email=i.excluded.login_email,
-                ),
-            )
-
-            result = conn.execute(
-                i, dict(name="name3", login_email="name1@gmail.com")
-            )
-            eq_(result.inserted_primary_key, (1,))
-            eq_(result.returned_defaults, (1,))
-
-            eq_(
-                conn.execute(users.select().order_by(users.c.id)).fetchall(),
-                [
-                    (1, "name3", "name1@gmail.com", "not"),
-                    (2, "name2", "name2@gmail.com", "not"),
-                ],
-            )
-
-    def test_on_conflict_do_update_exotic_targets_five(self):
+        self._exotic_targets_fixture(connection)
+        # try unique constraint by name: cause an
+        # upsert on target login_email, not id
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.login_email],
+            set_=dict(
+                id=i.excluded.id,
+                name=i.excluded.name,
+                login_email=i.excluded.login_email,
+            ),
+        )
+
+        result = connection.execute(
+            i, dict(name="name3", login_email="name1@gmail.com")
+        )
+        eq_(result.inserted_primary_key, (1,))
+        eq_(result.returned_defaults, (1,))
+
+        eq_(
+            connection.execute(users.select().order_by(users.c.id)).fetchall(),
+            [
+                (1, "name3", "name1@gmail.com", "not"),
+                (2, "name2", "name2@gmail.com", "not"),
+            ],
+        )
+
+    def test_on_conflict_do_update_exotic_targets_five(self, connection):
         users = self.tables.users_xtra
 
-        with testing.db.connect() as conn:
-            self._exotic_targets_fixture(conn)
-            # try bogus index
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=self.bogus_index.columns,
-                index_where=self.bogus_index.dialect_options["postgresql"][
-                    "where"
-                ],
-                set_=dict(
-                    name=i.excluded.name, login_email=i.excluded.login_email
-                ),
-            )
-
-            assert_raises(
-                exc.ProgrammingError,
-                conn.execute,
-                i,
-                dict(
-                    id=1,
-                    name="namebogus",
-                    login_email="bogus@gmail.com",
-                    lets_index_this="bogus",
-                ),
-            )
-
-    def test_on_conflict_do_update_exotic_targets_six(self):
+        self._exotic_targets_fixture(connection)
+        # try bogus index
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=self.bogus_index.columns,
+            index_where=self.bogus_index.dialect_options["postgresql"][
+                "where"
+            ],
+            set_=dict(
+                name=i.excluded.name, login_email=i.excluded.login_email
+            ),
+        )
+
+        assert_raises(
+            exc.ProgrammingError,
+            connection.execute,
+            i,
+            dict(
+                id=1,
+                name="namebogus",
+                login_email="bogus@gmail.com",
+                lets_index_this="bogus",
+            ),
+        )
+
+    def test_on_conflict_do_update_exotic_targets_six(self, connection):
         users = self.tables.users_xtra
 
-        with testing.db.connect() as conn:
-            conn.execute(
-                insert(users),
+        connection.execute(
+            insert(users),
+            dict(
+                id=1,
+                name="name1",
+                login_email="mail1@gmail.com",
+                lets_index_this="unique_name",
+            ),
+        )
+
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=self.unique_partial_index.columns,
+            index_where=self.unique_partial_index.dialect_options[
+                "postgresql"
+            ]["where"],
+            set_=dict(
+                name=i.excluded.name, login_email=i.excluded.login_email
+            ),
+        )
+
+        connection.execute(
+            i,
+            [
                 dict(
-                    id=1,
                     name="name1",
-                    login_email="mail1@gmail.com",
+                    login_email="mail2@gmail.com",
                     lets_index_this="unique_name",
-                ),
-            )
-
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=self.unique_partial_index.columns,
-                index_where=self.unique_partial_index.dialect_options[
-                    "postgresql"
-                ]["where"],
-                set_=dict(
-                    name=i.excluded.name, login_email=i.excluded.login_email
-                ),
-            )
-
-            conn.execute(
-                i,
-                [
-                    dict(
-                        name="name1",
-                        login_email="mail2@gmail.com",
-                        lets_index_this="unique_name",
-                    )
-                ],
-            )
-
-            eq_(
-                conn.execute(users.select()).fetchall(),
-                [(1, "name1", "mail2@gmail.com", "unique_name")],
-            )
-
-    def test_on_conflict_do_update_no_row_actually_affected(self):
+                )
+            ],
+        )
+
+        eq_(
+            connection.execute(users.select()).fetchall(),
+            [(1, "name1", "mail2@gmail.com", "unique_name")],
+        )
+
+    def test_on_conflict_do_update_no_row_actually_affected(self, connection):
         users = self.tables.users_xtra
 
-        with testing.db.connect() as conn:
-            self._exotic_targets_fixture(conn)
-            i = insert(users)
-            i = i.on_conflict_do_update(
-                index_elements=[users.c.login_email],
-                set_=dict(name="new_name"),
-                where=(i.excluded.name == "other_name"),
-            )
-            result = conn.execute(
-                i, dict(name="name2", login_email="name1@gmail.com")
-            )
-
-            eq_(result.returned_defaults, None)
-            eq_(result.inserted_primary_key, None)
-
-            eq_(
-                conn.execute(users.select()).fetchall(),
-                [
-                    (1, "name1", "name1@gmail.com", "not"),
-                    (2, "name2", "name2@gmail.com", "not"),
-                ],
-            )
-
-    def test_on_conflict_do_update_special_types_in_set(self):
+        self._exotic_targets_fixture(connection)
+        i = insert(users)
+        i = i.on_conflict_do_update(
+            index_elements=[users.c.login_email],
+            set_=dict(name="new_name"),
+            where=(i.excluded.name == "other_name"),
+        )
+        result = connection.execute(
+            i, dict(name="name2", login_email="name1@gmail.com")
+        )
+
+        eq_(result.returned_defaults, None)
+        eq_(result.inserted_primary_key, None)
+
+        eq_(
+            connection.execute(users.select()).fetchall(),
+            [
+                (1, "name1", "name1@gmail.com", "not"),
+                (2, "name2", "name2@gmail.com", "not"),
+            ],
+        )
+
+    def test_on_conflict_do_update_special_types_in_set(self, connection):
         bind_targets = self.tables.bind_targets
 
-        with testing.db.connect() as conn:
-            i = insert(bind_targets)
-            conn.execute(i, {"id": 1, "data": "initial data"})
-
-            eq_(
-                conn.scalar(sql.select(bind_targets.c.data)),
-                "initial data processed",
-            )
-
-            i = insert(bind_targets)
-            i = i.on_conflict_do_update(
-                index_elements=[bind_targets.c.id],
-                set_=dict(data="new updated data"),
-            )
-            conn.execute(i, {"id": 1, "data": "new inserted data"})
-
-            eq_(
-                conn.scalar(sql.select(bind_targets.c.data)),
-                "new updated data processed",
-            )
+        i = insert(bind_targets)
+        connection.execute(i, {"id": 1, "data": "initial data"})
+
+        eq_(
+            connection.scalar(sql.select(bind_targets.c.data)),
+            "initial data processed",
+        )
+
+        i = insert(bind_targets)
+        i = i.on_conflict_do_update(
+            index_elements=[bind_targets.c.id],
+            set_=dict(data="new updated data"),
+        )
+        connection.execute(i, {"id": 1, "data": "new inserted data"})
+
+        eq_(
+            connection.scalar(sql.select(bind_targets.c.data)),
+            "new updated data processed",
+        )
index c959acf359b9a5e76dc5feff9162a8cd65b967c4..94af168eee0e925dfc9b85a1ee80f02b97615c89 100644 (file)
@@ -35,30 +35,32 @@ from sqlalchemy.testing.assertsql import CursorSQL
 from sqlalchemy.testing.assertsql import DialectSQL
 
 
-matchtable = cattable = None
-
-
 class InsertTest(fixtures.TestBase, AssertsExecutionResults):
 
     __only_on__ = "postgresql"
     __backend__ = True
 
-    @classmethod
-    def setup_class(cls):
-        cls.metadata = MetaData(testing.db)
+    def setup(self):
+        self.metadata = MetaData()
 
     def teardown(self):
-        self.metadata.drop_all()
-        self.metadata.clear()
+        with testing.db.begin() as conn:
+            self.metadata.drop_all(conn)
+
+    @testing.combinations((False,), (True,))
+    def test_foreignkey_missing_insert(self, implicit_returning):
+        engine = engines.testing_engine(
+            options={"implicit_returning": implicit_returning}
+        )
 
-    def test_foreignkey_missing_insert(self):
         Table("t1", self.metadata, Column("id", Integer, primary_key=True))
         t2 = Table(
             "t2",
             self.metadata,
             Column("id", Integer, ForeignKey("t1.id"), primary_key=True),
         )
-        self.metadata.create_all()
+
+        self.metadata.create_all(engine)
 
         # want to ensure that "null value in column "id" violates not-
         # null constraint" is raised (IntegrityError on psycoopg2, but
@@ -67,19 +69,13 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
         # the latter corresponds to autoincrement behavior, which is not
         # the case here due to the foreign key.
 
-        for eng in [
-            engines.testing_engine(options={"implicit_returning": False}),
-            engines.testing_engine(options={"implicit_returning": True}),
-        ]:
-            with expect_warnings(
-                ".*has no Python-side or server-side default.*"
-            ):
-                with eng.connect() as conn:
-                    assert_raises(
-                        (exc.IntegrityError, exc.ProgrammingError),
-                        conn.execute,
-                        t2.insert(),
-                    )
+        with expect_warnings(".*has no Python-side or server-side default.*"):
+            with engine.begin() as conn:
+                assert_raises(
+                    (exc.IntegrityError, exc.ProgrammingError),
+                    conn.execute,
+                    t2.insert(),
+                )
 
     def test_sequence_insert(self):
         table = Table(
@@ -88,7 +84,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             Column("id", Integer, Sequence("my_seq"), primary_key=True),
             Column("data", String(30)),
         )
-        self.metadata.create_all()
+        self.metadata.create_all(testing.db)
         self._assert_data_with_sequence(table, "my_seq")
 
     @testing.requires.returning
@@ -99,7 +95,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             Column("id", Integer, Sequence("my_seq"), primary_key=True),
             Column("data", String(30)),
         )
-        self.metadata.create_all()
+        self.metadata.create_all(testing.db)
         self._assert_data_with_sequence_returning(table, "my_seq")
 
     def test_opt_sequence_insert(self):
@@ -114,7 +110,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             ),
             Column("data", String(30)),
         )
-        self.metadata.create_all()
+        self.metadata.create_all(testing.db)
         self._assert_data_autoincrement(table)
 
     @testing.requires.returning
@@ -130,7 +126,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             ),
             Column("data", String(30)),
         )
-        self.metadata.create_all()
+        self.metadata.create_all(testing.db)
         self._assert_data_autoincrement_returning(table)
 
     def test_autoincrement_insert(self):
@@ -140,7 +136,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             Column("id", Integer, primary_key=True),
             Column("data", String(30)),
         )
-        self.metadata.create_all()
+        self.metadata.create_all(testing.db)
         self._assert_data_autoincrement(table)
 
     @testing.requires.returning
@@ -151,7 +147,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             Column("id", Integer, primary_key=True),
             Column("data", String(30)),
         )
-        self.metadata.create_all()
+        self.metadata.create_all(testing.db)
         self._assert_data_autoincrement_returning(table)
 
     def test_noautoincrement_insert(self):
@@ -161,7 +157,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             Column("id", Integer, primary_key=True, autoincrement=False),
             Column("data", String(30)),
         )
-        self.metadata.create_all()
+        self.metadata.create_all(testing.db)
         self._assert_data_noautoincrement(table)
 
     def _assert_data_autoincrement(self, table):
@@ -169,7 +165,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
 
         with self.sql_execution_asserter(engine) as asserter:
 
-            with engine.connect() as conn:
+            with engine.begin() as conn:
                 # execute with explicit id
 
                 r = conn.execute(table.insert(), {"id": 30, "data": "d1"})
@@ -226,7 +222,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             ),
         )
 
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             eq_(
                 conn.execute(table.select()).fetchall(),
                 [
@@ -250,7 +246,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
         table = Table(table.name, m2, autoload_with=engine)
 
         with self.sql_execution_asserter(engine) as asserter:
-            with engine.connect() as conn:
+            with engine.begin() as conn:
                 conn.execute(table.insert(), {"id": 30, "data": "d1"})
                 r = conn.execute(table.insert(), {"data": "d2"})
                 eq_(r.inserted_primary_key, (5,))
@@ -288,7 +284,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
                 "INSERT INTO testtable (data) VALUES (:data)", [{"data": "d8"}]
             ),
         )
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             eq_(
                 conn.execute(table.select()).fetchall(),
                 [
@@ -308,7 +304,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
         engine = engines.testing_engine(options={"implicit_returning": True})
 
         with self.sql_execution_asserter(engine) as asserter:
-            with engine.connect() as conn:
+            with engine.begin() as conn:
 
                 # execute with explicit id
 
@@ -367,7 +363,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             ),
         )
 
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             eq_(
                 conn.execute(table.select()).fetchall(),
                 [
@@ -390,7 +386,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
         table = Table(table.name, m2, autoload_with=engine)
 
         with self.sql_execution_asserter(engine) as asserter:
-            with engine.connect() as conn:
+            with engine.begin() as conn:
                 conn.execute(table.insert(), {"id": 30, "data": "d1"})
                 r = conn.execute(table.insert(), {"data": "d2"})
                 eq_(r.inserted_primary_key, (5,))
@@ -430,7 +426,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             ),
         )
 
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             eq_(
                 conn.execute(table.select()).fetchall(),
                 [
@@ -450,7 +446,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
         engine = engines.testing_engine(options={"implicit_returning": False})
 
         with self.sql_execution_asserter(engine) as asserter:
-            with engine.connect() as conn:
+            with engine.begin() as conn:
                 conn.execute(table.insert(), {"id": 30, "data": "d1"})
                 conn.execute(table.insert(), {"data": "d2"})
                 conn.execute(
@@ -491,7 +487,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
                 [{"data": "d8"}],
             ),
         )
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             eq_(
                 conn.execute(table.select()).fetchall(),
                 [
@@ -513,7 +509,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
         engine = engines.testing_engine(options={"implicit_returning": True})
 
         with self.sql_execution_asserter(engine) as asserter:
-            with engine.connect() as conn:
+            with engine.begin() as conn:
                 conn.execute(table.insert(), {"id": 30, "data": "d1"})
                 conn.execute(table.insert(), {"data": "d2"})
                 conn.execute(
@@ -555,7 +551,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             ),
         )
 
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             eq_(
                 conn.execute(table.select()).fetchall(),
                 [
@@ -578,9 +574,12 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
 
         # turning off the cache because we are checking for compile-time
         # warnings
-        with engine.connect().execution_options(compiled_cache=None) as conn:
+        engine = engine.execution_options(compiled_cache=None)
+
+        with engine.begin() as conn:
             conn.execute(table.insert(), {"id": 30, "data": "d1"})
 
+        with engine.begin() as conn:
             with expect_warnings(
                 ".*has no Python-side or server-side default.*"
             ):
@@ -590,6 +589,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
                     table.insert(),
                     {"data": "d2"},
                 )
+
+        with engine.begin() as conn:
             with expect_warnings(
                 ".*has no Python-side or server-side default.*"
             ):
@@ -599,6 +600,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
                     table.insert(),
                     [{"data": "d2"}, {"data": "d3"}],
                 )
+
+        with engine.begin() as conn:
             with expect_warnings(
                 ".*has no Python-side or server-side default.*"
             ):
@@ -608,6 +611,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
                     table.insert(),
                     {"data": "d2"},
                 )
+
+        with engine.begin() as conn:
             with expect_warnings(
                 ".*has no Python-side or server-side default.*"
             ):
@@ -618,6 +623,7 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
                     [{"data": "d2"}, {"data": "d3"}],
                 )
 
+        with engine.begin() as conn:
             conn.execute(
                 table.insert(),
                 [{"id": 31, "data": "d2"}, {"id": 32, "data": "d3"}],
@@ -634,9 +640,10 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
 
         m2 = MetaData()
         table = Table(table.name, m2, autoload_with=engine)
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             conn.execute(table.insert(), {"id": 30, "data": "d1"})
 
+        with engine.begin() as conn:
             with expect_warnings(
                 ".*has no Python-side or server-side default.*"
             ):
@@ -646,6 +653,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
                     table.insert(),
                     {"data": "d2"},
                 )
+
+        with engine.begin() as conn:
             with expect_warnings(
                 ".*has no Python-side or server-side default.*"
             ):
@@ -655,6 +664,8 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
                     table.insert(),
                     [{"data": "d2"}, {"data": "d3"}],
                 )
+
+        with engine.begin() as conn:
             conn.execute(
                 table.insert(),
                 [{"id": 31, "data": "d2"}, {"id": 32, "data": "d3"}],
@@ -666,36 +677,40 @@ class InsertTest(fixtures.TestBase, AssertsExecutionResults):
             )
 
 
-class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
+class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
 
     __only_on__ = "postgresql >= 8.3"
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
-        global metadata, cattable, matchtable
-        metadata = MetaData(testing.db)
-        cattable = Table(
+    def define_tables(cls, metadata):
+        Table(
             "cattable",
             metadata,
             Column("id", Integer, primary_key=True),
             Column("description", String(50)),
         )
-        matchtable = Table(
+        Table(
             "matchtable",
             metadata,
             Column("id", Integer, primary_key=True),
             Column("title", String(200)),
             Column("category_id", Integer, ForeignKey("cattable.id")),
         )
-        metadata.create_all()
-        cattable.insert().execute(
+
+    @classmethod
+    def insert_data(cls, connection):
+        cattable, matchtable = cls.tables("cattable", "matchtable")
+
+        connection.execute(
+            cattable.insert(),
             [
                 {"id": 1, "description": "Python"},
                 {"id": 2, "description": "Ruby"},
-            ]
+            ],
         )
-        matchtable.insert().execute(
+        connection.execute(
+            matchtable.insert(),
             [
                 {
                     "id": 1,
@@ -714,15 +729,12 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
                     "category_id": 1,
                 },
                 {"id": 5, "title": "Python in a Nutshell", "category_id": 1},
-            ]
+            ],
         )
 
-    @classmethod
-    def teardown_class(cls):
-        metadata.drop_all()
-
     @testing.requires.pyformat_paramstyle
     def test_expression_pyformat(self):
+        matchtable = self.tables.matchtable
         self.assert_compile(
             matchtable.c.title.match("somstr"),
             "matchtable.title @@ to_tsquery(%(title_1)s" ")",
@@ -730,51 +742,47 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
 
     @testing.requires.format_paramstyle
     def test_expression_positional(self):
+        matchtable = self.tables.matchtable
         self.assert_compile(
             matchtable.c.title.match("somstr"),
             "matchtable.title @@ to_tsquery(%s)",
         )
 
-    def test_simple_match(self):
-        results = (
+    def test_simple_match(self, connection):
+        matchtable = self.tables.matchtable
+        results = connection.execute(
             matchtable.select()
             .where(matchtable.c.title.match("python"))
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([2, 5], [r.id for r in results])
 
-    def test_not_match(self):
-        results = (
+    def test_not_match(self, connection):
+        matchtable = self.tables.matchtable
+        results = connection.execute(
             matchtable.select()
             .where(~matchtable.c.title.match("python"))
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([1, 3, 4], [r.id for r in results])
 
-    def test_simple_match_with_apostrophe(self):
-        results = (
-            matchtable.select()
-            .where(matchtable.c.title.match("Matz's"))
-            .execute()
-            .fetchall()
-        )
+    def test_simple_match_with_apostrophe(self, connection):
+        matchtable = self.tables.matchtable
+        results = connection.execute(
+            matchtable.select().where(matchtable.c.title.match("Matz's"))
+        ).fetchall()
         eq_([3], [r.id for r in results])
 
-    def test_simple_derivative_match(self):
-        results = (
-            matchtable.select()
-            .where(matchtable.c.title.match("nutshells"))
-            .execute()
-            .fetchall()
-        )
+    def test_simple_derivative_match(self, connection):
+        matchtable = self.tables.matchtable
+        results = connection.execute(
+            matchtable.select().where(matchtable.c.title.match("nutshells"))
+        ).fetchall()
         eq_([5], [r.id for r in results])
 
-    def test_or_match(self):
-        results1 = (
+    def test_or_match(self, connection):
+        matchtable = self.tables.matchtable
+        results1 = connection.execute(
             matchtable.select()
             .where(
                 or_(
@@ -783,42 +791,36 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
                 )
             )
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([3, 5], [r.id for r in results1])
-        results2 = (
+        results2 = connection.execute(
             matchtable.select()
             .where(matchtable.c.title.match("nutshells | rubies"))
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([3, 5], [r.id for r in results2])
 
-    def test_and_match(self):
-        results1 = (
-            matchtable.select()
-            .where(
+    def test_and_match(self, connection):
+        matchtable = self.tables.matchtable
+        results1 = connection.execute(
+            matchtable.select().where(
                 and_(
                     matchtable.c.title.match("python"),
                     matchtable.c.title.match("nutshells"),
                 )
             )
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([5], [r.id for r in results1])
-        results2 = (
-            matchtable.select()
-            .where(matchtable.c.title.match("python & nutshells"))
-            .execute()
-            .fetchall()
-        )
+        results2 = connection.execute(
+            matchtable.select().where(
+                matchtable.c.title.match("python & nutshells")
+            )
+        ).fetchall()
         eq_([5], [r.id for r in results2])
 
-    def test_match_across_joins(self):
-        results = (
+    def test_match_across_joins(self, connection):
+        cattable, matchtable = self.tables("cattable", "matchtable")
+        results = connection.execute(
             matchtable.select()
             .where(
                 and_(
@@ -830,9 +832,7 @@ class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
                 )
             )
             .order_by(matchtable.c.id)
-            .execute()
-            .fetchall()
-        )
+        ).fetchall()
         eq_([1, 3, 5], [r.id for r in results])
 
 
index 4de4d88e3140648f399b6284d6f4810fb753f5f3..824f6cd36dcab61948d7f2fa6a604599f9db1f3c 100644 (file)
@@ -291,63 +291,64 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
 
     @classmethod
     def setup_class(cls):
-        con = testing.db.connect()
-        for ddl in [
-            'CREATE SCHEMA "SomeSchema"',
-            "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42",
-            "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0",
-            "CREATE TYPE testtype AS ENUM ('test')",
-            "CREATE DOMAIN enumdomain AS testtype",
-            "CREATE DOMAIN arraydomain AS INTEGER[]",
-            'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0',
-        ]:
-            try:
-                con.exec_driver_sql(ddl)
-            except exc.DBAPIError as e:
-                if "already exists" not in str(e):
-                    raise e
-        con.exec_driver_sql(
-            "CREATE TABLE testtable (question integer, answer " "testdomain)"
-        )
-        con.exec_driver_sql(
-            "CREATE TABLE test_schema.testtable(question "
-            "integer, answer test_schema.testdomain, anything "
-            "integer)"
-        )
-        con.exec_driver_sql(
-            "CREATE TABLE crosschema (question integer, answer "
-            "test_schema.testdomain)"
-        )
+        with testing.db.begin() as con:
+            for ddl in [
+                'CREATE SCHEMA "SomeSchema"',
+                "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42",
+                "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0",
+                "CREATE TYPE testtype AS ENUM ('test')",
+                "CREATE DOMAIN enumdomain AS testtype",
+                "CREATE DOMAIN arraydomain AS INTEGER[]",
+                'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0',
+            ]:
+                try:
+                    con.exec_driver_sql(ddl)
+                except exc.DBAPIError as e:
+                    if "already exists" not in str(e):
+                        raise e
+            con.exec_driver_sql(
+                "CREATE TABLE testtable (question integer, answer "
+                "testdomain)"
+            )
+            con.exec_driver_sql(
+                "CREATE TABLE test_schema.testtable(question "
+                "integer, answer test_schema.testdomain, anything "
+                "integer)"
+            )
+            con.exec_driver_sql(
+                "CREATE TABLE crosschema (question integer, answer "
+                "test_schema.testdomain)"
+            )
 
-        con.exec_driver_sql(
-            "CREATE TABLE enum_test (id integer, data enumdomain)"
-        )
+            con.exec_driver_sql(
+                "CREATE TABLE enum_test (id integer, data enumdomain)"
+            )
 
-        con.exec_driver_sql(
-            "CREATE TABLE array_test (id integer, data arraydomain)"
-        )
+            con.exec_driver_sql(
+                "CREATE TABLE array_test (id integer, data arraydomain)"
+            )
 
-        con.exec_driver_sql(
-            "CREATE TABLE quote_test "
-            '(id integer, data "SomeSchema"."Quoted.Domain")'
-        )
+            con.exec_driver_sql(
+                "CREATE TABLE quote_test "
+                '(id integer, data "SomeSchema"."Quoted.Domain")'
+            )
 
     @classmethod
     def teardown_class(cls):
-        con = testing.db.connect()
-        con.exec_driver_sql("DROP TABLE testtable")
-        con.exec_driver_sql("DROP TABLE test_schema.testtable")
-        con.exec_driver_sql("DROP TABLE crosschema")
-        con.exec_driver_sql("DROP TABLE quote_test")
-        con.exec_driver_sql("DROP DOMAIN testdomain")
-        con.exec_driver_sql("DROP DOMAIN test_schema.testdomain")
-        con.exec_driver_sql("DROP TABLE enum_test")
-        con.exec_driver_sql("DROP DOMAIN enumdomain")
-        con.exec_driver_sql("DROP TYPE testtype")
-        con.exec_driver_sql("DROP TABLE array_test")
-        con.exec_driver_sql("DROP DOMAIN arraydomain")
-        con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
-        con.exec_driver_sql('DROP SCHEMA "SomeSchema"')
+        with testing.db.begin() as con:
+            con.exec_driver_sql("DROP TABLE testtable")
+            con.exec_driver_sql("DROP TABLE test_schema.testtable")
+            con.exec_driver_sql("DROP TABLE crosschema")
+            con.exec_driver_sql("DROP TABLE quote_test")
+            con.exec_driver_sql("DROP DOMAIN testdomain")
+            con.exec_driver_sql("DROP DOMAIN test_schema.testdomain")
+            con.exec_driver_sql("DROP TABLE enum_test")
+            con.exec_driver_sql("DROP DOMAIN enumdomain")
+            con.exec_driver_sql("DROP TYPE testtype")
+            con.exec_driver_sql("DROP TABLE array_test")
+            con.exec_driver_sql("DROP DOMAIN arraydomain")
+            con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
+            con.exec_driver_sql('DROP SCHEMA "SomeSchema"')
 
     def test_table_is_reflected(self):
         metadata = MetaData()
@@ -486,7 +487,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             Column("id", Integer, primary_key=True),
             Column("ref", Integer, ForeignKey("subject.id$")),
         )
-        meta1.create_all()
+        meta1.create_all(testing.db)
         meta2 = MetaData()
         subject = Table("subject", meta2, autoload_with=testing.db)
         referer = Table("referer", meta2, autoload_with=testing.db)
@@ -523,9 +524,11 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
         with testing.db.begin() as conn:
             r = conn.execute(t2.insert())
             eq_(r.inserted_primary_key, (1,))
-        testing.db.connect().execution_options(
-            autocommit=True
-        ).exec_driver_sql("alter table t_id_seq rename to foobar_id_seq")
+
+        with testing.db.begin() as conn:
+            conn.exec_driver_sql(
+                "alter table t_id_seq rename to foobar_id_seq"
+            )
         m3 = MetaData()
         t3 = Table("t", m3, autoload_with=testing.db, implicit_returning=False)
         eq_(
@@ -545,10 +548,12 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             Column("id", Integer, primary_key=True),
             Column("x", Integer),
         )
-        metadata.create_all()
-        testing.db.connect().execution_options(
-            autocommit=True
-        ).exec_driver_sql("alter table t alter column id type varchar(50)")
+        metadata.create_all(testing.db)
+
+        with testing.db.begin() as conn:
+            conn.exec_driver_sql(
+                "alter table t alter column id type varchar(50)"
+            )
         m2 = MetaData()
         t2 = Table("t", m2, autoload_with=testing.db)
         eq_(t2.c.id.autoincrement, False)
@@ -558,10 +563,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
     def test_renamed_pk_reflection(self):
         metadata = self.metadata
         Table("t", metadata, Column("id", Integer, primary_key=True))
-        metadata.create_all()
-        testing.db.connect().execution_options(
-            autocommit=True
-        ).exec_driver_sql("alter table t rename id to t_id")
+        metadata.create_all(testing.db)
+        with testing.db.begin() as conn:
+            conn.exec_driver_sql("alter table t rename id to t_id")
         m2 = MetaData()
         t2 = Table("t", m2, autoload_with=testing.db)
         eq_([c.name for c in t2.primary_key], ["t_id"])
@@ -936,13 +940,13 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             Column("name", String(20), index=True),
             Column("aname", String(20)),
         )
-        metadata.create_all()
-        with testing.db.connect() as c:
-            c.exec_driver_sql("create index idx1 on party ((id || name))")
-            c.exec_driver_sql(
+        metadata.create_all(testing.db)
+        with testing.db.begin() as conn:
+            conn.exec_driver_sql("create index idx1 on party ((id || name))")
+            conn.exec_driver_sql(
                 "create unique index idx2 on party (id) where name = 'test'"
             )
-            c.exec_driver_sql(
+            conn.exec_driver_sql(
                 """
                 create index idx3 on party using btree
                     (lower(name::text), lower(aname::text))
@@ -1029,7 +1033,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             Column("aname", String(20)),
         )
 
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
 
             t1.create(conn)
 
@@ -1109,18 +1113,19 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             Column("id", Integer, primary_key=True),
             Column("x", Integer),
         )
-        metadata.create_all()
-        conn = testing.db.connect().execution_options(autocommit=True)
-        conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
-        conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
+        metadata.create_all(testing.db)
+        with testing.db.begin() as conn:
+            conn.exec_driver_sql("CREATE INDEX idx1 ON t (x)")
+            conn.exec_driver_sql("ALTER TABLE t RENAME COLUMN x to y")
 
-        ind = testing.db.dialect.get_indexes(conn, "t", None)
-        expected = [{"name": "idx1", "unique": False, "column_names": ["y"]}]
-        if testing.requires.index_reflects_included_columns.enabled:
-            expected[0]["include_columns"] = []
+            ind = testing.db.dialect.get_indexes(conn, "t", None)
+            expected = [
+                {"name": "idx1", "unique": False, "column_names": ["y"]}
+            ]
+            if testing.requires.index_reflects_included_columns.enabled:
+                expected[0]["include_columns"] = []
 
-        eq_(ind, expected)
-        conn.close()
+            eq_(ind, expected)
 
     @testing.fails_if("postgresql < 8.2", "reloptions not supported")
     @testing.provide_metadata
@@ -1135,9 +1140,9 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             Column("id", Integer, primary_key=True),
             Column("x", Integer),
         )
-        metadata.create_all()
+        metadata.create_all(testing.db)
 
-        with testing.db.connect().execution_options(autocommit=True) as conn:
+        with testing.db.begin() as conn:
             conn.exec_driver_sql(
                 "CREATE INDEX idx1 ON t (x) WITH (fillfactor = 50)"
             )
@@ -1177,8 +1182,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             Column("id", Integer, primary_key=True),
             Column("x", ARRAY(Integer)),
         )
-        metadata.create_all()
-        with testing.db.connect().execution_options(autocommit=True) as conn:
+        metadata.create_all(testing.db)
+        with testing.db.begin() as conn:
             conn.exec_driver_sql("CREATE INDEX idx1 ON t USING gin (x)")
 
             ind = testing.db.dialect.get_indexes(conn, "t", None)
@@ -1215,7 +1220,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             Column("name", String(20)),
         )
         metadata.create_all()
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.exec_driver_sql("CREATE INDEX idx1 ON t (x) INCLUDE (name)")
 
             # prior to #5205, this would return:
@@ -1312,8 +1317,7 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             eq_(fk, fk_ref[fk["name"]])
 
     @testing.provide_metadata
-    def test_inspect_enums_schema(self):
-        conn = testing.db.connect()
+    def test_inspect_enums_schema(self, connection):
         enum_type = postgresql.ENUM(
             "sad",
             "ok",
@@ -1322,8 +1326,8 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             schema="test_schema",
             metadata=self.metadata,
         )
-        enum_type.create(conn)
-        inspector = inspect(conn)
+        enum_type.create(connection)
+        inspector = inspect(connection)
         eq_(
             inspector.get_enums("test_schema"),
             [
index e7174f234a0a83d6884977523188e5a68745ffb0..ae7a65a3af8a03b8065a50af4a85b92c325ec063 100644 (file)
@@ -206,7 +206,7 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
             ),
             schema=symbol_name,
         )
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn = conn.execution_options(
                 schema_translate_map={symbol_name: testing.config.test_schema}
             )
index de8b22b67cd8350231e546b8f66d79661958950d..cd8768d73b5474e13e1f9f0234968767e37bdd84 100644 (file)
@@ -30,34 +30,37 @@ class MxODBCTest(fixtures.TestBase):
         )
         conn = engine.connect()
 
-        # crud: uses execute
-        conn.execute(t1.insert().values(c1="foo"))
-        conn.execute(t1.delete().where(t1.c.c1 == "foo"))
-        conn.execute(t1.update().where(t1.c.c1 == "foo").values(c1="bar"))
-
-        # select: uses executedirect
-        conn.execute(t1.select())
-
-        # manual flagging
-        conn.execution_options(native_odbc_execute=True).execute(t1.select())
-        conn.execution_options(native_odbc_execute=False).execute(
-            t1.insert().values(c1="foo")
-        )
+        with conn.begin():
+            # crud: uses execute
+            conn.execute(t1.insert().values(c1="foo"))
+            conn.execute(t1.delete().where(t1.c.c1 == "foo"))
+            conn.execute(t1.update().where(t1.c.c1 == "foo").values(c1="bar"))
 
-        eq_(
-            # fmt: off
-            [
-                c[2]
-                for c in dbapi.connect.return_value.cursor.
-                return_value.execute.mock_calls
-            ],
-            # fmt: on
-            [
-                {"direct": True},
-                {"direct": True},
-                {"direct": True},
-                {"direct": True},
-                {"direct": False},
-                {"direct": True},
-            ]
-        )
+            # select: uses executedirect
+            conn.execute(t1.select())
+
+            # manual flagging
+            conn.execution_options(native_odbc_execute=True).execute(
+                t1.select()
+            )
+            conn.execution_options(native_odbc_execute=False).execute(
+                t1.insert().values(c1="foo")
+            )
+
+            eq_(
+                # fmt: off
+                [
+                    c[2]
+                    for c in dbapi.connect.return_value.cursor.
+                    return_value.execute.mock_calls
+                ],
+                # fmt: on
+                [
+                    {"direct": True},
+                    {"direct": True},
+                    {"direct": True},
+                    {"direct": True},
+                    {"direct": False},
+                    {"direct": True},
+                ]
+            )
index f8b50f8883979e8f32088372028ba97e8b6aaa53..12200f832f376b4a145855d8f4924db4cf5a323c 100644 (file)
@@ -63,8 +63,9 @@ from sqlalchemy.util import ue
 
 
 def exec_sql(engine, sql, *args, **kwargs):
-    conn = engine.connect(close_with_result=True)
-    return conn.exec_driver_sql(sql, *args, **kwargs)
+    # TODO: convert all tests to not use this
+    with engine.begin() as conn:
+        conn.exec_driver_sql(sql, *args, **kwargs)
 
 
 class TestTypes(fixtures.TestBase, AssertsExecutionResults):
@@ -189,11 +190,13 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults):
         connection.execute(
             t.insert().values(d=datetime.datetime(2010, 10, 15, 12, 37, 0))
         )
-        exec_sql(
-            testing.db, "insert into t (d) values ('2004-05-21T00:00:00')"
+        connection.exec_driver_sql(
+            "insert into t (d) values ('2004-05-21T00:00:00')"
         )
         eq_(
-            exec_sql(testing.db, "select * from t order by d").fetchall(),
+            connection.exec_driver_sql(
+                "select * from t order by d"
+            ).fetchall(),
             [("2004-05-21T00:00:00",), ("2010-10-15T12:37:00",)],
         )
         eq_(
@@ -216,9 +219,13 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults):
         connection.execute(
             t.insert().values(d=datetime.datetime(2010, 10, 15, 12, 37, 0))
         )
-        exec_sql(testing.db, "insert into t (d) values ('20040521000000')")
+        connection.exec_driver_sql(
+            "insert into t (d) values ('20040521000000')"
+        )
         eq_(
-            exec_sql(testing.db, "select * from t order by d").fetchall(),
+            connection.exec_driver_sql(
+                "select * from t order by d"
+            ).fetchall(),
             [("20040521000000",), ("20101015123700",)],
         )
         eq_(
@@ -238,9 +245,11 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults):
         t = Table("t", self.metadata, Column("d", sqlite_date))
         self.metadata.create_all(connection)
         connection.execute(t.insert().values(d=datetime.date(2010, 10, 15)))
-        exec_sql(testing.db, "insert into t (d) values ('20040521')")
+        connection.exec_driver_sql("insert into t (d) values ('20040521')")
         eq_(
-            exec_sql(testing.db, "select * from t order by d").fetchall(),
+            connection.exec_driver_sql(
+                "select * from t order by d"
+            ).fetchall(),
             [("20040521",), ("20101015",)],
         )
         eq_(
@@ -256,11 +265,15 @@ class TestTypes(fixtures.TestBase, AssertsExecutionResults):
             regexp=r"(\d+)\|(\d+)\|(\d+)",
         )
         t = Table("t", self.metadata, Column("d", sqlite_date))
-        self.metadata.create_all(testing.db)
+        self.metadata.create_all(connection)
         connection.execute(t.insert().values(d=datetime.date(2010, 10, 15)))
-        exec_sql(testing.db, "insert into t (d) values ('2004|05|21')")
+
+        connection.exec_driver_sql("insert into t (d) values ('2004|05|21')")
+
         eq_(
-            exec_sql(testing.db, "select * from t order by d").fetchall(),
+            connection.exec_driver_sql(
+                "select * from t order by d"
+            ).fetchall(),
             [("2004|05|21",), ("2010|10|15",)],
         )
         eq_(
@@ -313,7 +326,7 @@ class JSONTest(fixtures.TestBase):
 
         value = {"json": {"foo": "bar"}, "recs": ["one", "two"]}
 
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.execute(sqlite_json.insert(), foo=value)
 
             eq_(conn.scalar(select(sqlite_json.c.foo)), value)
@@ -328,7 +341,7 @@ class JSONTest(fixtures.TestBase):
 
         value = {"json": {"foo": "bar"}}
 
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.execute(sqlite_json.insert(), foo=value)
 
             eq_(conn.scalar(select(sqlite_json.c.foo["json"])), value["json"])
@@ -551,7 +564,7 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL):
             Column("x", Boolean, server_default=sql.false()),
         )
         t.create(testing.db)
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.execute(t.insert())
             conn.execute(t.insert().values(x=True))
             eq_(
@@ -568,7 +581,7 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL):
             Column("x", DateTime(), server_default=func.now()),
         )
         t.create(testing.db)
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             now = conn.scalar(func.now())
             today = datetime.datetime.today()
             conn.execute(t.insert())
@@ -587,7 +600,7 @@ class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL):
             Column("x", Integer(), server_default=func.abs(-5) + 17),
         )
         t.create(testing.db)
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.execute(t.insert())
             conn.execute(t.insert().values(x=35))
             eq_(
@@ -622,7 +635,8 @@ class DialectTest(
                 )
             )
 
-    def test_extra_reserved_words(self):
+    @testing.provide_metadata
+    def test_extra_reserved_words(self, connection):
         """Tests reserved words in identifiers.
 
         'true', 'false', and 'column' are undocumented reserved words
@@ -630,22 +644,19 @@ class DialectTest(
         here to ensure they remain in place if the dialect's
         reserved_words set is updated in the future."""
 
-        meta = MetaData(testing.db)
         t = Table(
             "reserved",
-            meta,
+            self.metadata,
             Column("safe", Integer),
             Column("true", Integer),
             Column("false", Integer),
             Column("column", Integer),
             Column("exists", Integer),
         )
-        try:
-            meta.create_all()
-            t.insert().execute(safe=1)
-            list(t.select().execute())
-        finally:
-            meta.drop_all()
+        self.metadata.create_all(connection)
+        connection.execute(t.insert(), dict(safe=1))
+        result = connection.execute(t.select())
+        eq_(list(result), [(1, None, None, None, None)])
 
     @testing.provide_metadata
     def test_quoted_identifiers_functional_one(self):
@@ -827,7 +838,8 @@ class AttachedDBTest(fixtures.TestBase):
             schema="test_schema",
         )
 
-        meta.create_all(self.conn)
+        with self.conn.begin():
+            meta.create_all(self.conn)
         return ct
 
     def setup(self):
@@ -835,7 +847,8 @@ class AttachedDBTest(fixtures.TestBase):
         self.metadata = MetaData()
 
     def teardown(self):
-        self.metadata.drop_all(self.conn)
+        with self.conn.begin():
+            self.metadata.drop_all(self.conn)
         self.conn.close()
 
     def test_no_tables(self):
@@ -928,18 +941,20 @@ class AttachedDBTest(fixtures.TestBase):
     def test_crud(self):
         ct = self._fixture()
 
-        self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
-        eq_(self.conn.execute(ct.select()).fetchall(), [(1, "foo")])
+        with self.conn.begin():
+            self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
+            eq_(self.conn.execute(ct.select()).fetchall(), [(1, "foo")])
 
-        self.conn.execute(ct.update(), {"id": 2, "name": "bar"})
-        eq_(self.conn.execute(ct.select()).fetchall(), [(2, "bar")])
-        self.conn.execute(ct.delete())
-        eq_(self.conn.execute(ct.select()).fetchall(), [])
+            self.conn.execute(ct.update(), {"id": 2, "name": "bar"})
+            eq_(self.conn.execute(ct.select()).fetchall(), [(2, "bar")])
+            self.conn.execute(ct.delete())
+            eq_(self.conn.execute(ct.select()).fetchall(), [])
 
     def test_col_targeting(self):
         ct = self._fixture()
 
-        self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
+        with self.conn.begin():
+            self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
         row = self.conn.execute(ct.select()).first()
         eq_(row._mapping["id"], 1)
         eq_(row._mapping["name"], "foo")
@@ -947,7 +962,8 @@ class AttachedDBTest(fixtures.TestBase):
     def test_col_targeting_union(self):
         ct = self._fixture()
 
-        self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
+        with self.conn.begin():
+            self.conn.execute(ct.insert(), {"id": 1, "name": "foo"})
         row = self.conn.execute(ct.select().union(ct.select())).first()
         eq_(row._mapping["id"], 1)
         eq_(row._mapping["name"], "foo")
@@ -2236,7 +2252,7 @@ class ConstraintReflectionTest(fixtures.TestBase):
         )
 
     def test_foreign_key_options_unnamed_inline(self):
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.exec_driver_sql(
                 "create table foo (id integer, "
                 "foreign key (id) references bar (id) on update cascade)"
@@ -2571,33 +2587,33 @@ class TypeReflectionTest(fixtures.TestBase):
     def _test_round_trip(self, fixture, warnings=False):
         from sqlalchemy import inspect
 
-        conn = testing.db.connect()
         for from_, to_ in self._fixture_as_string(fixture):
-            inspector = inspect(conn)
-            conn.exec_driver_sql("CREATE TABLE foo (data %s)" % from_)
-            try:
-                if warnings:
+            with testing.db.begin() as conn:
+                inspector = inspect(conn)
+                conn.exec_driver_sql("CREATE TABLE foo (data %s)" % from_)
+                try:
+                    if warnings:
 
-                    def go():
-                        return inspector.get_columns("foo")[0]
+                        def go():
+                            return inspector.get_columns("foo")[0]
 
-                    col_info = testing.assert_warnings(
-                        go, ["Could not instantiate"], regex=True
-                    )
-                else:
-                    col_info = inspector.get_columns("foo")[0]
-                expected_type = type(to_)
-                is_(type(col_info["type"]), expected_type)
-
-                # test args
-                for attr in ("scale", "precision", "length"):
-                    if getattr(to_, attr, None) is not None:
-                        eq_(
-                            getattr(col_info["type"], attr),
-                            getattr(to_, attr, None),
+                        col_info = testing.assert_warnings(
+                            go, ["Could not instantiate"], regex=True
                         )
-            finally:
-                conn.exec_driver_sql("DROP TABLE foo")
+                    else:
+                        col_info = inspector.get_columns("foo")[0]
+                    expected_type = type(to_)
+                    is_(type(col_info["type"]), expected_type)
+
+                    # test args
+                    for attr in ("scale", "precision", "length"):
+                        if getattr(to_, attr, None) is not None:
+                            eq_(
+                                getattr(col_info["type"], attr),
+                                getattr(to_, attr, None),
+                            )
+                finally:
+                    conn.exec_driver_sql("DROP TABLE foo")
 
     def test_lookup_direct_lookup(self):
         self._test_lookup_direct(self._fixed_lookup_fixture())
index f2429175f96191844a3c178796f21c619bf70f7e..5cbb4785466da784858738acd630520a2b392f33 100644 (file)
@@ -489,6 +489,7 @@ class DDLExecutionTest(fixtures.TestBase):
     def test_ddl_execute(self):
         engine = create_engine("sqlite:///")
         cx = engine.connect()
+        cx.begin()
         table = self.users
         ddl = DDL("SELECT 1")
 
index 5e32cc3e96f888c46b9f05a527debd144401df44..47e59b55da3267df15fa377256e32dff4c396643 100644 (file)
@@ -93,6 +93,9 @@ class ConnectionlessDeprecationTest(fixtures.TestBase):
 
         for meta in (MetaData, ThreadLocalMetaData):
             for bind in (testing.db, testing.db.connect()):
+                if isinstance(bind, engine.Connection):
+                    bind.begin()
+
                 if meta is ThreadLocalMetaData:
                     with testing.expect_deprecated(
                         "ThreadLocalMetaData is deprecated"
@@ -151,6 +154,8 @@ class ConnectionlessDeprecationTest(fixtures.TestBase):
 
     def test_bind_create_drop_constructor_bound(self):
         for bind in (testing.db, testing.db.connect()):
+            if isinstance(bind, engine.Connection):
+                bind.begin()
             try:
                 for args in (([bind], {}), ([], {"bind": bind})):
                     metadata = MetaData(*args[0], **args[1])
@@ -177,15 +182,25 @@ class ConnectionlessDeprecationTest(fixtures.TestBase):
             test_needs_acid=True,
         )
         conn = testing.db.connect()
-        metadata.create_all(bind=conn)
+        with conn.begin():
+            metadata.create_all(bind=conn)
         try:
             trans = conn.begin()
             metadata.bind = conn
             t = table.insert()
             assert t.bind is conn
-            table.insert().execute(foo=5)
-            table.insert().execute(foo=6)
-            table.insert().execute(foo=7)
+            with testing.expect_deprecated_20(
+                r"The Executable.execute\(\) method is considered legacy"
+            ):
+                table.insert().execute(foo=5)
+            with testing.expect_deprecated_20(
+                r"The Executable.execute\(\) method is considered legacy"
+            ):
+                table.insert().execute(foo=6)
+            with testing.expect_deprecated_20(
+                r"The Executable.execute\(\) method is considered legacy"
+            ):
+                table.insert().execute(foo=7)
             trans.rollback()
             metadata.bind = None
             assert (
@@ -195,7 +210,8 @@ class ConnectionlessDeprecationTest(fixtures.TestBase):
                 == 0
             )
         finally:
-            metadata.drop_all(bind=conn)
+            with conn.begin():
+                metadata.drop_all(bind=conn)
 
     def test_bind_clauseelement(self):
         metadata = MetaData()
@@ -215,14 +231,21 @@ class ConnectionlessDeprecationTest(fixtures.TestBase):
                         ):
                             e = elem(bind=bind)
                         assert e.bind is bind
-                        e.execute().close()
+                        with testing.expect_deprecated_20(
+                            r"The Executable.execute\(\) method is "
+                            "considered legacy"
+                        ):
+                            e.execute().close()
                     finally:
                         if isinstance(bind, engine.Connection):
                             bind.close()
 
                 e = elem()
                 assert e.bind is None
-                assert_raises(exc.UnboundExecutionError, e.execute)
+                with testing.expect_deprecated_20(
+                    r"The Executable.execute\(\) method is considered legacy"
+                ):
+                    assert_raises(exc.UnboundExecutionError, e.execute)
         finally:
             if isinstance(bind, engine.Connection):
                 bind.close()
@@ -365,6 +388,11 @@ class TransactionTest(fixtures.TablesTest):
         )
         Table("inserttable", metadata, Column("data", String(20)))
 
+    @testing.fixture
+    def local_connection(self):
+        with testing.db.connect() as conn:
+            yield conn
+
     def test_transaction_container(self):
         users = self.tables.users
 
@@ -429,6 +457,110 @@ class TransactionTest(fixtures.TablesTest):
                     "insert into inserttable (data) values ('thedata')"
                 )
 
+    def test_branch_autorollback(self, local_connection):
+        connection = local_connection
+        users = self.tables.users
+        branched = connection.connect()
+        with testing.expect_deprecated_20(
+            "The current statement is being autocommitted using "
+            "implicit autocommit"
+        ):
+            branched.execute(
+                users.insert(), dict(user_id=1, user_name="user1")
+            )
+        assert_raises(
+            exc.DBAPIError,
+            branched.execute,
+            users.insert(),
+            dict(user_id=1, user_name="user1"),
+        )
+        # can continue w/o issue
+        with testing.expect_deprecated_20(
+            "The current statement is being autocommitted using "
+            "implicit autocommit"
+        ):
+            branched.execute(
+                users.insert(), dict(user_id=2, user_name="user2")
+            )
+
+    def test_branch_orig_rollback(self, local_connection):
+        connection = local_connection
+        users = self.tables.users
+        branched = connection.connect()
+        with testing.expect_deprecated_20(
+            "The current statement is being autocommitted using "
+            "implicit autocommit"
+        ):
+            branched.execute(
+                users.insert(), dict(user_id=1, user_name="user1")
+            )
+        nested = branched.begin()
+        assert branched.in_transaction()
+        branched.execute(users.insert(), dict(user_id=2, user_name="user2"))
+        nested.rollback()
+        eq_(
+            connection.exec_driver_sql("select count(*) from users").scalar(),
+            1,
+        )
+
+    @testing.requires.independent_connections
+    def test_branch_autocommit(self, local_connection):
+        users = self.tables.users
+        with testing.db.connect() as connection:
+            branched = connection.connect()
+            with testing.expect_deprecated_20(
+                "The current statement is being autocommitted using "
+                "implicit autocommit"
+            ):
+                branched.execute(
+                    users.insert(), dict(user_id=1, user_name="user1")
+                )
+
+        eq_(
+            local_connection.execute(
+                text("select count(*) from users")
+            ).scalar(),
+            1,
+        )
+
+    @testing.requires.savepoints
+    def test_branch_savepoint_rollback(self, local_connection):
+        connection = local_connection
+        users = self.tables.users
+        trans = connection.begin()
+        branched = connection.connect()
+        assert branched.in_transaction()
+        branched.execute(users.insert(), user_id=1, user_name="user1")
+        nested = branched.begin_nested()
+        branched.execute(users.insert(), user_id=2, user_name="user2")
+        nested.rollback()
+        assert connection.in_transaction()
+        trans.commit()
+        eq_(
+            connection.exec_driver_sql("select count(*) from users").scalar(),
+            1,
+        )
+
+    @testing.requires.two_phase_transactions
+    def test_branch_twophase_rollback(self, local_connection):
+        connection = local_connection
+        users = self.tables.users
+        branched = connection.connect()
+        assert not branched.in_transaction()
+        with testing.expect_deprecated_20(
+            r"The current statement is being autocommitted using "
+            "implicit autocommit"
+        ):
+            branched.execute(users.insert(), user_id=1, user_name="user1")
+        nested = branched.begin_twophase()
+        branched.execute(users.insert(), user_id=2, user_name="user2")
+        nested.rollback()
+        assert not connection.in_transaction()
+        eq_(
+            connection.exec_driver_sql("select count(*) from users").scalar(),
+            1,
+        )
+
 
 class HandleInvalidatedOnConnectTest(fixtures.TestBase):
     __requires__ = ("sqlite",)
@@ -699,20 +831,20 @@ class DeprecatedReflectionTest(fixtures.TablesTest):
     def test_create_drop_explicit(self):
         metadata = MetaData()
         table = Table("test_table", metadata, Column("foo", Integer))
-        for bind in (testing.db, testing.db.connect()):
-            for args in [([], {"bind": bind}), ([bind], {})]:
-                metadata.create_all(*args[0], **args[1])
-                with testing.expect_deprecated(
-                    r"The Table.exists\(\) method is deprecated"
-                ):
-                    assert table.exists(*args[0], **args[1])
-                metadata.drop_all(*args[0], **args[1])
-                table.create(*args[0], **args[1])
-                table.drop(*args[0], **args[1])
-                with testing.expect_deprecated(
-                    r"The Table.exists\(\) method is deprecated"
-                ):
-                    assert not table.exists(*args[0], **args[1])
+        bind = testing.db
+        for args in [([], {"bind": bind}), ([bind], {})]:
+            metadata.create_all(*args[0], **args[1])
+            with testing.expect_deprecated(
+                r"The Table.exists\(\) method is deprecated"
+            ):
+                assert table.exists(*args[0], **args[1])
+            metadata.drop_all(*args[0], **args[1])
+            table.create(*args[0], **args[1])
+            table.drop(*args[0], **args[1])
+            with testing.expect_deprecated(
+                r"The Table.exists\(\) method is deprecated"
+            ):
+                assert not table.exists(*args[0], **args[1])
 
     def test_create_drop_err_table(self):
         metadata = MetaData()
@@ -1195,3 +1327,208 @@ class DDLExecutionTest(fixtures.TestBase):
                 with testing.expect_deprecated_20(ddl_msg):
                     r = fn(**kw)
                 eq_(list(r), [(1,)])
+
+
+class AutocommitKeywordFixture(object):
+    def _test_keyword(self, keyword, expected=True):
+        dbapi = Mock(
+            connect=Mock(
+                return_value=Mock(
+                    cursor=Mock(return_value=Mock(description=()))
+                )
+            )
+        )
+        engine = engines.testing_engine(
+            options={"_initialize": False, "pool_reset_on_return": None}
+        )
+        engine.dialect.dbapi = dbapi
+
+        with engine.connect() as conn:
+            if expected:
+                with testing.expect_deprecated_20(
+                    "The current statement is being autocommitted "
+                    "using implicit autocommit"
+                ):
+                    conn.exec_driver_sql(
+                        "%s something table something" % keyword
+                    )
+            else:
+                conn.exec_driver_sql("%s something table something" % keyword)
+
+            if expected:
+                eq_(
+                    [n for (n, k, s) in dbapi.connect().mock_calls],
+                    ["cursor", "commit"],
+                )
+            else:
+                eq_(
+                    [n for (n, k, s) in dbapi.connect().mock_calls], ["cursor"]
+                )
+
+
+class AutocommitTextTest(AutocommitKeywordFixture, fixtures.TestBase):
+    __backend__ = True
+
+    def test_update(self):
+        self._test_keyword("UPDATE")
+
+    def test_insert(self):
+        self._test_keyword("INSERT")
+
+    def test_delete(self):
+        self._test_keyword("DELETE")
+
+    def test_alter(self):
+        self._test_keyword("ALTER TABLE")
+
+    def test_create(self):
+        self._test_keyword("CREATE TABLE foobar")
+
+    def test_drop(self):
+        self._test_keyword("DROP TABLE foobar")
+
+    def test_select(self):
+        self._test_keyword("SELECT foo FROM table", False)
+
+
+class ExplicitAutoCommitTest(fixtures.TestBase):
+
+    """test the 'autocommit' flag on select() and text() objects.
+
+    Requires PostgreSQL so that we may define a custom function which
+    modifies the database."""
+
+    __only_on__ = "postgresql"
+
+    @classmethod
+    def setup_class(cls):
+        global metadata, foo
+        metadata = MetaData(testing.db)
+        foo = Table(
+            "foo",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", String(100)),
+        )
+        with testing.db.begin() as conn:
+            metadata.create_all(conn)
+            conn.exec_driver_sql(
+                "create function insert_foo(varchar) "
+                "returns integer as 'insert into foo(data) "
+                "values ($1);select 1;' language sql"
+            )
+
+    def teardown(self):
+        with testing.db.begin() as conn:
+            conn.execute(foo.delete())
+
+    @classmethod
+    def teardown_class(cls):
+        with testing.db.begin() as conn:
+            conn.exec_driver_sql("drop function insert_foo(varchar)")
+            metadata.drop_all(conn)
+
+    def test_control(self):
+
+        # test that not using autocommit does not commit
+
+        conn1 = testing.db.connect()
+        conn2 = testing.db.connect()
+        conn1.execute(select(func.insert_foo("data1")))
+        assert conn2.execute(select(foo.c.data)).fetchall() == []
+        conn1.execute(text("select insert_foo('moredata')"))
+        assert conn2.execute(select(foo.c.data)).fetchall() == []
+        trans = conn1.begin()
+        trans.commit()
+        assert conn2.execute(select(foo.c.data)).fetchall() == [
+            ("data1",),
+            ("moredata",),
+        ]
+        conn1.close()
+        conn2.close()
+
+    def test_explicit_compiled(self):
+        conn1 = testing.db.connect()
+        conn2 = testing.db.connect()
+
+        with testing.expect_deprecated_20(
+            "The current statement is being autocommitted using "
+            "implicit autocommit"
+        ):
+            conn1.execute(
+                select(func.insert_foo("data1")).execution_options(
+                    autocommit=True
+                )
+            )
+        assert conn2.execute(select(foo.c.data)).fetchall() == [("data1",)]
+        conn1.close()
+        conn2.close()
+
+    def test_explicit_connection(self):
+        conn1 = testing.db.connect()
+        conn2 = testing.db.connect()
+        with testing.expect_deprecated_20(
+            "The current statement is being autocommitted using "
+            "implicit autocommit"
+        ):
+            conn1.execution_options(autocommit=True).execute(
+                select(func.insert_foo("data1"))
+            )
+        eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)])
+
+        # connection supersedes statement
+
+        conn1.execution_options(autocommit=False).execute(
+            select(func.insert_foo("data2")).execution_options(autocommit=True)
+        )
+        eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)])
+
+        # ditto
+
+        with testing.expect_deprecated_20(
+            "The current statement is being autocommitted using "
+            "implicit autocommit"
+        ):
+            conn1.execution_options(autocommit=True).execute(
+                select(func.insert_foo("data3")).execution_options(
+                    autocommit=False
+                )
+            )
+        eq_(
+            conn2.execute(select(foo.c.data)).fetchall(),
+            [("data1",), ("data2",), ("data3",)],
+        )
+        conn1.close()
+        conn2.close()
+
+    def test_explicit_text(self):
+        conn1 = testing.db.connect()
+        conn2 = testing.db.connect()
+        with testing.expect_deprecated_20(
+            "The current statement is being autocommitted using "
+            "implicit autocommit"
+        ):
+            conn1.execute(
+                text("select insert_foo('moredata')").execution_options(
+                    autocommit=True
+                )
+            )
+        assert conn2.execute(select(foo.c.data)).fetchall() == [("moredata",)]
+        conn1.close()
+        conn2.close()
+
+    def test_implicit_text(self):
+        conn1 = testing.db.connect()
+        conn2 = testing.db.connect()
+        with testing.expect_deprecated_20(
+            "The current statement is being autocommitted using "
+            "implicit autocommit"
+        ):
+            conn1.execute(
+                text("insert into foo (data) values ('implicitdata')")
+            )
+        assert conn2.execute(select(foo.c.data)).fetchall() == [
+            ("implicitdata",)
+        ]
+        conn1.close()
+        conn2.close()
index efec9376c1057a1495fd71e5225fef7c857ee5ed..55a114409bd146c6bc1b837364f3ad356bfcbe9a 100644 (file)
@@ -543,13 +543,15 @@ class ExecuteTest(fixtures.TablesTest):
 
     @testing.only_on("sqlite")
     def test_execute_compiled_favors_compiled_paramstyle(self):
+        users = self.tables.users
+
         with patch.object(testing.db.dialect, "do_execute") as do_exec:
             stmt = users.update().values(user_id=1, user_name="foo")
 
             d1 = default.DefaultDialect(paramstyle="format")
             d2 = default.DefaultDialect(paramstyle="pyformat")
 
-            with testing.db.connect() as conn:
+            with testing.db.begin() as conn:
                 conn.execute(stmt.compile(dialect=d1))
                 conn.execute(stmt.compile(dialect=d2))
 
@@ -805,9 +807,8 @@ class ConvenienceExecuteTest(fixtures.TablesTest):
 
     def test_connection_as_ctx(self):
         fn = self._trans_fn()
-        ctx = testing.db.connect()
-        testing.run_as_contextmanager(ctx, fn, 5, value=8)
-        # autocommit is on
+        with testing.db.begin() as conn:
+            fn(conn, 5, value=8)
         self._assert_fn(5, value=8)
 
     @testing.fails_on("mysql+oursql", "oursql bug ?  getting wrong rowcount")
@@ -822,14 +823,12 @@ class ConvenienceExecuteTest(fixtures.TablesTest):
             self._assert_no_data()
 
 
-class CompiledCacheTest(fixtures.TestBase):
+class CompiledCacheTest(fixtures.TablesTest):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
-        global users, metadata
-        metadata = MetaData(testing.db)
-        users = Table(
+    def define_tables(cls, metadata):
+        Table(
             "users",
             metadata,
             Column(
@@ -838,19 +837,11 @@ class CompiledCacheTest(fixtures.TestBase):
             Column("user_name", VARCHAR(20)),
             Column("extra_data", VARCHAR(20)),
         )
-        metadata.create_all()
 
-    @engines.close_first
-    def teardown(self):
-        with testing.db.connect() as conn:
-            conn.execute(users.delete())
-
-    @classmethod
-    def teardown_class(cls):
-        metadata.drop_all()
+    def test_cache(self, connection):
+        users = self.tables.users
 
-    def test_cache(self):
-        conn = testing.db.connect()
+        conn = connection
         cache = {}
         cached_conn = conn.execution_options(compiled_cache=cache)
 
@@ -870,7 +861,7 @@ class CompiledCacheTest(fixtures.TestBase):
         "uses blob value that is problematic for some DBAPIs",
     )
     @testing.provide_metadata
-    def test_cache_noleak_on_statement_values(self):
+    def test_cache_noleak_on_statement_values(self, connection):
         # This is a non regression test for an object reference leak caused
         # by the compiled_cache.
 
@@ -883,11 +874,10 @@ class CompiledCacheTest(fixtures.TestBase):
             ),
             Column("photo_blob", LargeBinary()),
         )
-        metadata.create_all()
+        metadata.create_all(connection)
 
-        conn = testing.db.connect()
         cache = {}
-        cached_conn = conn.execution_options(compiled_cache=cache)
+        cached_conn = connection.execution_options(compiled_cache=cache)
 
         class PhotoBlob(bytearray):
             pass
@@ -902,7 +892,10 @@ class CompiledCacheTest(fixtures.TestBase):
             cached_conn.execute(ins, {"photo_blob": blob})
         eq_(compile_mock.call_count, 1)
         eq_(len(cache), 1)
-        eq_(conn.exec_driver_sql("select count(*) from photo").scalar(), 1)
+        eq_(
+            connection.exec_driver_sql("select count(*) from photo").scalar(),
+            1,
+        )
 
         del blob
 
@@ -912,14 +905,15 @@ class CompiledCacheTest(fixtures.TestBase):
         # the statement values (only the keys).
         eq_(ref_blob(), None)
 
-    def test_keys_independent_of_ordering(self):
-        conn = testing.db.connect()
-        conn.execute(
+    def test_keys_independent_of_ordering(self, connection):
+        users = self.tables.users
+
+        connection.execute(
             users.insert(),
             {"user_id": 1, "user_name": "u1", "extra_data": "e1"},
         )
         cache = {}
-        cached_conn = conn.execution_options(compiled_cache=cache)
+        cached_conn = connection.execution_options(compiled_cache=cache)
 
         upd = users.update().where(users.c.user_id == bindparam("b_user_id"))
 
@@ -974,30 +968,32 @@ class CompiledCacheTest(fixtures.TestBase):
         stmt = select(t1.c.q)
 
         cache = {}
-        with config.db.connect().execution_options(
-            compiled_cache=cache
-        ) as conn:
+        with config.db.begin() as conn:
+            conn = conn.execution_options(compiled_cache=cache)
             conn.execute(ins, {"q": 1})
             eq_(conn.scalar(stmt), 1)
 
-        with config.db.connect().execution_options(
-            compiled_cache=cache,
-            schema_translate_map={None: config.test_schema},
-        ) as conn:
+        with config.db.begin() as conn:
+            conn = conn.execution_options(
+                compiled_cache=cache,
+                schema_translate_map={None: config.test_schema},
+            )
             conn.execute(ins, {"q": 2})
             eq_(conn.scalar(stmt), 2)
 
-        with config.db.connect().execution_options(
-            compiled_cache=cache,
-            schema_translate_map={None: None},
-        ) as conn:
+        with config.db.begin() as conn:
+            conn = conn.execution_options(
+                compiled_cache=cache,
+                schema_translate_map={None: None},
+            )
             # should use default schema again even though statement
             # was compiled with test_schema in the map
             eq_(conn.scalar(stmt), 1)
 
-        with config.db.connect().execution_options(
-            compiled_cache=cache
-        ) as conn:
+        with config.db.begin() as conn:
+            conn = conn.execution_options(
+                compiled_cache=cache,
+            )
             eq_(conn.scalar(stmt), 1)
 
 
@@ -1050,7 +1046,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
         t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
 
         with self.sql_execution_asserter(config.db) as asserter:
-            with config.db.connect().execution_options(
+            with config.db.begin() as conn, conn.execution_options(
                 schema_translate_map=map_
             ) as conn:
 
@@ -1091,9 +1087,8 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
         Table("t2", metadata, Column("x", Integer), schema="foo")
         Table("t3", metadata, Column("x", Integer), schema="bar")
 
-        with config.db.connect().execution_options(
-            schema_translate_map=map_
-        ) as conn:
+        with config.db.begin() as conn:
+            conn = conn.execution_options(schema_translate_map=map_)
             metadata.create_all(conn)
 
         insp = inspect(config.db)
@@ -1101,9 +1096,8 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
         is_true(insp.has_table("t2", schema=config.test_schema))
         is_true(insp.has_table("t3", schema=None))
 
-        with config.db.connect().execution_options(
-            schema_translate_map=map_
-        ) as conn:
+        with config.db.begin() as conn:
+            conn = conn.execution_options(schema_translate_map=map_)
             metadata.drop_all(conn)
 
         insp = inspect(config.db)
@@ -1127,7 +1121,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
         t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
 
         with self.sql_execution_asserter(config.db) as asserter:
-            with config.db.connect() as conn:
+            with config.db.begin() as conn:
 
                 execution_options = {"schema_translate_map": map_}
                 conn._execute_20(
@@ -1222,7 +1216,7 @@ class SchemaTranslateTest(fixtures.TestBase, testing.AssertsExecutionResults):
         t3 = Table("t3", metadata, Column("x", Integer), schema="bar")
 
         with self.sql_execution_asserter(config.db) as asserter:
-            with config.db.connect().execution_options(
+            with config.db.begin() as conn, conn.execution_options(
                 schema_translate_map=map_
             ) as conn:
 
@@ -1790,6 +1784,7 @@ class EngineEventsTest(fixtures.TestBase):
             else:
                 ctx = conn = engine.connect()
 
+            trans = conn.begin()
             try:
                 m.create_all(conn, checkfirst=False)
                 try:
@@ -1801,8 +1796,7 @@ class EngineEventsTest(fixtures.TestBase):
                     )
                 finally:
                     m.drop_all(conn)
-                    if engine._is_future:
-                        conn.commit()
+                    trans.commit()
             finally:
                 if ctx:
                     ctx.close()
@@ -3046,7 +3040,7 @@ class DialectEventTest(fixtures.TestBase):
             m1.do_execute_no_params.side_effect
         ) = mock_the_cursor
 
-        with e.connect() as conn:
+        with e.begin() as conn:
             yield conn, m1
 
     def _assert(self, retval, m1, m2, mock_calls):
@@ -3244,59 +3238,6 @@ class DialectEventTest(fixtures.TestBase):
         eq_(conn.info["boom"], "one")
 
 
-class AutocommitKeywordFixture(object):
-    def _test_keyword(self, keyword, expected=True):
-        dbapi = Mock(
-            connect=Mock(
-                return_value=Mock(
-                    cursor=Mock(return_value=Mock(description=()))
-                )
-            )
-        )
-        engine = engines.testing_engine(
-            options={"_initialize": False, "pool_reset_on_return": None}
-        )
-        engine.dialect.dbapi = dbapi
-
-        with engine.connect() as conn:
-            conn.exec_driver_sql("%s something table something" % keyword)
-
-            if expected:
-                eq_(
-                    [n for (n, k, s) in dbapi.connect().mock_calls],
-                    ["cursor", "commit"],
-                )
-            else:
-                eq_(
-                    [n for (n, k, s) in dbapi.connect().mock_calls], ["cursor"]
-                )
-
-
-class AutocommitTextTest(AutocommitKeywordFixture, fixtures.TestBase):
-    __backend__ = True
-
-    def test_update(self):
-        self._test_keyword("UPDATE")
-
-    def test_insert(self):
-        self._test_keyword("INSERT")
-
-    def test_delete(self):
-        self._test_keyword("DELETE")
-
-    def test_alter(self):
-        self._test_keyword("ALTER TABLE")
-
-    def test_create(self):
-        self._test_keyword("CREATE TABLE foobar")
-
-    def test_drop(self):
-        self._test_keyword("DROP TABLE foobar")
-
-    def test_select(self):
-        self._test_keyword("SELECT foo FROM table", False)
-
-
 class FutureExecuteTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
     __backend__ = True
 
@@ -3463,7 +3404,7 @@ class SetInputSizesTest(fixtures.TablesTest):
     def test_set_input_sizes_no_event(self, input_sizes_fixture):
         engine, canary = input_sizes_fixture
 
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             conn.execute(
                 self.tables.users.insert(),
                 [
@@ -3596,7 +3537,7 @@ class SetInputSizesTest(fixtures.TablesTest):
                         0,
                     )
 
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             conn.execute(
                 self.tables.users.insert(),
                 [
index aa272c0cf545085da898b43e8627a238c1b6d824..29b8132aa326bbcb4a133a807e88f731f27613d4 100644 (file)
@@ -22,7 +22,7 @@ from sqlalchemy.testing.util import lazy_gc
 
 
 def exec_sql(engine, sql, *args, **kwargs):
-    with engine.connect() as conn:
+    with engine.begin() as conn:
         return conn.exec_driver_sql(sql, *args, **kwargs)
 
 
@@ -56,7 +56,7 @@ class LogParamsTest(fixtures.TestBase):
             [{"data": str(i)} for i in range(100)],
         )
         eq_(
-            self.buf.buffer[1].message,
+            self.buf.buffer[2].message,
             "[raw sql] [{'data': '0'}, {'data': '1'}, {'data': '2'}, "
             "{'data': '3'}, "
             "{'data': '4'}, {'data': '5'}, {'data': '6'}, {'data': '7'}"
@@ -86,7 +86,7 @@ class LogParamsTest(fixtures.TestBase):
             [{"data": str(i)} for i in range(100)],
         )
         eq_(
-            self.buf.buffer[1].message,
+            self.buf.buffer[2].message,
             "[raw sql] [SQL parameters hidden due to hide_parameters=True]",
         )
 
@@ -97,7 +97,7 @@ class LogParamsTest(fixtures.TestBase):
             [(str(i),) for i in range(100)],
         )
         eq_(
-            self.buf.buffer[1].message,
+            self.buf.buffer[2].message,
             "[raw sql] [('0',), ('1',), ('2',), ('3',), ('4',), ('5',), "
             "('6',), ('7',)  ... displaying 10 of 100 total "
             "bound parameter sets ...  ('98',), ('99',)]",
@@ -227,7 +227,7 @@ class LogParamsTest(fixtures.TestBase):
         exec_sql(self.eng, "INSERT INTO foo (data) values (?)", (largeparam,))
 
         eq_(
-            self.buf.buffer[1].message,
+            self.buf.buffer[2].message,
             "[raw sql] ('%s ... (4702 characters truncated) ... %s',)"
             % (largeparam[0:149], largeparam[-149:]),
         )
@@ -242,7 +242,7 @@ class LogParamsTest(fixtures.TestBase):
         exec_sql(self.eng, "SELECT ?, ?, ?", (lp1, lp2, lp3))
 
         eq_(
-            self.buf.buffer[1].message,
+            self.buf.buffer[2].message,
             "[raw sql] ('%s', '%s', '%s ... (372 characters truncated) "
             "... %s')" % (lp1, lp2, lp3[0:149], lp3[-149:]),
         )
@@ -261,7 +261,7 @@ class LogParamsTest(fixtures.TestBase):
         )
 
         eq_(
-            self.buf.buffer[1].message,
+            self.buf.buffer[2].message,
             "[raw sql] [('%s ... (4702 characters truncated) ... %s',), "
             "('%s',), "
             "('%s ... (372 characters truncated) ... %s',)]"
@@ -347,20 +347,20 @@ class LogParamsTest(fixtures.TestBase):
         row = result.first()
 
         eq_(
-            self.buf.buffer[1].message,
+            self.buf.buffer[2].message,
             "[raw sql] ('%s ... (4702 characters truncated) ... %s',)"
             % (largeparam[0:149], largeparam[-149:]),
         )
 
         if util.py3k:
             eq_(
-                self.buf.buffer[3].message,
+                self.buf.buffer[5].message,
                 "Row ('%s ... (4702 characters truncated) ... %s',)"
                 % (largeparam[0:149], largeparam[-149:]),
             )
         else:
             eq_(
-                self.buf.buffer[3].message,
+                self.buf.buffer[5].message,
                 "Row (u'%s ... (4703 characters truncated) ... %s',)"
                 % (largeparam[0:148], largeparam[-149:]),
             )
@@ -495,7 +495,8 @@ class LoggingNameTest(fixtures.TestBase):
     __requires__ = ("ad_hoc_engines",)
 
     def _assert_names_in_execute(self, eng, eng_name, pool_name):
-        eng.execute(select(1))
+        with eng.connect() as conn:
+            conn.execute(select(1))
         assert self.buf.buffer
         for name in [b.name for b in self.buf.buffer]:
             assert name in (
@@ -505,7 +506,8 @@ class LoggingNameTest(fixtures.TestBase):
             )
 
     def _assert_no_name_in_execute(self, eng):
-        eng.execute(select(1))
+        with eng.connect() as conn:
+            conn.execute(select(1))
         assert self.buf.buffer
         for name in [b.name for b in self.buf.buffer]:
             assert name in (
@@ -548,7 +550,8 @@ class LoggingNameTest(fixtures.TestBase):
 
     def test_named_logger_names_after_dispose(self):
         eng = self._named_engine()
-        eng.execute(select(1))
+        with eng.connect() as conn:
+            conn.execute(select(1))
         eng.dispose()
         eq_(eng.logging_name, "myenginename")
         eq_(eng.pool.logging_name, "mypoolname")
@@ -568,7 +571,8 @@ class LoggingNameTest(fixtures.TestBase):
 
     def test_named_logger_execute_after_dispose(self):
         eng = self._named_engine()
-        eng.execute(select(1))
+        with eng.connect() as conn:
+            conn.execute(select(1))
         eng.dispose()
         self._assert_names_in_execute(eng, "myenginename", "mypoolname")
 
@@ -599,7 +603,8 @@ class EchoTest(fixtures.TestBase):
 
         # do an initial execute to clear out 'first connect'
         # messages
-        e.execute(select(10)).close()
+        with e.connect() as conn:
+            conn.execute(select(10)).close()
         self.buf.flush()
 
         return e
@@ -637,16 +642,25 @@ class EchoTest(fixtures.TestBase):
         e2 = self._testing_engine()
 
         e1.echo = True
-        e1.execute(select(1)).close()
-        e2.execute(select(2)).close()
+
+        with e1.connect() as conn:
+            conn.execute(select(1)).close()
+
+        with e2.connect() as conn:
+            conn.execute(select(2)).close()
 
         e1.echo = False
-        e1.execute(select(3)).close()
-        e2.execute(select(4)).close()
+
+        with e1.connect() as conn:
+            conn.execute(select(3)).close()
+        with e2.connect() as conn:
+            conn.execute(select(4)).close()
 
         e2.echo = True
-        e1.execute(select(5)).close()
-        e2.execute(select(6)).close()
+        with e1.connect() as conn:
+            conn.execute(select(5)).close()
+        with e2.connect() as conn:
+            conn.execute(select(6)).close()
 
         assert self.buf.buffer[0].getMessage().startswith("SELECT 1")
         assert self.buf.buffer[2].getMessage().startswith("SELECT 6")
index 0dc35f99e8507a549a19ee2a2a2068d51dd711d6..ebdaa79a082651a215e0294331301911270c1630 100644 (file)
@@ -1340,20 +1340,24 @@ class InvalidateDuringResultTest(fixtures.TestBase):
 
     def setup(self):
         self.engine = engines.reconnecting_engine()
-        self.meta = MetaData(self.engine)
+        self.meta = MetaData()
         table = Table(
             "sometable",
             self.meta,
             Column("id", Integer, primary_key=True),
             Column("name", String(50)),
         )
-        self.meta.create_all()
-        table.insert().execute(
-            [{"id": i, "name": "row %d" % i} for i in range(1, 100)]
-        )
+
+        with self.engine.begin() as conn:
+            self.meta.create_all(conn)
+            conn.execute(
+                table.insert(),
+                [{"id": i, "name": "row %d" % i} for i in range(1, 100)],
+            )
 
     def teardown(self):
-        self.meta.drop_all()
+        with self.engine.begin() as conn:
+            self.meta.drop_all(conn)
         self.engine.dispose()
 
     @testing.crashes(
index b19836c84241572909e95d53127b686e376fc7ba..48b6c40d771b4923cfe6dacf0aad1fe2d3b4650a 100644 (file)
@@ -2016,7 +2016,7 @@ def createIndexes(con, schema=None):
 
 @testing.requires.views
 def _create_views(con, schema=None):
-    with testing.db.connect() as conn:
+    with testing.db.begin() as conn:
         for table_name in ("users", "email_addresses"):
             fullname = table_name
             if schema:
@@ -2031,7 +2031,7 @@ def _create_views(con, schema=None):
 
 @testing.requires.views
 def _drop_views(con, schema=None):
-    with testing.db.connect() as conn:
+    with testing.db.begin() as conn:
         for table_name in ("email_addresses", "users"):
             fullname = table_name
             if schema:
@@ -2047,7 +2047,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
 
     @testing.requires.denormalized_names
     def setup(self):
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.exec_driver_sql(
                 """
             CREATE TABLE weird_casing(
@@ -2060,7 +2060,7 @@ class ReverseCasingReflectTest(fixtures.TestBase, AssertsCompiledSQL):
 
     @testing.requires.denormalized_names
     def teardown(self):
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.exec_driver_sql("drop table weird_casing")
 
     @testing.requires.denormalized_names
index d0774e84641b706453b853a75dac01567a49eaea..4db5a745ada5339cf59d32fece67f664bd7a27b2 100644 (file)
@@ -5,20 +5,16 @@ from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import func
 from sqlalchemy import INT
-from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import pool as _pool
 from sqlalchemy import select
-from sqlalchemy import String
 from sqlalchemy import testing
-from sqlalchemy import text
 from sqlalchemy import util
 from sqlalchemy import VARCHAR
 from sqlalchemy.engine import base
 from sqlalchemy.engine import characteristics
 from sqlalchemy.engine import default
 from sqlalchemy.engine import url
-from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
@@ -29,31 +25,19 @@ from sqlalchemy.testing.engines import testing_engine
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
-users, metadata = None, None
 
-
-class TransactionTest(fixtures.TestBase):
+class TransactionTest(fixtures.TablesTest):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
-        global users, metadata
-        metadata = MetaData()
-        users = Table(
-            "query_users",
+    def define_tables(cls, metadata):
+        Table(
+            "users",
             metadata,
             Column("user_id", INT, primary_key=True),
             Column("user_name", VARCHAR(20)),
             test_needs_acid=True,
         )
-        users.create(testing.db)
-
-    def teardown(self):
-        testing.db.execute(users.delete()).close()
-
-    @classmethod
-    def teardown_class(cls):
-        users.drop(testing.db)
 
     @testing.fixture
     def local_connection(self):
@@ -61,6 +45,7 @@ class TransactionTest(fixtures.TestBase):
             yield conn
 
     def test_commits(self, local_connection):
+        users = self.tables.users
         connection = local_connection
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
@@ -72,7 +57,7 @@ class TransactionTest(fixtures.TestBase):
         transaction.commit()
 
         transaction = connection.begin()
-        result = connection.exec_driver_sql("select * from query_users")
+        result = connection.exec_driver_sql("select * from users")
         assert len(result.fetchall()) == 3
         transaction.commit()
         connection.close()
@@ -80,17 +65,19 @@ class TransactionTest(fixtures.TestBase):
     def test_rollback(self, local_connection):
         """test a basic rollback"""
 
+        users = self.tables.users
         connection = local_connection
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         connection.execute(users.insert(), user_id=2, user_name="user2")
         connection.execute(users.insert(), user_id=3, user_name="user3")
         transaction.rollback()
-        result = connection.exec_driver_sql("select * from query_users")
+        result = connection.exec_driver_sql("select * from users")
         assert len(result.fetchall()) == 0
 
     def test_raise(self, local_connection):
         connection = local_connection
+        users = self.tables.users
 
         transaction = connection.begin()
         try:
@@ -103,11 +90,12 @@ class TransactionTest(fixtures.TestBase):
             print("Exception: ", e)
             transaction.rollback()
 
-        result = connection.exec_driver_sql("select * from query_users")
+        result = connection.exec_driver_sql("select * from users")
         assert len(result.fetchall()) == 0
 
     def test_nested_rollback(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         try:
             transaction = connection.begin()
             try:
@@ -146,6 +134,7 @@ class TransactionTest(fixtures.TestBase):
 
     def test_branch_nested_rollback(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         connection.begin()
         branched = connection.connect()
         assert branched.in_transaction()
@@ -179,6 +168,7 @@ class TransactionTest(fixtures.TestBase):
     @testing.requires.savepoints
     def test_savepoint_cancelled_by_toplevel_marker(self, local_connection):
         conn = local_connection
+        users = self.tables.users
         trans = conn.begin()
         conn.execute(users.insert(), {"user_id": 1, "user_name": "name"})
 
@@ -245,85 +235,6 @@ class TransactionTest(fixtures.TestBase):
             nested.commit,
         )
 
-    def test_branch_autorollback(self, local_connection):
-        connection = local_connection
-        branched = connection.connect()
-        branched.execute(users.insert(), dict(user_id=1, user_name="user1"))
-        assert_raises(
-            exc.DBAPIError,
-            branched.execute,
-            users.insert(),
-            dict(user_id=1, user_name="user1"),
-        )
-        # can continue w/o issue
-        branched.execute(users.insert(), dict(user_id=2, user_name="user2"))
-
-    def test_branch_orig_rollback(self, local_connection):
-        connection = local_connection
-        branched = connection.connect()
-        branched.execute(users.insert(), dict(user_id=1, user_name="user1"))
-        nested = branched.begin()
-        assert branched.in_transaction()
-        branched.execute(users.insert(), dict(user_id=2, user_name="user2"))
-        nested.rollback()
-        eq_(
-            connection.exec_driver_sql(
-                "select count(*) from query_users"
-            ).scalar(),
-            1,
-        )
-
-    @testing.requires.independent_connections
-    def test_branch_autocommit(self, local_connection):
-        with testing.db.connect() as connection:
-            branched = connection.connect()
-            branched.execute(
-                users.insert(), dict(user_id=1, user_name="user1")
-            )
-
-        eq_(
-            local_connection.execute(
-                text("select count(*) from query_users")
-            ).scalar(),
-            1,
-        )
-
-    @testing.requires.savepoints
-    def test_branch_savepoint_rollback(self, local_connection):
-        connection = local_connection
-        trans = connection.begin()
-        branched = connection.connect()
-        assert branched.in_transaction()
-        branched.execute(users.insert(), user_id=1, user_name="user1")
-        nested = branched.begin_nested()
-        branched.execute(users.insert(), user_id=2, user_name="user2")
-        nested.rollback()
-        assert connection.in_transaction()
-        trans.commit()
-        eq_(
-            connection.exec_driver_sql(
-                "select count(*) from query_users"
-            ).scalar(),
-            1,
-        )
-
-    @testing.requires.two_phase_transactions
-    def test_branch_twophase_rollback(self, local_connection):
-        connection = local_connection
-        branched = connection.connect()
-        assert not branched.in_transaction()
-        branched.execute(users.insert(), user_id=1, user_name="user1")
-        nested = branched.begin_twophase()
-        branched.execute(users.insert(), user_id=2, user_name="user2")
-        nested.rollback()
-        assert not connection.in_transaction()
-        eq_(
-            connection.exec_driver_sql(
-                "select count(*) from query_users"
-            ).scalar(),
-            1,
-        )
-
     def test_deactivated_warning_ctxmanager(self, local_connection):
         with expect_warnings(
             "transaction already deassociated from connection"
@@ -472,20 +383,20 @@ class TransactionTest(fixtures.TestBase):
 
     def test_retains_through_options(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         conn2 = connection.execution_options(dummy=True)
         conn2.execute(users.insert(), user_id=2, user_name="user2")
         transaction.rollback()
         eq_(
-            connection.exec_driver_sql(
-                "select count(*) from query_users"
-            ).scalar(),
+            connection.exec_driver_sql("select count(*) from users").scalar(),
             0,
         )
 
     def test_nesting(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -497,15 +408,16 @@ class TransactionTest(fixtures.TestBase):
         transaction.rollback()
         self.assert_(
             connection.exec_driver_sql(
-                "select count(*) from " "query_users"
+                "select count(*) from " "users"
             ).scalar()
             == 0
         )
-        result = connection.exec_driver_sql("select * from query_users")
+        result = connection.exec_driver_sql("select * from users")
         assert len(result.fetchall()) == 0
 
     def test_with_interface(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         trans = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -517,7 +429,7 @@ class TransactionTest(fixtures.TestBase):
         assert not trans.is_active
         self.assert_(
             connection.exec_driver_sql(
-                "select count(*) from " "query_users"
+                "select count(*) from " "users"
             ).scalar()
             == 0
         )
@@ -528,13 +440,14 @@ class TransactionTest(fixtures.TestBase):
         assert not trans.is_active
         self.assert_(
             connection.exec_driver_sql(
-                "select count(*) from " "query_users"
+                "select count(*) from " "users"
             ).scalar()
             == 1
         )
 
     def test_close(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -549,15 +462,16 @@ class TransactionTest(fixtures.TestBase):
         assert not connection.in_transaction()
         self.assert_(
             connection.exec_driver_sql(
-                "select count(*) from " "query_users"
+                "select count(*) from " "users"
             ).scalar()
             == 5
         )
-        result = connection.exec_driver_sql("select * from query_users")
+        result = connection.exec_driver_sql("select * from users")
         assert len(result.fetchall()) == 5
 
     def test_close2(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         connection.execute(users.insert(), user_id=2, user_name="user2")
@@ -572,16 +486,17 @@ class TransactionTest(fixtures.TestBase):
         assert not connection.in_transaction()
         self.assert_(
             connection.exec_driver_sql(
-                "select count(*) from " "query_users"
+                "select count(*) from " "users"
             ).scalar()
             == 0
         )
-        result = connection.exec_driver_sql("select * from query_users")
+        result = connection.exec_driver_sql("select * from users")
         assert len(result.fetchall()) == 0
 
     @testing.requires.savepoints
     def test_nested_subtransaction_rollback(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         trans2 = connection.begin_nested()
@@ -599,6 +514,7 @@ class TransactionTest(fixtures.TestBase):
     @testing.requires.savepoints
     def test_nested_subtransaction_commit(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         trans2 = connection.begin_nested()
@@ -616,6 +532,7 @@ class TransactionTest(fixtures.TestBase):
     @testing.requires.savepoints
     def test_rollback_to_subtransaction(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         trans2 = connection.begin_nested()
@@ -646,6 +563,7 @@ class TransactionTest(fixtures.TestBase):
     @testing.requires.two_phase_transactions
     def test_two_phase_transaction(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         transaction = connection.begin_twophase()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         transaction.prepare()
@@ -680,6 +598,7 @@ class TransactionTest(fixtures.TestBase):
     @testing.requires.savepoints
     def test_mixed_two_phase_transaction(self, local_connection):
         connection = local_connection
+        users = self.tables.users
         transaction = connection.begin_twophase()
         connection.execute(users.insert(), user_id=1, user_name="user1")
         transaction2 = connection.begin()
@@ -704,6 +623,7 @@ class TransactionTest(fixtures.TestBase):
     @testing.requires.two_phase_transactions
     @testing.requires.two_phase_recovery
     def test_two_phase_recover(self):
+        users = self.tables.users
 
         # 2020, still can't get this to work w/ modern MySQL or MariaDB.
         # the XA RECOVER comes back as bytes, OK, convert to string,
@@ -722,11 +642,14 @@ class TransactionTest(fixtures.TestBase):
 
         with testing.db.connect() as connection2:
             eq_(
-                connection2.execution_options(autocommit=True)
-                .execute(select(users.c.user_id).order_by(users.c.user_id))
-                .fetchall(),
+                connection2.execute(
+                    select(users.c.user_id).order_by(users.c.user_id)
+                ).fetchall(),
                 [],
             )
+
+        # recover_twophase needs to be run in a new transaction
+        with testing.db.connect() as connection2:
             recoverables = connection2.recover_twophase()
             assert transaction.xid in recoverables
             connection2.commit_prepared(transaction.xid, recover=True)
@@ -740,6 +663,7 @@ class TransactionTest(fixtures.TestBase):
     @testing.requires.two_phase_transactions
     def test_multiple_two_phase(self, local_connection):
         conn = local_connection
+        users = self.tables.users
         xa = conn.begin_twophase()
         conn.execute(users.insert(), user_id=1, user_name="user1")
         xa.prepare()
@@ -767,6 +691,7 @@ class TransactionTest(fixtures.TestBase):
         # so that picky backends like MySQL correctly clear out
         # their state when a connection is closed without handling
         # the transaction explicitly.
+        users = self.tables.users
 
         eng = testing_engine()
 
@@ -1005,7 +930,8 @@ class AutoRollbackTest(fixtures.TestBase):
             Column("user_name", VARCHAR(20)),
             test_needs_acid=True,
         )
-        users.create(conn1)
+        with conn1.begin():
+            users.create(conn1)
         conn1.exec_driver_sql("select * from deadlock_users")
         conn1.close()
 
@@ -1014,125 +940,8 @@ class AutoRollbackTest(fixtures.TestBase):
         # pool but still has a lock on "deadlock_users". comment out the
         # rollback in pool/ConnectionFairy._close() to see !
 
-        users.drop(conn2)
-        conn2.close()
-
-
-class ExplicitAutoCommitTest(fixtures.TestBase):
-
-    """test the 'autocommit' flag on select() and text() objects.
-
-    Requires PostgreSQL so that we may define a custom function which
-    modifies the database."""
-
-    __only_on__ = "postgresql"
-
-    @classmethod
-    def setup_class(cls):
-        global metadata, foo
-        metadata = MetaData(testing.db)
-        foo = Table(
-            "foo",
-            metadata,
-            Column("id", Integer, primary_key=True),
-            Column("data", String(100)),
-        )
-        with testing.db.connect() as conn:
-            metadata.create_all(conn)
-            conn.exec_driver_sql(
-                "create function insert_foo(varchar) "
-                "returns integer as 'insert into foo(data) "
-                "values ($1);select 1;' language sql"
-            )
-
-    def teardown(self):
-        with testing.db.connect() as conn:
-            conn.execute(foo.delete())
-
-    @classmethod
-    def teardown_class(cls):
-        with testing.db.connect() as conn:
-            conn.exec_driver_sql("drop function insert_foo(varchar)")
-            metadata.drop_all(conn)
-
-    def test_control(self):
-
-        # test that not using autocommit does not commit
-
-        conn1 = testing.db.connect()
-        conn2 = testing.db.connect()
-        conn1.execute(select(func.insert_foo("data1")))
-        assert conn2.execute(select(foo.c.data)).fetchall() == []
-        conn1.execute(text("select insert_foo('moredata')"))
-        assert conn2.execute(select(foo.c.data)).fetchall() == []
-        trans = conn1.begin()
-        trans.commit()
-        assert conn2.execute(select(foo.c.data)).fetchall() == [
-            ("data1",),
-            ("moredata",),
-        ]
-        conn1.close()
-        conn2.close()
-
-    def test_explicit_compiled(self):
-        conn1 = testing.db.connect()
-        conn2 = testing.db.connect()
-        conn1.execute(
-            select(func.insert_foo("data1")).execution_options(autocommit=True)
-        )
-        assert conn2.execute(select(foo.c.data)).fetchall() == [("data1",)]
-        conn1.close()
-        conn2.close()
-
-    def test_explicit_connection(self):
-        conn1 = testing.db.connect()
-        conn2 = testing.db.connect()
-        conn1.execution_options(autocommit=True).execute(
-            select(func.insert_foo("data1"))
-        )
-        eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)])
-
-        # connection supersedes statement
-
-        conn1.execution_options(autocommit=False).execute(
-            select(func.insert_foo("data2")).execution_options(autocommit=True)
-        )
-        eq_(conn2.execute(select(foo.c.data)).fetchall(), [("data1",)])
-
-        # ditto
-
-        conn1.execution_options(autocommit=True).execute(
-            select(func.insert_foo("data3")).execution_options(
-                autocommit=False
-            )
-        )
-        eq_(
-            conn2.execute(select(foo.c.data)).fetchall(),
-            [("data1",), ("data2",), ("data3",)],
-        )
-        conn1.close()
-        conn2.close()
-
-    def test_explicit_text(self):
-        conn1 = testing.db.connect()
-        conn2 = testing.db.connect()
-        conn1.execute(
-            text("select insert_foo('moredata')").execution_options(
-                autocommit=True
-            )
-        )
-        assert conn2.execute(select(foo.c.data)).fetchall() == [("moredata",)]
-        conn1.close()
-        conn2.close()
-
-    def test_implicit_text(self):
-        conn1 = testing.db.connect()
-        conn2 = testing.db.connect()
-        conn1.execute(text("insert into foo (data) values ('implicitdata')"))
-        assert conn2.execute(select(foo.c.data)).fetchall() == [
-            ("implicitdata",)
-        ]
-        conn1.close()
+        with conn2.begin():
+            users.drop(conn2)
         conn2.close()
 
 
index 3cb29c67dc5868250a7bbbf373a48e9341fea278..df27c8d270f6593edfdeb7b20b06d7ed6d287e27 100644 (file)
@@ -1329,10 +1329,13 @@ class KVChild(object):
         self.value = value
 
 
-class ReconstitutionTest(fixtures.TestBase):
-    def setup(self):
-        metadata = MetaData(testing.db)
-        parents = Table(
+class ReconstitutionTest(fixtures.MappedTest):
+    run_setup_mappers = "each"
+    run_setup_classes = "each"
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
             "parents",
             metadata,
             Column(
@@ -1340,7 +1343,7 @@ class ReconstitutionTest(fixtures.TestBase):
             ),
             Column("name", String(30)),
         )
-        children = Table(
+        Table(
             "children",
             metadata,
             Column(
@@ -1349,22 +1352,23 @@ class ReconstitutionTest(fixtures.TestBase):
             Column("parent_id", Integer, ForeignKey("parents.id")),
             Column("name", String(30)),
         )
-        metadata.create_all()
-        parents.insert().execute(name="p1")
-        self.metadata = metadata
-        self.parents = parents
-        self.children = children
-        Parent.kids = association_proxy("children", "name")
 
-    def teardown(self):
-        self.metadata.drop_all()
-        clear_mappers()
+    @classmethod
+    def insert_data(cls, connection):
+        parents = cls.tables.parents
+        connection.execute(parents.insert(), dict(name="p1"))
+
+    @classmethod
+    def setup_classes(cls):
+        Parent.kids = association_proxy("children", "name")
 
     def test_weak_identity_map(self):
         mapper(
-            Parent, self.parents, properties=dict(children=relationship(Child))
+            Parent,
+            self.tables.parents,
+            properties=dict(children=relationship(Child)),
         )
-        mapper(Child, self.children)
+        mapper(Child, self.tables.children)
         session = create_session()
 
         def add_child(parent_name, child_name):
@@ -1380,9 +1384,11 @@ class ReconstitutionTest(fixtures.TestBase):
 
     def test_copy(self):
         mapper(
-            Parent, self.parents, properties=dict(children=relationship(Child))
+            Parent,
+            self.tables.parents,
+            properties=dict(children=relationship(Child)),
         )
-        mapper(Child, self.children)
+        mapper(Child, self.tables.children)
         p = Parent("p1")
         p.kids.extend(["c1", "c2"])
         p_copy = copy.copy(p)
@@ -1392,9 +1398,11 @@ class ReconstitutionTest(fixtures.TestBase):
 
     def test_pickle_list(self):
         mapper(
-            Parent, self.parents, properties=dict(children=relationship(Child))
+            Parent,
+            self.tables.parents,
+            properties=dict(children=relationship(Child)),
         )
-        mapper(Child, self.children)
+        mapper(Child, self.tables.children)
         p = Parent("p1")
         p.kids.extend(["c1", "c2"])
         r1 = pickle.loads(pickle.dumps(p))
@@ -1407,12 +1415,12 @@ class ReconstitutionTest(fixtures.TestBase):
     def test_pickle_set(self):
         mapper(
             Parent,
-            self.parents,
+            self.tables.parents,
             properties=dict(
                 children=relationship(Child, collection_class=set)
             ),
         )
-        mapper(Child, self.children)
+        mapper(Child, self.tables.children)
         p = Parent("p1")
         p.kids.update(["c1", "c2"])
         r1 = pickle.loads(pickle.dumps(p))
@@ -1425,7 +1433,7 @@ class ReconstitutionTest(fixtures.TestBase):
     def test_pickle_dict(self):
         mapper(
             Parent,
-            self.parents,
+            self.tables.parents,
             properties=dict(
                 children=relationship(
                     KVChild,
@@ -1435,7 +1443,7 @@ class ReconstitutionTest(fixtures.TestBase):
                 )
             ),
         )
-        mapper(KVChild, self.children)
+        mapper(KVChild, self.tables.children)
         p = Parent("p1")
         p.kids.update({"c1": "v1", "c2": "v2"})
         assert p.kids == {"c1": "c1", "c2": "c2"}
index a8c17d7aca2548f8ccb74df4cd5328770c7471fb..e46c65ff02f4eed5047678d5e26409d2e8b8acae 100644 (file)
@@ -53,10 +53,10 @@ class ShardTest(object):
         def id_generator(ctx):
             # in reality, might want to use a separate transaction for this.
 
-            c = db1.connect()
-            nextid = c.execute(ids.select().with_for_update()).scalar()
-            c.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1}))
-            return nextid
+            with db1.begin() as c:
+                nextid = c.execute(ids.select().with_for_update()).scalar()
+                c.execute(ids.update(values={ids.c.nextid: ids.c.nextid + 1}))
+                return nextid
 
         weather_locations = Table(
             "weather_locations",
@@ -80,7 +80,8 @@ class ShardTest(object):
         for db in (db1, db2, db3, db4):
             meta.create_all(db)
 
-        db1.execute(ids.insert(), nextid=1)
+        with db1.begin() as conn:
+            conn.execute(ids.insert(), dict(nextid=1))
 
         self.setup_session()
         self.setup_mappers()
@@ -762,7 +763,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
                 )
 
         e2 = testing_engine()
-        with e2.connect() as conn:
+        with e2.begin() as conn:
             for i in [2, 4]:
                 conn.exec_driver_sql(
                     "CREATE SCHEMA IF NOT EXISTS shard%s" % (i,)
@@ -784,7 +785,7 @@ class MultipleDialectShardTest(ShardTest, fixtures.TestBase):
         for i in [1, 3]:
             os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
 
-        with self.postgresql_engine.connect() as conn:
+        with self.postgresql_engine.begin() as conn:
             self.metadata.drop_all(conn)
             for i in [2, 4]:
                 conn.exec_driver_sql("DROP SCHEMA shard%s CASCADE" % (i,))
index c9a78db081db0c672d47145dfd75577072bc7a73..dab1841943f5b2df32dbbc9d1120631de65cd422 100644 (file)
@@ -2,7 +2,6 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import select
 from sqlalchemy import String
-from sqlalchemy import testing
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import Session
 from sqlalchemy.testing import eq_
@@ -24,13 +23,13 @@ class InheritingSelectablesTest(fixtures.MappedTest):
         cls.tables.bar = foo.select(foo.c.b == "bar").alias("bar")
         cls.tables.baz = foo.select(foo.c.b == "baz").alias("baz")
 
-    def test_load(self):
+    def test_load(self, connection):
         foo, bar, baz = self.tables.foo, self.tables.bar, self.tables.baz
         # TODO: add persistence test also
-        testing.db.execute(foo.insert(), a="not bar", b="baz")
-        testing.db.execute(foo.insert(), a="also not bar", b="baz")
-        testing.db.execute(foo.insert(), a="i am bar", b="bar")
-        testing.db.execute(foo.insert(), a="also bar", b="bar")
+        connection.execute(foo.insert(), dict(a="not bar", b="baz"))
+        connection.execute(foo.insert(), dict(a="also not bar", b="baz"))
+        connection.execute(foo.insert(), dict(a="i am bar", b="bar"))
+        connection.execute(foo.insert(), dict(a="also bar", b="bar"))
 
         class Foo(fixtures.ComparableEntity):
             pass
@@ -69,8 +68,8 @@ class InheritingSelectablesTest(fixtures.MappedTest):
             polymorphic_identity="bar",
         )
 
-        s = Session()
-        assert [Bar(), Bar()] == s.query(Bar).all()
+        s = Session(connection)
+        eq_(s.query(Bar).all(), [Bar(), Bar()])
 
 
 class JoinFromSelectPersistenceTest(fixtures.MappedTest):
index 3a9959857075b2545a1044ac8bbf054c855b8a1b..64f85b3351e72ec3b3b76a1993559562013abfa9 100644 (file)
@@ -151,7 +151,7 @@ class BindIntegrationTest(_fixtures.FixtureTest):
 
         mapper(User, users)
 
-        session = create_session()
+        session = Session()
 
         session.execute(users.insert(), dict(name="Johnny"))
 
@@ -447,7 +447,9 @@ class BindIntegrationTest(_fixtures.FixtureTest):
         sess.commit()
         assert not c.in_transaction()
         assert c.exec_driver_sql("select count(1) from users").scalar() == 1
-        c.exec_driver_sql("delete from users")
+
+        with c.begin():
+            c.exec_driver_sql("delete from users")
         assert c.exec_driver_sql("select count(1) from users").scalar() == 0
 
         c = testing.db.connect()
index dcf07eec8a3258cc4869499953117941a36b5384..c6a1226d4bb4f7aa00d41f0e46d0b2dbc9fa4589 100644 (file)
@@ -190,8 +190,9 @@ class CompileTest(fixtures.ORMTest):
             sa_exc.ArgumentError, "Error creating backref", configure_mappers
         )
 
-    def test_misc_one(self):
-        metadata = MetaData(testing.db)
+    @testing.provide_metadata
+    def test_misc_one(self, connection):
+        metadata = self.metadata
         node_table = Table(
             "node",
             metadata,
@@ -212,33 +213,30 @@ class CompileTest(fixtures.ORMTest):
             Column("host_id", Integer, primary_key=True),
             Column("hostname", String(64), nullable=False, unique=True),
         )
-        metadata.create_all()
-        try:
-            node_table.insert().execute(node_id=1, node_index=5)
-
-            class Node(object):
-                pass
-
-            class NodeName(object):
-                pass
-
-            class Host(object):
-                pass
-
-            mapper(Node, node_table)
-            mapper(Host, host_table)
-            mapper(
-                NodeName,
-                node_name_table,
-                properties={
-                    "node": relationship(Node, backref=backref("names")),
-                    "host": relationship(Host),
-                },
-            )
-            sess = create_session()
-            assert sess.query(Node).get(1).names == []
-        finally:
-            metadata.drop_all()
+        metadata.create_all(connection)
+        connection.execute(node_table.insert(), dict(node_id=1, node_index=5))
+
+        class Node(object):
+            pass
+
+        class NodeName(object):
+            pass
+
+        class Host(object):
+            pass
+
+        mapper(Node, node_table)
+        mapper(Host, host_table)
+        mapper(
+            NodeName,
+            node_name_table,
+            properties={
+                "node": relationship(Node, backref=backref("names")),
+                "host": relationship(Host),
+            },
+        )
+        sess = create_session(connection)
+        assert sess.query(Node).get(1).names == []
 
     def test_conflicting_backref_two(self):
         meta = MetaData()
index 57225d640680981eefc9b2b402574be10403453c..7bc82b2a3ac37e53c24d284d7813196d253cb3be 100644 (file)
@@ -4808,6 +4808,8 @@ class MixedEntitiesTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
 
 
 class SubqueryTest(fixtures.MappedTest):
+    run_deletes = "each"
+
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -4830,7 +4832,12 @@ class SubqueryTest(fixtures.MappedTest):
             Column("score2", sa.Float),
         )
 
-    def test_label_anonymizing(self):
+    @testing.combinations(
+        (True, "score"),
+        (True, None),
+        (False, None),
+    )
+    def test_label_anonymizing(self, labeled, labelname):
         """Eager loading works with subqueries with labels,
 
         Even if an explicit labelname which conflicts with a label on the
@@ -4859,75 +4866,65 @@ class SubqueryTest(fixtures.MappedTest):
             def prop_score(self):
                 return self.score1 * self.score2
 
-        for labeled, labelname in [
-            (True, "score"),
-            (True, None),
-            (False, None),
-        ]:
-            sa.orm.clear_mappers()
-
-            tag_score = tags_table.c.score1 * tags_table.c.score2
-            user_score = sa.select(
-                sa.func.sum(tags_table.c.score1 * tags_table.c.score2)
-            ).where(
-                tags_table.c.user_id == users_table.c.id,
-            )
+        tag_score = tags_table.c.score1 * tags_table.c.score2
+        user_score = sa.select(
+            sa.func.sum(tags_table.c.score1 * tags_table.c.score2)
+        ).where(
+            tags_table.c.user_id == users_table.c.id,
+        )
 
-            if labeled:
-                tag_score = tag_score.label(labelname)
-                user_score = user_score.label(labelname)
-            else:
-                user_score = user_score.scalar_subquery()
+        if labeled:
+            tag_score = tag_score.label(labelname)
+            user_score = user_score.label(labelname)
+        else:
+            user_score = user_score.scalar_subquery()
 
-            mapper(
-                Tag,
-                tags_table,
-                properties={"query_score": sa.orm.column_property(tag_score)},
-            )
+        mapper(
+            Tag,
+            tags_table,
+            properties={"query_score": sa.orm.column_property(tag_score)},
+        )
 
-            mapper(
-                User,
-                users_table,
-                properties={
-                    "tags": relationship(Tag, backref="user", lazy="joined"),
-                    "query_score": sa.orm.column_property(user_score),
-                },
-            )
+        mapper(
+            User,
+            users_table,
+            properties={
+                "tags": relationship(Tag, backref="user", lazy="joined"),
+                "query_score": sa.orm.column_property(user_score),
+            },
+        )
 
-            session = create_session()
-            session.add(
-                User(
-                    name="joe",
-                    tags=[
-                        Tag(score1=5.0, score2=3.0),
-                        Tag(score1=55.0, score2=1.0),
-                    ],
-                )
+        session = create_session()
+        session.add(
+            User(
+                name="joe",
+                tags=[
+                    Tag(score1=5.0, score2=3.0),
+                    Tag(score1=55.0, score2=1.0),
+                ],
             )
-            session.add(
-                User(
-                    name="bar",
-                    tags=[
-                        Tag(score1=5.0, score2=4.0),
-                        Tag(score1=50.0, score2=1.0),
-                        Tag(score1=15.0, score2=2.0),
-                    ],
-                )
+        )
+        session.add(
+            User(
+                name="bar",
+                tags=[
+                    Tag(score1=5.0, score2=4.0),
+                    Tag(score1=50.0, score2=1.0),
+                    Tag(score1=15.0, score2=2.0),
+                ],
             )
-            session.flush()
-            session.expunge_all()
-
-            for user in session.query(User).all():
-                eq_(user.query_score, user.prop_score)
+        )
+        session.flush()
+        session.expunge_all()
 
-            def go():
-                u = session.query(User).filter_by(name="joe").one()
-                eq_(u.query_score, u.prop_score)
+        for user in session.query(User).all():
+            eq_(user.query_score, user.prop_score)
 
-            self.assert_sql_count(testing.db, go, 1)
+        def go():
+            u = session.query(User).filter_by(name="joe").one()
+            eq_(u.query_score, u.prop_score)
 
-            for t in (tags_table, users_table):
-                t.delete().execute()
+        self.assert_sql_count(testing.db, go, 1)
 
 
 class CorrelatedSubqueryTest(fixtures.MappedTest):
index 7ccf2c1aee669af27ab2bc2e692ba4285ec93972..5abaa03db51d25f1c7531314ce7b6f109a108cd5 100644 (file)
@@ -9,7 +9,6 @@ from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.orm import attributes
-from sqlalchemy.orm import create_session
 from sqlalchemy.orm import defer
 from sqlalchemy.orm import deferred
 from sqlalchemy.orm import exc as orm_exc
@@ -26,6 +25,7 @@ from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import create_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.util import gc_collect
@@ -66,7 +66,7 @@ class ExpireTest(_fixtures.FixtureTest):
         u.name = "foo"
         sess.flush()
         # change the value in the DB
-        users.update(users.c.id == 7, values=dict(name="jack")).execute()
+        sess.execute(users.update(users.c.id == 7, values=dict(name="jack")))
         sess.expire(u)
         # object isn't refreshed yet, using dict to bypass trigger
         assert u.__dict__.get("name") != "jack"
@@ -471,7 +471,7 @@ class ExpireTest(_fixtures.FixtureTest):
         o = sess.query(Order).get(3)
         sess.expire(o)
 
-        orders.update().execute(description="order 3 modified")
+        sess.execute(orders.update(), dict(description="order 3 modified"))
         assert o.isopen == 1
         assert (
             attributes.instance_state(o).dict["description"]
@@ -788,7 +788,7 @@ class ExpireTest(_fixtures.FixtureTest):
         sess.expire(u)
         assert "name" not in u.__dict__
 
-        users.update(users.c.id == 7).execute(name="jack2")
+        sess.execute(users.update(users.c.id == 7), dict(name="jack2"))
         assert u.name == "jack2"
         assert u.uname == "jack2"
         assert "name" in u.__dict__
@@ -812,7 +812,10 @@ class ExpireTest(_fixtures.FixtureTest):
         assert "description" not in o.__dict__
         assert attributes.instance_state(o).dict["isopen"] == 1
 
-        orders.update(orders.c.id == 3).execute(description="order 3 modified")
+        sess.execute(
+            orders.update(orders.c.id == 3),
+            dict(description="order 3 modified"),
+        )
 
         def go():
             assert o.description == "order 3 modified"
@@ -1660,12 +1663,9 @@ class LifecycleTest(fixtures.MappedTest):
     def test_cols_missing_in_load(self):
         Data = self.classes.Data
 
-        sess = create_session()
-
-        d1 = Data(data="d1")
-        sess.add(d1)
-        sess.flush()
-        sess.close()
+        with Session(testing.db) as sess, sess.begin():
+            d1 = Data(data="d1")
+            sess.add(d1)
 
         sess = create_session()
         d1 = sess.query(Data).from_statement(select(Data.id)).first()
@@ -1679,21 +1679,18 @@ class LifecycleTest(fixtures.MappedTest):
     def test_deferred_cols_missing_in_load_state_reset(self):
         Data = self.classes.DataDefer
 
-        sess = create_session()
+        with Session(testing.db) as sess, sess.begin():
+            d1 = Data(data="d1")
+            sess.add(d1)
 
-        d1 = Data(data="d1")
-        sess.add(d1)
-        sess.flush()
-        sess.close()
-
-        sess = create_session()
-        d1 = (
-            sess.query(Data)
-            .from_statement(select(Data.id))
-            .options(undefer(Data.data))
-            .first()
-        )
-        d1.data = "d2"
+        with Session(testing.db) as sess:
+            d1 = (
+                sess.query(Data)
+                .from_statement(select(Data.id))
+                .options(undefer(Data.data))
+                .first()
+            )
+            d1.data = "d2"
 
         # the deferred loader has to clear out any state
         # on the col, including that 'd2' here
index c1cc85261f1f5109ad8bc13597cbb652c835c645..e1c0ec77b8327370c7326489d9451199dc704b43 100644 (file)
@@ -1302,18 +1302,22 @@ class O2MWOSideFixedTest(fixtures.MappedTest):
     def _fixture(self, include_other):
         city, person = self.tables.city, self.tables.person
 
-        if include_other:
-            city.insert().execute({"id": 1, "deleted": False})
-
-            person.insert().execute(
-                {"id": 1, "city_id": 1}, {"id": 2, "city_id": 1}
-            )
+        with testing.db.begin() as conn:
+            if include_other:
+                conn.execute(city.insert(), {"id": 1, "deleted": False})
+
+                conn.execute(
+                    person.insert(),
+                    {"id": 1, "city_id": 1},
+                    {"id": 2, "city_id": 1},
+                )
 
-        city.insert().execute({"id": 2, "deleted": True})
+            conn.execute(city.insert(), {"id": 2, "deleted": True})
 
-        person.insert().execute(
-            {"id": 3, "city_id": 2}, {"id": 4, "city_id": 2}
-        )
+            conn.execute(
+                person.insert(),
+                [{"id": 3, "city_id": 2}, {"id": 4, "city_id": 2}],
+            )
 
     def test_lazyload_assert_expected_sql(self):
         self._fixture(True)
index fc6caa75d4155519520246b235c88e74f2081241..edbb4b0cd0ea7fe991fb3b495bd527f0344a5df0 100644 (file)
@@ -129,7 +129,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
         )
         assert_raises(sa.exc.ArgumentError, sa.orm.configure_mappers)
 
-    def test_update_attr_keys(self):
+    def test_update_attr_keys(self, connection):
         """test that update()/insert() use the correct key when given
         InstrumentedAttributes."""
 
@@ -137,21 +137,21 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
 
         self.mapper(User, users, properties={"foobar": users.c.name})
 
-        users.insert().values({User.foobar: "name1"}).execute()
+        connection.execute(users.insert().values({User.foobar: "name1"}))
         eq_(
-            sa.select(User.foobar)
-            .where(User.foobar == "name1")
-            .execute()
-            .fetchall(),
+            connection.execute(
+                sa.select(User.foobar).where(User.foobar == "name1")
+            ).fetchall(),
             [("name1",)],
         )
 
-        users.update().values({User.foobar: User.foobar + "foo"}).execute()
+        connection.execute(
+            users.update().values({User.foobar: User.foobar + "foo"})
+        )
         eq_(
-            sa.select(User.foobar)
-            .where(User.foobar == "name1foo")
-            .execute()
-            .fetchall(),
+            connection.execute(
+                sa.select(User.foobar).where(User.foobar == "name1foo")
+            ).fetchall(),
             [("name1foo",)],
         )
 
index 87ec0d79d3e9daad5e4b5c464897d08f6e1ba62e..d814b0cab86c9821317b39a0856fc98bb7c6117e 100644 (file)
@@ -12,7 +12,6 @@ from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import TypeDecorator
-from sqlalchemy.orm import create_session
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
@@ -23,6 +22,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing.fixtures import create_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 from test.orm import _fixtures
@@ -141,7 +141,9 @@ class NaturalPKTest(fixtures.MappedTest):
         sess.flush()
         assert sess.query(User).get("jack") is u1
 
-        users.update(values={User.username: "jack"}).execute(username="ed")
+        sess.execute(
+            users.update(values={User.username: "jack"}), dict(username="ed")
+        )
 
         # expire/refresh works off of primary key.  the PK is gone
         # in this case so there's no way to look it up.  criterion-
@@ -1089,7 +1091,7 @@ class NonPKCascadeTest(fixtures.MappedTest):
         a1 = u1.addresses[0]
 
         eq_(
-            sa.select(addresses.c.username).execute().fetchall(),
+            sess.execute(sa.select(addresses.c.username)).fetchall(),
             [("jack",), ("jack",)],
         )
 
@@ -1099,7 +1101,7 @@ class NonPKCascadeTest(fixtures.MappedTest):
         sess.flush()
         assert u1.addresses[0].username == "ed"
         eq_(
-            sa.select(addresses.c.username).execute().fetchall(),
+            sess.execute(sa.select(addresses.c.username)).fetchall(),
             [("ed",), ("ed",)],
         )
 
@@ -1141,7 +1143,7 @@ class NonPKCascadeTest(fixtures.MappedTest):
         eq_(a1.username, None)
 
         eq_(
-            sa.select(addresses.c.username).execute().fetchall(),
+            sess.execute(sa.select(addresses.c.username)).fetchall(),
             [(None,), (None,)],
         )
 
@@ -1454,7 +1456,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
         eq_(a1.username, "ed")
         eq_(a2.username, "ed")
         eq_(
-            sa.select(addresses.c.username).execute().fetchall(),
+            sess.execute(sa.select(addresses.c.username)).fetchall(),
             [("ed",), ("ed",)],
         )
 
@@ -1465,7 +1467,7 @@ class CascadeToFKPKTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
         eq_(a1.username, "jack")
         eq_(a2.username, "jack")
         eq_(
-            sa.select(addresses.c.username).execute().fetchall(),
+            sess.execute(sa.select(addresses.c.username)).fetchall(),
             [("jack",), ("jack",)],
         )
 
index 8cca45b2708fe650ab43ad3c21c0e1fd9aac0f23..9e528dc0d42035a1450c1e1f149914559ff8de68 100644 (file)
@@ -806,7 +806,7 @@ class GetTest(QueryTest):
 
     @testing.provide_metadata
     @testing.requires.unicode_connections
-    def test_unicode(self):
+    def test_unicode(self, connection):
         """test that Query.get properly sets up the type for the bind
         parameter. using unicode would normally fail on postgresql, mysql and
         oracle unless it is converted to an encoded string"""
@@ -818,19 +818,20 @@ class GetTest(QueryTest):
             Column("id", Unicode(40), primary_key=True),
             Column("data", Unicode(40)),
         )
-        metadata.create_all()
+        metadata.create_all(connection)
         ustring = util.b("petit voix m\xe2\x80\x99a").decode("utf-8")
 
-        table.insert().execute(id=ustring, data=ustring)
+        connection.execute(table.insert(), dict(id=ustring, data=ustring))
 
         class LocalFoo(self.classes.Base):
             pass
 
         mapper(LocalFoo, table)
-        eq_(
-            create_session().query(LocalFoo).get(ustring),
-            LocalFoo(id=ustring, data=ustring),
-        )
+        with Session(connection) as sess:
+            eq_(
+                sess.get(LocalFoo, ustring),
+                LocalFoo(id=ustring, data=ustring),
+            )
 
     def test_populate_existing(self):
         User, Address = self.classes.User, self.classes.Address
index 165008234603ec047e352c4ed0ffe172ce8834db..d2838e5bf854cf866c931224259ae08c127c7d2e 100644 (file)
@@ -12,7 +12,6 @@ from sqlalchemy import testing
 from sqlalchemy.orm import attributes
 from sqlalchemy.orm import backref
 from sqlalchemy.orm import close_all_sessions
-from sqlalchemy.orm import create_session
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import make_transient
@@ -35,6 +34,7 @@ from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import pickleable
+from sqlalchemy.testing.fixtures import create_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.util import gc_collect
@@ -48,33 +48,33 @@ class ExecutionTest(_fixtures.FixtureTest):
     __backend__ = True
 
     @testing.requires.sequences
-    def test_sequence_execute(self):
+    def test_sequence_execute(self, connection):
         seq = Sequence("some_sequence")
-        seq.create(testing.db)
+        seq.create(connection)
         try:
-            sess = create_session(bind=testing.db)
-            eq_(sess.execute(seq), testing.db.dialect.default_sequence_base)
+            sess = Session(connection)
+            eq_(sess.execute(seq), connection.dialect.default_sequence_base)
         finally:
-            seq.drop(testing.db)
+            seq.drop(connection)
 
-    def test_textual_execute(self):
+    def test_textual_execute(self, connection):
         """test that Session.execute() converts to text()"""
 
         users = self.tables.users
 
-        sess = create_session(bind=self.metadata.bind)
-        users.insert().execute(id=7, name="jack")
+        with Session(bind=connection) as sess:
+            sess.execute(users.insert(), dict(id=7, name="jack"))
 
-        # use :bindparam style
-        eq_(
-            sess.execute(
-                "select * from users where id=:id", {"id": 7}
-            ).fetchall(),
-            [(7, "jack")],
-        )
+            # use :bindparam style
+            eq_(
+                sess.execute(
+                    "select * from users where id=:id", {"id": 7}
+                ).fetchall(),
+                [(7, "jack")],
+            )
 
-        # use :bindparam style
-        eq_(sess.scalar("select id from users where id=:id", {"id": 7}), 7)
+            # use :bindparam style
+            eq_(sess.scalar("select id from users where id=:id", {"id": 7}), 7)
 
     def test_parameter_execute(self):
         users = self.tables.users
@@ -104,7 +104,7 @@ class TransScopingTest(_fixtures.FixtureTest):
         c.exec_driver_sql("select * from users")
 
         mapper(User, users)
-        s = create_session(bind=c)
+        s = Session(bind=c)
         s.add(User(name="first"))
         s.flush()
         c.exec_driver_sql("select * from users")
@@ -118,7 +118,7 @@ class TransScopingTest(_fixtures.FixtureTest):
         c.exec_driver_sql("select * from users")
 
         mapper(User, users)
-        s = create_session(bind=c)
+        s = Session(bind=c)
         s.add(User(name="first"))
         s.flush()
         c.exec_driver_sql("select * from users")
@@ -189,7 +189,7 @@ class TransScopingTest(_fixtures.FixtureTest):
         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
 
-        sess = create_session(autocommit=False, bind=conn1)
+        sess = Session(autocommit=False, bind=conn1)
         u = User(name="x")
         sess.add(u)
         sess.flush()
@@ -415,7 +415,7 @@ class SessionStateTest(_fixtures.FixtureTest):
         conn1 = bind.connect()
         conn2 = bind.connect()
 
-        sess = create_session(bind=conn1, autocommit=False, autoflush=True)
+        sess = Session(bind=conn1, autocommit=False, autoflush=True)
         u = User()
         u.name = "ed"
         sess.add(u)
@@ -600,7 +600,7 @@ class SessionStateTest(_fixtures.FixtureTest):
 
         mapper(User, users)
         conn1 = testing.db.connect()
-        sess = create_session(bind=conn1, autocommit=False, autoflush=True)
+        sess = Session(bind=conn1, autocommit=False, autoflush=True)
         u = User()
         u.name = "ed"
         sess.add(u)
@@ -620,7 +620,7 @@ class SessionStateTest(_fixtures.FixtureTest):
         User, users = self.classes.User, self.tables.users
 
         mapper(User, users)
-        session = create_session(autocommit=True)
+        session = Session(testing.db, autocommit=True)
 
         session.add(User(name="ed"))
 
@@ -629,7 +629,7 @@ class SessionStateTest(_fixtures.FixtureTest):
         session.commit()
 
     def test_active_flag_autocommit(self):
-        sess = create_session(bind=config.db, autocommit=True)
+        sess = Session(bind=config.db, autocommit=True)
         assert not sess.is_active
         sess.begin()
         assert sess.is_active
@@ -637,7 +637,7 @@ class SessionStateTest(_fixtures.FixtureTest):
         assert not sess.is_active
 
     def test_active_flag_autobegin(self):
-        sess = create_session(bind=config.db, autocommit=False)
+        sess = Session(bind=config.db, autocommit=False)
         assert sess.is_active
         assert not sess.in_transaction()
         sess.begin()
@@ -646,7 +646,7 @@ class SessionStateTest(_fixtures.FixtureTest):
         assert sess.is_active
 
     def test_active_flag_autobegin_future(self):
-        sess = create_session(bind=config.db, future=True)
+        sess = Session(bind=config.db, future=True)
         assert sess.is_active
         assert not sess.in_transaction()
         sess.begin()
@@ -655,7 +655,7 @@ class SessionStateTest(_fixtures.FixtureTest):
         assert sess.is_active
 
     def test_active_flag_partial_rollback(self):
-        sess = create_session(bind=config.db, autocommit=False)
+        sess = Session(bind=config.db, autocommit=False)
         assert sess.is_active
         assert not sess.in_transaction()
         sess.begin()
@@ -693,7 +693,7 @@ class SessionStateTest(_fixtures.FixtureTest):
         )
 
         s.add(user)
-        s.flush()
+        s.commit()
         user = s.query(User).one()
         s.expunge(user)
         assert user not in s
@@ -703,8 +703,7 @@ class SessionStateTest(_fixtures.FixtureTest):
         s.add(user)
         assert user in s
         assert user in s.dirty
-        s.flush()
-        s.expunge_all()
+        s.commit()
         assert s.query(User).count() == 1
         user = s.query(User).one()
         assert user.name == "fred"
@@ -766,8 +765,9 @@ class SessionStateTest(_fixtures.FixtureTest):
         users, User = self.tables.users, self.classes.User
 
         mapper(User, users)
-        for s in (create_session(), create_session()):
-            users.delete().execute()
+
+        with create_session() as s:
+            s.execute(users.delete())
             u1 = User(name="ed")
             s.add(u1)
             s.flush()
@@ -1774,7 +1774,8 @@ class DisposedStates(fixtures.MappedTest):
 
     def _test_session(self, **kwargs):
         T = self.classes.T
-        sess = create_session(**kwargs)
+
+        sess = Session(config.db, **kwargs)
 
         data = o1, o2, o3, o4, o5 = [
             T("t1"),
@@ -1786,7 +1787,7 @@ class DisposedStates(fixtures.MappedTest):
 
         sess.add_all(data)
 
-        sess.flush()
+        sess.commit()
 
         o1.data = "t1modified"
         o5.data = "t5modified"
@@ -1925,7 +1926,7 @@ class SessionInterface(fixtures.TestBase):
 
         def raises_(method, *args, **kw):
             watchdog.add(method)
-            callable_ = getattr(create_session(), method)
+            callable_ = getattr(Session(), method)
             if is_class:
                 assert_raises(
                     sa.orm.exc.UnmappedClassError, callable_, *args, **kw
index e8f6c5c4052daf451abcc82b370c651d59e0aee8..248f334cf6a520e7af53e9f15258b52ff496a2dd 100644 (file)
@@ -1951,9 +1951,7 @@ class AccountingFlagsTest(_LocalFixture):
         sess.add(u1)
         sess.commit()
 
-        testing.db.execute(
-            users.update(users.c.name == "ed").values(name="edward")
-        )
+        sess.execute(users.update(users.c.name == "ed").values(name="edward"))
 
         assert u1.name == "ed"
         sess.expire_all()
index ed320db10426b89baed301b11a575981616633c6..31386b07f56a9e36fc96cab67eb225376e5eaef2 100644 (file)
@@ -778,7 +778,8 @@ class SingleCycleTest(UOWTest):
         # mysql can't handle delete from nodes
         # since it doesn't deal with the FKs correctly,
         # so wipe out the parent_id first
-        testing.db.execute(self.tables.nodes.update().values(parent_id=None))
+        with testing.db.begin() as conn:
+            conn.execute(self.tables.nodes.update().values(parent_id=None))
         super(SingleCycleTest, self).teardown()
 
     def test_one_to_many_save(self):
index 4a6ebd0c83cfc35d60743724f929241e17e65eec..2a2e70bc3941551948758c49a8ce18f7d5922329 100644 (file)
@@ -1012,9 +1012,7 @@ class PKIncrementTest(fixtures.TablesTest):
             Column("str1", String(20)),
         )
 
-    # TODO: add coverage for increment on a secondary column in a key
-    @testing.fails_on("firebird", "Data type unknown")
-    def _test_autoincrement(self, connection):
+    def test_autoincrement(self, connection):
         aitable = self.tables.aitable
 
         ids = set()
@@ -1064,14 +1062,6 @@ class PKIncrementTest(fixtures.TablesTest):
             ],
         )
 
-    def test_autoincrement_autocommit(self):
-        with testing.db.connect() as conn:
-            self._test_autoincrement(conn)
-
-    def test_autoincrement_transaction(self):
-        with testing.db.begin() as conn:
-            self._test_autoincrement(conn)
-
 
 class EmptyInsertTest(fixtures.TestBase):
     __backend__ = True
@@ -1267,7 +1257,7 @@ class SpecialTypePKTest(fixtures.TestBase):
             implicit_returning=implicit_returning,
         )
 
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             t.create(conn)
             r = conn.execute(t.insert().values(data=5))
 
index 934022560f461110b5eab958a031226dea156fe1..6f7b3f8f5d80659daf4bf96a8d909dc9cacb281a 100644 (file)
@@ -308,32 +308,31 @@ class DeleteFromRoundTripTest(fixtures.TablesTest):
         )
 
     @testing.requires.delete_from
-    def test_exec_two_table(self):
+    def test_exec_two_table(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
         dingalings = self.tables.dingalings
 
-        with testing.db.connect() as conn:
-            conn.execute(dingalings.delete())  # fk violation otherwise
+        connection.execute(dingalings.delete())  # fk violation otherwise
 
-            conn.execute(
-                addresses.delete()
-                .where(users.c.id == addresses.c.user_id)
-                .where(users.c.name == "ed")
-            )
+        connection.execute(
+            addresses.delete()
+            .where(users.c.id == addresses.c.user_id)
+            .where(users.c.name == "ed")
+        )
 
-            expected = [
-                (1, 7, "x", "jack@bean.com"),
-                (5, 9, "x", "fred@fred.com"),
-            ]
-        self._assert_table(addresses, expected)
+        expected = [
+            (1, 7, "x", "jack@bean.com"),
+            (5, 9, "x", "fred@fred.com"),
+        ]
+        self._assert_table(connection, addresses, expected)
 
     @testing.requires.delete_from
-    def test_exec_three_table(self):
+    def test_exec_three_table(self, connection):
         users = self.tables.users
         addresses = self.tables.addresses
         dingalings = self.tables.dingalings
 
-        testing.db.execute(
+        connection.execute(
             dingalings.delete()
             .where(users.c.id == addresses.c.user_id)
             .where(users.c.name == "ed")
@@ -341,34 +340,33 @@ class DeleteFromRoundTripTest(fixtures.TablesTest):
         )
 
         expected = [(2, 5, "ding 2/5")]
-        self._assert_table(dingalings, expected)
+        self._assert_table(connection, dingalings, expected)
 
     @testing.requires.delete_from
-    def test_exec_two_table_plus_alias(self):
+    def test_exec_two_table_plus_alias(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
         dingalings = self.tables.dingalings
 
-        with testing.db.connect() as conn:
-            conn.execute(dingalings.delete())  # fk violation otherwise
-            a1 = addresses.alias()
-            conn.execute(
-                addresses.delete()
-                .where(users.c.id == addresses.c.user_id)
-                .where(users.c.name == "ed")
-                .where(a1.c.id == addresses.c.id)
-            )
+        connection.execute(dingalings.delete())  # fk violation otherwise
+        a1 = addresses.alias()
+        connection.execute(
+            addresses.delete()
+            .where(users.c.id == addresses.c.user_id)
+            .where(users.c.name == "ed")
+            .where(a1.c.id == addresses.c.id)
+        )
 
         expected = [(1, 7, "x", "jack@bean.com"), (5, 9, "x", "fred@fred.com")]
-        self._assert_table(addresses, expected)
+        self._assert_table(connection, addresses, expected)
 
     @testing.requires.delete_from
-    def test_exec_alias_plus_table(self):
+    def test_exec_alias_plus_table(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
         dingalings = self.tables.dingalings
 
         d1 = dingalings.alias()
 
-        testing.db.execute(
+        connection.execute(
             delete(d1)
             .where(users.c.id == addresses.c.user_id)
             .where(users.c.name == "ed")
@@ -376,8 +374,8 @@ class DeleteFromRoundTripTest(fixtures.TablesTest):
         )
 
         expected = [(2, 5, "ding 2/5")]
-        self._assert_table(dingalings, expected)
+        self._assert_table(connection, dingalings, expected)
 
-    def _assert_table(self, table, expected):
+    def _assert_table(self, connection, table, expected):
         stmt = table.select().order_by(table.c.id)
-        eq_(testing.db.execute(stmt).fetchall(), expected)
+        eq_(connection.execute(stmt).fetchall(), expected)
index c0d2e87e8640d1d95ecbdf7f4e080e790956ac4c..e082cf55d09b02d3ec5d3e0b9d126f661da15526 100644 (file)
@@ -23,6 +23,7 @@ from sqlalchemy import MetaData
 from sqlalchemy import null
 from sqlalchemy import or_
 from sqlalchemy import select
+from sqlalchemy import Sequence
 from sqlalchemy import sql
 from sqlalchemy import String
 from sqlalchemy import table
@@ -1271,6 +1272,165 @@ class KeyTargetingTest(fixtures.TablesTest):
             in_(stmt.c.keyed2_b, row)
 
 
+class PKIncrementTest(fixtures.TablesTest):
+    run_define_tables = "each"
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "aitable",
+            metadata,
+            Column(
+                "id",
+                Integer,
+                Sequence("ai_id_seq", optional=True),
+                primary_key=True,
+            ),
+            Column("int1", Integer),
+            Column("str1", String(20)),
+        )
+
+    def _test_autoincrement(self, connection):
+        aitable = self.tables.aitable
+
+        ids = set()
+        rs = connection.execute(aitable.insert(), int1=1)
+        last = rs.inserted_primary_key[0]
+        self.assert_(last)
+        self.assert_(last not in ids)
+        ids.add(last)
+
+        rs = connection.execute(aitable.insert(), str1="row 2")
+        last = rs.inserted_primary_key[0]
+        self.assert_(last)
+        self.assert_(last not in ids)
+        ids.add(last)
+
+        rs = connection.execute(aitable.insert(), int1=3, str1="row 3")
+        last = rs.inserted_primary_key[0]
+        self.assert_(last)
+        self.assert_(last not in ids)
+        ids.add(last)
+
+        rs = connection.execute(
+            aitable.insert().values({"int1": func.length("four")})
+        )
+        last = rs.inserted_primary_key[0]
+        self.assert_(last)
+        self.assert_(last not in ids)
+        ids.add(last)
+
+        eq_(
+            ids,
+            set(
+                range(
+                    testing.db.dialect.default_sequence_base,
+                    testing.db.dialect.default_sequence_base + 4,
+                )
+            ),
+        )
+
+        eq_(
+            list(connection.execute(aitable.select().order_by(aitable.c.id))),
+            [
+                (testing.db.dialect.default_sequence_base, 1, None),
+                (testing.db.dialect.default_sequence_base + 1, None, "row 2"),
+                (testing.db.dialect.default_sequence_base + 2, 3, "row 3"),
+                (testing.db.dialect.default_sequence_base + 3, 4, None),
+            ],
+        )
+
+    def test_autoincrement_autocommit(self):
+        with testing.db.connect() as conn:
+            with testing.expect_deprecated_20(
+                "The current statement is being autocommitted using "
+                "implicit autocommit, "
+            ):
+                self._test_autoincrement(conn)
+
+
+class ConnectionlessCursorResultTest(fixtures.TablesTest):
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "users",
+            metadata,
+            Column(
+                "user_id", INT, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("user_name", VARCHAR(20)),
+            test_needs_acid=True,
+        )
+
+    def test_connectionless_autoclose_rows_exhausted(self):
+        users = self.tables.users
+        with testing.db.begin() as conn:
+            conn.execute(users.insert(), dict(user_id=1, user_name="john"))
+
+        with testing.expect_deprecated_20(
+            r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method"
+        ):
+            result = testing.db.execute(text("select * from users"))
+        connection = result.connection
+        assert not connection.closed
+        eq_(result.fetchone(), (1, "john"))
+        assert not connection.closed
+        eq_(result.fetchone(), None)
+        assert connection.closed
+
+    @testing.requires.returning
+    def test_connectionless_autoclose_crud_rows_exhausted(self):
+        users = self.tables.users
+        stmt = (
+            users.insert()
+            .values(user_id=1, user_name="john")
+            .returning(users.c.user_id)
+        )
+        with testing.expect_deprecated_20(
+            r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method"
+        ):
+            result = testing.db.execute(stmt)
+        connection = result.connection
+        assert not connection.closed
+        eq_(result.fetchone(), (1,))
+        assert not connection.closed
+        eq_(result.fetchone(), None)
+        assert connection.closed
+
+    def test_connectionless_autoclose_no_rows(self):
+        with testing.expect_deprecated_20(
+            r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method"
+        ):
+            result = testing.db.execute(text("select * from users"))
+        connection = result.connection
+        assert not connection.closed
+        eq_(result.fetchone(), None)
+        assert connection.closed
+
+    @testing.requires.updateable_autoincrement_pks
+    def test_connectionless_autoclose_no_metadata(self):
+        with testing.expect_deprecated_20(
+            r"The (?:Executable|Engine)\.(?:execute|scalar)\(\) method"
+        ):
+            result = testing.db.execute(text("update users set user_id=5"))
+        connection = result.connection
+        assert connection.closed
+
+        assert_raises_message(
+            exc.ResourceClosedError,
+            "This result object does not return rows.",
+            result.fetchone,
+        )
+        assert_raises_message(
+            exc.ResourceClosedError,
+            "This result object does not return rows.",
+            result.keys,
+        )
+
+
 class CursorResultTest(fixtures.TablesTest):
     __backend__ = True
 
@@ -1436,7 +1596,7 @@ class CursorResultTest(fixtures.TablesTest):
     def test_pickled_rows(self):
         users = self.tables.users
         addresses = self.tables.addresses
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.execute(users.delete())
             conn.execute(
                 users.insert(),
@@ -2319,3 +2479,93 @@ class LegacyOperatorTest(AssertsCompiledSQL, fixtures.TestBase):
         _op_modern = getattr(operators.ColumnOperators, _modern)
         _op_legacy = getattr(operators.ColumnOperators, _legacy)
         assert _op_modern == _op_legacy
+
+
+class LegacySequenceExecTest(fixtures.TestBase):
+    __requires__ = ("sequences",)
+    __backend__ = True
+
+    @classmethod
+    def setup_class(cls):
+        cls.seq = Sequence("my_sequence")
+        cls.seq.create(testing.db)
+
+    @classmethod
+    def teardown_class(cls):
+        cls.seq.drop(testing.db)
+
+    def _assert_seq_result(self, ret):
+        """asserts return of next_value is an int"""
+
+        assert isinstance(ret, util.int_types)
+        assert ret >= testing.db.dialect.default_sequence_base
+
+    def test_implicit_connectionless(self):
+        with testing.expect_deprecated_20(
+            r"The MetaData.bind argument is deprecated"
+        ):
+            s = Sequence("my_sequence", metadata=MetaData(testing.db))
+
+        with testing.expect_deprecated_20(
+            r"The DefaultGenerator.execute\(\) method is considered legacy "
+            "as of the 1.x",
+        ):
+            self._assert_seq_result(s.execute())
+
+    def test_explicit(self, connection):
+        s = Sequence("my_sequence")
+        with testing.expect_deprecated_20(
+            r"The DefaultGenerator.execute\(\) method is considered legacy"
+        ):
+            self._assert_seq_result(s.execute(connection))
+
+    def test_explicit_optional(self):
+        """test dialect executes a Sequence, returns nextval, whether
+        or not "optional" is set"""
+
+        s = Sequence("my_sequence", optional=True)
+        with testing.expect_deprecated_20(
+            r"The DefaultGenerator.execute\(\) method is considered legacy"
+        ):
+            self._assert_seq_result(s.execute(testing.db))
+
+    def test_func_implicit_connectionless_execute(self):
+        """test func.next_value().execute()/.scalar() works
+        with connectionless execution."""
+
+        with testing.expect_deprecated_20(
+            r"The MetaData.bind argument is deprecated"
+        ):
+            s = Sequence("my_sequence", metadata=MetaData(testing.db))
+        with testing.expect_deprecated_20(
+            r"The Executable.execute\(\) method is considered legacy"
+        ):
+            self._assert_seq_result(s.next_value().execute().scalar())
+
+    def test_func_explicit(self):
+        s = Sequence("my_sequence")
+        with testing.expect_deprecated_20(
+            r"The Engine.scalar\(\) method is considered legacy"
+        ):
+            self._assert_seq_result(testing.db.scalar(s.next_value()))
+
+    def test_func_implicit_connectionless_scalar(self):
+        """test func.next_value().execute()/.scalar() works. """
+
+        with testing.expect_deprecated_20(
+            r"The MetaData.bind argument is deprecated"
+        ):
+            s = Sequence("my_sequence", metadata=MetaData(testing.db))
+        with testing.expect_deprecated_20(
+            r"The Executable.execute\(\) method is considered legacy"
+        ):
+            self._assert_seq_result(s.next_value().scalar())
+
+    def test_func_embedded_select(self):
+        """test can use next_value() in select column expr"""
+
+        s = Sequence("my_sequence")
+        with testing.expect_deprecated_20(
+            r"The Engine.scalar\(\) method is considered legacy"
+        ):
+            self._assert_seq_result(testing.db.scalar(select(s.next_value())))
index 7d05462abb2fae04af4622023502a7c74bbd7832..6d26f79758c9a1116eccf0192e138cedbb06a609 100644 (file)
@@ -84,7 +84,7 @@ class QueryTest(fixtures.TestBase):
 
     @engines.close_first
     def teardown(self):
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.execute(addresses.delete())
             conn.execute(users.delete())
             conn.execute(users2.delete())
@@ -878,21 +878,22 @@ class RequiredBindTest(fixtures.TablesTest):
         )
 
     def _assert_raises(self, stmt, params):
-        assert_raises_message(
-            exc.StatementError,
-            "A value is required for bind parameter 'x'",
-            testing.db.execute,
-            stmt,
-            **params
-        )
+        with testing.db.connect() as conn:
+            assert_raises_message(
+                exc.StatementError,
+                "A value is required for bind parameter 'x'",
+                conn.execute,
+                stmt,
+                **params
+            )
 
-        assert_raises_message(
-            exc.StatementError,
-            "A value is required for bind parameter 'x'",
-            testing.db.execute,
-            stmt,
-            params,
-        )
+            assert_raises_message(
+                exc.StatementError,
+                "A value is required for bind parameter 'x'",
+                conn.execute,
+                stmt,
+                params,
+            )
 
     def test_insert(self):
         stmt = self.tables.foo.insert().values(
@@ -953,7 +954,7 @@ class LimitTest(fixtures.TestBase):
         )
         metadata.create_all()
 
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.execute(users.insert(), user_id=1, user_name="john")
             conn.execute(
                 addresses.insert(), address_id=1, user_id=1, address="addr1"
@@ -1105,7 +1106,7 @@ class CompoundTest(fixtures.TestBase):
         )
         metadata.create_all()
 
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.execute(
                 t1.insert(),
                 [
@@ -1470,7 +1471,7 @@ class JoinTest(fixtures.TestBase):
         metadata.drop_all()
         metadata.create_all()
 
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             # t1.10 -> t2.20 -> t3.30
             # t1.11 -> t2.21
             # t1.12
@@ -1823,7 +1824,7 @@ class OperatorTest(fixtures.TestBase):
         )
         metadata.create_all()
 
-        with testing.db.connect() as conn:
+        with testing.db.begin() as conn:
             conn.execute(
                 flds.insert(),
                 [dict(intcol=5, strcol="foo"), dict(intcol=13, strcol="bar")],
index 1c023e7b1f7451007f2bcbf538130bdea89fbe73..a78d6c16b5a5bd6ab20e2338ff8d5a42eef2e4b9 100644 (file)
@@ -25,19 +25,12 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing.util import picklers
 
 
-class QuoteExecTest(fixtures.TestBase):
+class QuoteExecTest(fixtures.TablesTest):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
-        # TODO: figure out which databases/which identifiers allow special
-        # characters to be used, such as: spaces, quote characters,
-        # punctuation characters, set up tests for those as well.
-
-        global table1, table2
-        metadata = MetaData(testing.db)
-
-        table1 = Table(
+    def define_tables(cls, metadata):
+        Table(
             "WorstCase1",
             metadata,
             Column("lowercase", Integer, primary_key=True),
@@ -45,7 +38,7 @@ class QuoteExecTest(fixtures.TestBase):
             Column("MixedCase", Integer),
             Column("ASC", Integer, key="a123"),
         )
-        table2 = Table(
+        Table(
             "WorstCase2",
             metadata,
             Column("desc", Integer, primary_key=True, key="d123"),
@@ -53,18 +46,6 @@ class QuoteExecTest(fixtures.TestBase):
             Column("MixedCase", Integer),
         )
 
-        table1.create()
-        table2.create()
-
-    def teardown(self):
-        table1.delete().execute()
-        table2.delete().execute()
-
-    @classmethod
-    def teardown_class(cls):
-        table1.drop()
-        table2.drop()
-
     def test_reflect(self):
         meta2 = MetaData()
         t2 = Table("WorstCase1", meta2, autoload_with=testing.db, quote=True)
@@ -88,25 +69,22 @@ class QuoteExecTest(fixtures.TestBase):
         assert "MixedCase" in t2.c
 
     @testing.provide_metadata
-    def test_has_table_case_sensitive(self):
+    def test_has_table_case_sensitive(self, connection):
         preparer = testing.db.dialect.identifier_preparer
-        with testing.db.connect() as conn:
-            if conn.dialect.requires_name_normalize:
-                conn.exec_driver_sql("CREATE TABLE TAB1 (id INTEGER)")
-            else:
-                conn.exec_driver_sql("CREATE TABLE tab1 (id INTEGER)")
-            conn.exec_driver_sql(
-                "CREATE TABLE %s (id INTEGER)"
-                % preparer.quote_identifier("tab2")
-            )
-            conn.exec_driver_sql(
-                "CREATE TABLE %s (id INTEGER)"
-                % preparer.quote_identifier("TAB3")
-            )
-            conn.exec_driver_sql(
-                "CREATE TABLE %s (id INTEGER)"
-                % preparer.quote_identifier("TAB4")
-            )
+        conn = connection
+        if conn.dialect.requires_name_normalize:
+            conn.exec_driver_sql("CREATE TABLE TAB1 (id INTEGER)")
+        else:
+            conn.exec_driver_sql("CREATE TABLE tab1 (id INTEGER)")
+        conn.exec_driver_sql(
+            "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("tab2")
+        )
+        conn.exec_driver_sql(
+            "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("TAB3")
+        )
+        conn.exec_driver_sql(
+            "CREATE TABLE %s (id INTEGER)" % preparer.quote_identifier("TAB4")
+        )
 
         t1 = Table(
             "tab1", self.metadata, Column("id", Integer, primary_key=True)
@@ -127,7 +105,7 @@ class QuoteExecTest(fixtures.TestBase):
             quote=True,
         )
 
-        insp = inspect(testing.db)
+        insp = inspect(connection)
         assert insp.has_table(t1.name)
         eq_([c["name"] for c in insp.get_columns(t1.name)], ["id"])
 
@@ -140,16 +118,24 @@ class QuoteExecTest(fixtures.TestBase):
         assert insp.has_table(t4.name)
         eq_([c["name"] for c in insp.get_columns(t4.name)], ["id"])
 
-    def test_basic(self):
-        table1.insert().execute(
-            {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
-            {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
-            {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1},
+    def test_basic(self, connection):
+        table1, table2 = self.tables("WorstCase1", "WorstCase2")
+
+        connection.execute(
+            table1.insert(),
+            [
+                {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
+                {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
+                {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1},
+            ],
         )
-        table2.insert().execute(
-            {"d123": 1, "u123": 2, "MixedCase": 3},
-            {"d123": 2, "u123": 2, "MixedCase": 3},
-            {"d123": 4, "u123": 3, "MixedCase": 2},
+        connection.execute(
+            table2.insert(),
+            [
+                {"d123": 1, "u123": 2, "MixedCase": 3},
+                {"d123": 2, "u123": 2, "MixedCase": 3},
+                {"d123": 4, "u123": 3, "MixedCase": 2},
+            ],
         )
 
         columns = [
@@ -158,23 +144,30 @@ class QuoteExecTest(fixtures.TestBase):
             table1.c.MixedCase,
             table1.c.a123,
         ]
-        result = select(columns).execute().fetchall()
+        result = connection.execute(select(columns)).all()
         assert result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)]
 
         columns = [table2.c.d123, table2.c.u123, table2.c.MixedCase]
-        result = select(columns).execute().fetchall()
+        result = connection.execute(select(columns)).all()
         assert result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)]
 
-    def test_use_labels(self):
-        table1.insert().execute(
-            {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
-            {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
-            {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1},
-        )
-        table2.insert().execute(
-            {"d123": 1, "u123": 2, "MixedCase": 3},
-            {"d123": 2, "u123": 2, "MixedCase": 3},
-            {"d123": 4, "u123": 3, "MixedCase": 2},
+    def test_use_labels(self, connection):
+        table1, table2 = self.tables("WorstCase1", "WorstCase2")
+        connection.execute(
+            table1.insert(),
+            [
+                {"lowercase": 1, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
+                {"lowercase": 2, "UPPERCASE": 2, "MixedCase": 3, "a123": 4},
+                {"lowercase": 4, "UPPERCASE": 3, "MixedCase": 2, "a123": 1},
+            ],
+        )
+        connection.execute(
+            table2.insert(),
+            [
+                {"d123": 1, "u123": 2, "MixedCase": 3},
+                {"d123": 2, "u123": 2, "MixedCase": 3},
+                {"d123": 4, "u123": 3, "MixedCase": 2},
+            ],
         )
 
         columns = [
@@ -183,11 +176,11 @@ class QuoteExecTest(fixtures.TestBase):
             table1.c.MixedCase,
             table1.c.a123,
         ]
-        result = select(columns, use_labels=True).execute().fetchall()
+        result = connection.execute(select(columns).apply_labels()).fetchall()
         assert result == [(1, 2, 3, 4), (2, 2, 3, 4), (4, 3, 2, 1)]
 
         columns = [table2.c.d123, table2.c.u123, table2.c.MixedCase]
-        result = select(columns, use_labels=True).execute().fetchall()
+        result = connection.execute(select(columns).apply_labels()).all()
         assert result == [(1, 2, 3), (2, 2, 3), (4, 3, 2)]
 
 
index 9ef533be3a555e988822757683fca7816ab6eecb..db0e0d4c8160df1adc1ff8e10f2b83a66777afe2 100644 (file)
@@ -615,63 +615,6 @@ class CursorResultTest(fixtures.TablesTest):
             result.fetchone,
         )
 
-    def test_connectionless_autoclose_rows_exhausted(self):
-        # TODO: deprecate for 2.0
-        users = self.tables.users
-        with testing.db.connect() as conn:
-            conn.execute(users.insert(), dict(user_id=1, user_name="john"))
-
-        result = testing.db.execute(text("select * from users"))
-        connection = result.connection
-        assert not connection.closed
-        eq_(result.fetchone(), (1, "john"))
-        assert not connection.closed
-        eq_(result.fetchone(), None)
-        assert connection.closed
-
-    @testing.requires.returning
-    def test_connectionless_autoclose_crud_rows_exhausted(self):
-        # TODO: deprecate for 2.0
-        users = self.tables.users
-        stmt = (
-            users.insert()
-            .values(user_id=1, user_name="john")
-            .returning(users.c.user_id)
-        )
-        result = testing.db.execute(stmt)
-        connection = result.connection
-        assert not connection.closed
-        eq_(result.fetchone(), (1,))
-        assert not connection.closed
-        eq_(result.fetchone(), None)
-        assert connection.closed
-
-    def test_connectionless_autoclose_no_rows(self):
-        # TODO: deprecate for 2.0
-        result = testing.db.execute(text("select * from users"))
-        connection = result.connection
-        assert not connection.closed
-        eq_(result.fetchone(), None)
-        assert connection.closed
-
-    @testing.requires.updateable_autoincrement_pks
-    def test_connectionless_autoclose_no_metadata(self):
-        # TODO: deprecate for 2.0
-        result = testing.db.execute(text("update users set user_id=5"))
-        connection = result.connection
-        assert connection.closed
-
-        assert_raises_message(
-            exc.ResourceClosedError,
-            "This result object does not return rows.",
-            result.fetchone,
-        )
-        assert_raises_message(
-            exc.ResourceClosedError,
-            "This result object does not return rows.",
-            result.keys,
-        )
-
     def test_row_case_sensitive(self, connection):
         row = connection.execute(
             select(
@@ -1285,7 +1228,7 @@ class CursorResultTest(fixtures.TablesTest):
         with patch.object(
             engine.dialect.execution_ctx_cls, "rowcount"
         ) as mock_rowcount:
-            with engine.connect() as conn:
+            with engine.begin() as conn:
                 mock_rowcount.__get__ = Mock()
                 conn.execute(
                     t.insert(), {"data": "d1"}, {"data": "d2"}, {"data": "d3"}
@@ -1362,20 +1305,14 @@ class CursorResultTest(fixtures.TablesTest):
         eq_(row[1:0:-1], ("Uno",))
 
     @testing.requires.cextensions
-    def test_row_c_sequence_check(self):
-        # TODO: modernize for 2.0
-        metadata = MetaData()
-        metadata.bind = "sqlite://"
-        users = Table(
-            "users",
-            metadata,
-            Column("id", Integer, primary_key=True),
-            Column("name", String(40)),
-        )
-        users.create()
+    @testing.provide_metadata
+    def test_row_c_sequence_check(self, connection):
+        users = self.tables.users2
 
-        users.insert().execute(name="Test")
-        row = users.select().execute().fetchone()
+        connection.execute(users.insert(), dict(user_id=1, user_name="Test"))
+        row = connection.execute(
+            users.select().where(users.c.user_id == 1)
+        ).fetchone()
 
         s = util.StringIO()
         writer = csv.writer(s)
@@ -2340,7 +2277,7 @@ class AlternateCursorResultTest(fixtures.TablesTest):
     @testing.fixture
     def row_growth_fixture(self):
         with self._proxy_fixture(_cursor.BufferedRowCursorFetchStrategy):
-            with self.engine.connect() as conn:
+            with self.engine.begin() as conn:
                 conn.execute(
                     self.table.insert(),
                     [{"x": i, "y": "t_%d" % i} for i in range(15, 3000)],
index 065205c45a90ee3581f4fd0f2e03c2b137f23dbe..9f2afd7b7dad1fc33b8c96f3a14b0f17ae29701f 100644 (file)
@@ -23,9 +23,6 @@ from sqlalchemy.testing.schema import Table
 from sqlalchemy.types import TypeDecorator
 
 
-table = GoofyType = seq = None
-
-
 class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "postgresql"
 
@@ -92,14 +89,14 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
         )
 
 
-class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
+class ReturningTest(fixtures.TablesTest, AssertsExecutionResults):
     __requires__ = ("returning",)
     __backend__ = True
 
-    def setup(self):
-        meta = MetaData(testing.db)
-        global table, GoofyType
+    run_create_tables = "each"
 
+    @classmethod
+    def define_tables(cls, metadata):
         class GoofyType(TypeDecorator):
             impl = String
 
@@ -113,9 +110,11 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
                     return None
                 return value + "BAR"
 
-        table = Table(
+        cls.GoofyType = GoofyType
+
+        Table(
             "tables",
-            meta,
+            metadata,
             Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
             ),
@@ -123,14 +122,9 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
             Column("full", Boolean),
             Column("goofy", GoofyType(50)),
         )
-        with testing.db.connect() as conn:
-            table.create(conn, checkfirst=True)
-
-    def teardown(self):
-        with testing.db.connect() as conn:
-            table.drop(conn)
 
     def test_column_targeting(self, connection):
+        table = self.tables.tables
         result = connection.execute(
             table.insert().returning(table.c.id, table.c.full),
             {"persons": 1, "full": False},
@@ -155,6 +149,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
 
     @testing.fails_on("firebird", "fb can't handle returning x AS y")
     def test_labeling(self, connection):
+        table = self.tables.tables
         result = connection.execute(
             table.insert()
             .values(persons=6)
@@ -167,6 +162,8 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
         "firebird", "fb/kintersbasdb can't handle the bind params"
     )
     def test_anon_expressions(self, connection):
+        table = self.tables.tables
+        GoofyType = self.GoofyType
         result = connection.execute(
             table.insert()
             .values(goofy="someOTHERgoofy")
@@ -182,6 +179,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(row[0], 30)
 
     def test_update_returning(self, connection):
+        table = self.tables.tables
         connection.execute(
             table.insert(),
             [{"persons": 5, "full": False}, {"persons": 3, "full": False}],
@@ -201,6 +199,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
 
     @testing.requires.full_returning
     def test_update_full_returning(self, connection):
+        table = self.tables.tables
         connection.execute(
             table.insert(),
             [{"persons": 5, "full": False}, {"persons": 3, "full": False}],
@@ -215,6 +214,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
 
     @testing.requires.full_returning
     def test_delete_full_returning(self, connection):
+        table = self.tables.tables
         connection.execute(
             table.insert(),
             [{"persons": 5, "full": False}, {"persons": 3, "full": False}],
@@ -226,6 +226,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(result.fetchall(), [(1, False), (2, False)])
 
     def test_insert_returning(self, connection):
+        table = self.tables.tables
         result = connection.execute(
             table.insert().returning(table.c.id), {"persons": 1, "full": False}
         )
@@ -234,6 +235,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
 
     @testing.requires.multivalues_inserts
     def test_multirow_returning(self, connection):
+        table = self.tables.tables
         ins = (
             table.insert()
             .returning(table.c.id, table.c.persons)
@@ -249,6 +251,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)])
 
     def test_no_ipk_on_returning(self, connection):
+        table = self.tables.tables
         result = connection.execute(
             table.insert().returning(table.c.id), {"persons": 1, "full": False}
         )
@@ -274,6 +277,7 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
         eq_([dict(row._mapping) for row in result4], [{"persons": 10}])
 
     def test_delete_returning(self, connection):
+        table = self.tables.tables
         connection.execute(
             table.insert(),
             [{"persons": 5, "full": False}, {"persons": 3, "full": False}],
@@ -319,17 +323,16 @@ class CompositeStatementTest(fixtures.TestBase):
         eq_(result.scalar(), 5)
 
 
-class SequenceReturningTest(fixtures.TestBase):
+class SequenceReturningTest(fixtures.TablesTest):
     __requires__ = "returning", "sequences"
     __backend__ = True
 
-    def setup(self):
-        meta = MetaData(testing.db)
-        global table, seq
+    @classmethod
+    def define_tables(cls, metadata):
         seq = Sequence("tid_seq")
-        table = Table(
+        Table(
             "tables",
-            meta,
+            metadata,
             Column(
                 "id",
                 Integer,
@@ -338,38 +341,32 @@ class SequenceReturningTest(fixtures.TestBase):
             ),
             Column("data", String(50)),
         )
-        with testing.db.connect() as conn:
-            table.create(conn, checkfirst=True)
-
-    def teardown(self):
-        with testing.db.connect() as conn:
-            table.drop(conn)
+        cls.sequences.tid_seq = seq
 
     def test_insert(self, connection):
+        table = self.tables.tables
         r = connection.execute(
             table.insert().values(data="hi").returning(table.c.id)
         )
         eq_(r.first(), tuple([testing.db.dialect.default_sequence_base]))
         eq_(
-            connection.execute(seq),
+            connection.execute(self.sequences.tid_seq),
             testing.db.dialect.default_sequence_base + 1,
         )
 
 
-class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults):
+class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults):
 
     """test returning() works with columns that define 'key'."""
 
     __requires__ = ("returning",)
     __backend__ = True
 
-    def setup(self):
-        meta = MetaData(testing.db)
-        global table
-
-        table = Table(
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
             "tables",
-            meta,
+            metadata,
             Column(
                 "id",
                 Integer,
@@ -379,16 +376,11 @@ class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults):
             ),
             Column("data", String(20)),
         )
-        with testing.db.connect() as conn:
-            table.create(conn, checkfirst=True)
-
-    def teardown(self):
-        with testing.db.connect() as conn:
-            table.drop(conn)
 
     @testing.exclude("firebird", "<", (2, 0), "2.0+ feature")
     @testing.exclude("postgresql", "<", (8, 2), "8.2+ feature")
     def test_insert(self, connection):
+        table = self.tables.tables
         result = connection.execute(
             table.insert().returning(table.c.foo_id), data="somedata"
         )
index e609a8a91618a13caecc93c71e777a86de4d1ff8..1809e0cca0752fae9aef6d90f8c41dd33785e12c 100644 (file)
@@ -95,64 +95,6 @@ class SequenceDDLTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         )
 
 
-class LegacySequenceExecTest(fixtures.TestBase):
-    __requires__ = ("sequences",)
-    __backend__ = True
-
-    @classmethod
-    def setup_class(cls):
-        cls.seq = Sequence("my_sequence")
-        cls.seq.create(testing.db)
-
-    @classmethod
-    def teardown_class(cls):
-        cls.seq.drop(testing.db)
-
-    def _assert_seq_result(self, ret):
-        """asserts return of next_value is an int"""
-
-        assert isinstance(ret, util.int_types)
-        assert ret >= testing.db.dialect.default_sequence_base
-
-    def test_implicit_connectionless(self):
-        s = Sequence("my_sequence", metadata=MetaData(testing.db))
-        self._assert_seq_result(s.execute())
-
-    def test_explicit(self, connection):
-        s = Sequence("my_sequence")
-        self._assert_seq_result(s.execute(connection))
-
-    def test_explicit_optional(self):
-        """test dialect executes a Sequence, returns nextval, whether
-        or not "optional" is set"""
-
-        s = Sequence("my_sequence", optional=True)
-        self._assert_seq_result(s.execute(testing.db))
-
-    def test_func_implicit_connectionless_execute(self):
-        """test func.next_value().execute()/.scalar() works
-        with connectionless execution."""
-
-        s = Sequence("my_sequence", metadata=MetaData(testing.db))
-        self._assert_seq_result(s.next_value().execute().scalar())
-
-    def test_func_explicit(self):
-        s = Sequence("my_sequence")
-        self._assert_seq_result(testing.db.scalar(s.next_value()))
-
-    def test_func_implicit_connectionless_scalar(self):
-        """test func.next_value().execute()/.scalar() works. """
-
-        s = Sequence("my_sequence", metadata=MetaData(testing.db))
-        self._assert_seq_result(s.next_value().scalar())
-
-    def test_func_embedded_select(self):
-        """test can use next_value() in select column expr"""
-
-        s = Sequence("my_sequence")
-        self._assert_seq_result(testing.db.scalar(select(s.next_value())))
-
-
 class SequenceExecTest(fixtures.TestBase):
     __requires__ = ("sequences",)
     __backend__ = True
@@ -247,7 +189,7 @@ class SequenceExecTest(fixtures.TestBase):
         s = Sequence("my_sequence_here", metadata=metadata)
 
         e = engines.testing_engine(options={"implicit_returning": False})
-        with e.connect() as conn:
+        with e.begin() as conn:
 
             t1.create(conn)
             s.create(conn)
@@ -279,7 +221,7 @@ class SequenceExecTest(fixtures.TestBase):
         t1.create(testing.db)
 
         e = engines.testing_engine(options={"implicit_returning": True})
-        with e.connect() as conn:
+        with e.begin() as conn:
             r = conn.execute(t1.insert().values(x=s.next_value()))
             self._assert_seq_result(r.inserted_primary_key[0])
 
@@ -476,7 +418,7 @@ class TableBoundSequenceTest(fixtures.TablesTest):
 
         engine = engines.testing_engine(options={"implicit_returning": False})
 
-        with engine.connect() as conn:
+        with engine.begin() as conn:
             result = conn.execute(sometable.insert(), dict(name="somename"))
 
             eq_(result.postfetch_cols(), [sometable.c.obj_id])
index 09ade319e2f3e0bb44213bf1249d285c6fd85054..719f8e3187aa5f08e7db8dedecf3c0fb605f904b 100644 (file)
@@ -359,34 +359,34 @@ class RoundTripTestBase(object):
             [("X1", "Y1"), ("X2", "Y2"), ("X3", "Y3")],
         )
 
-    def test_targeting_no_labels(self):
-        testing.db.execute(
+    def test_targeting_no_labels(self, connection):
+        connection.execute(
             self.tables.test_table.insert(), {"x": "X1", "y": "Y1"}
         )
-        row = testing.db.execute(select(self.tables.test_table)).first()
+        row = connection.execute(select(self.tables.test_table)).first()
         eq_(row._mapping[self.tables.test_table.c.y], "Y1")
 
-    def test_targeting_by_string(self):
-        testing.db.execute(
+    def test_targeting_by_string(self, connection):
+        connection.execute(
             self.tables.test_table.insert(), {"x": "X1", "y": "Y1"}
         )
-        row = testing.db.execute(select(self.tables.test_table)).first()
+        row = connection.execute(select(self.tables.test_table)).first()
         eq_(row._mapping["y"], "Y1")
 
-    def test_targeting_apply_labels(self):
-        testing.db.execute(
+    def test_targeting_apply_labels(self, connection):
+        connection.execute(
             self.tables.test_table.insert(), {"x": "X1", "y": "Y1"}
         )
-        row = testing.db.execute(
+        row = connection.execute(
             select(self.tables.test_table).apply_labels()
         ).first()
         eq_(row._mapping[self.tables.test_table.c.y], "Y1")
 
-    def test_targeting_individual_labels(self):
-        testing.db.execute(
+    def test_targeting_individual_labels(self, connection):
+        connection.execute(
             self.tables.test_table.insert(), {"x": "X1", "y": "Y1"}
         )
-        row = testing.db.execute(
+        row = connection.execute(
             select(
                 self.tables.test_table.c.x.label("xbar"),
                 self.tables.test_table.c.y.label("ybar"),
@@ -450,9 +450,9 @@ class ReturningTest(fixtures.TablesTest):
         )
 
     @testing.provide_metadata
-    def test_insert_returning(self):
+    def test_insert_returning(self, connection):
         table = self.tables.test_table
-        result = testing.db.execute(
+        result = connection.execute(
             table.insert().returning(table.c.y), {"x": "xvalue"}
         )
         eq_(result.first(), ("yvalue",))
index fd1783e09806177acb953da3225526f493e20ec4..3f89d438a6a6bac9dd0645abe0f4ccb66f29963d 100644 (file)
@@ -535,49 +535,48 @@ class _UserDefinedTypeFixture(object):
 class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
     __backend__ = True
 
-    def _data_fixture(self):
+    def _data_fixture(self, connection):
         users = self.tables.users
-        with testing.db.connect() as conn:
-            conn.execute(
-                users.insert(),
-                dict(
-                    user_id=2,
-                    goofy="jack",
-                    goofy2="jack",
-                    goofy4=util.u("jack"),
-                    goofy7=util.u("jack"),
-                    goofy8=12,
-                    goofy9=12,
-                ),
-            )
-            conn.execute(
-                users.insert(),
-                dict(
-                    user_id=3,
-                    goofy="lala",
-                    goofy2="lala",
-                    goofy4=util.u("lala"),
-                    goofy7=util.u("lala"),
-                    goofy8=15,
-                    goofy9=15,
-                ),
-            )
-            conn.execute(
-                users.insert(),
-                dict(
-                    user_id=4,
-                    goofy="fred",
-                    goofy2="fred",
-                    goofy4=util.u("fred"),
-                    goofy7=util.u("fred"),
-                    goofy8=9,
-                    goofy9=9,
-                ),
-            )
+        connection.execute(
+            users.insert(),
+            dict(
+                user_id=2,
+                goofy="jack",
+                goofy2="jack",
+                goofy4=util.u("jack"),
+                goofy7=util.u("jack"),
+                goofy8=12,
+                goofy9=12,
+            ),
+        )
+        connection.execute(
+            users.insert(),
+            dict(
+                user_id=3,
+                goofy="lala",
+                goofy2="lala",
+                goofy4=util.u("lala"),
+                goofy7=util.u("lala"),
+                goofy8=15,
+                goofy9=15,
+            ),
+        )
+        connection.execute(
+            users.insert(),
+            dict(
+                user_id=4,
+                goofy="fred",
+                goofy2="fred",
+                goofy4=util.u("fred"),
+                goofy7=util.u("fred"),
+                goofy8=9,
+                goofy9=9,
+            ),
+        )
 
     def test_processing(self, connection):
         users = self.tables.users
-        self._data_fixture()
+        self._data_fixture(connection)
 
         result = connection.execute(
             users.select().order_by(users.c.user_id)
@@ -601,7 +600,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
 
     def test_plain_in(self, connection):
         users = self.tables.users
-        self._data_fixture()
+        self._data_fixture(connection)
 
         stmt = (
             select(users.c.user_id, users.c.goofy8)
@@ -613,7 +612,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
 
     def test_expanding_in(self, connection):
         users = self.tables.users
-        self._data_fixture()
+        self._data_fixture(connection)
 
         stmt = (
             select(users.c.user_id, users.c.goofy8)
@@ -1225,41 +1224,38 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
 
     @testing.only_on("sqlite")
     @testing.provide_metadata
-    def test_round_trip(self):
+    def test_round_trip(self, connection):
         variant = self.UTypeOne().with_variant(self.UTypeTwo(), "sqlite")
 
         t = Table("t", self.metadata, Column("x", variant))
-        with testing.db.connect() as conn:
-            t.create(conn)
+        t.create(connection)
 
-            conn.execute(t.insert(), x="foo")
+        connection.execute(t.insert(), x="foo")
 
-            eq_(conn.scalar(select(t.c.x).where(t.c.x == "foo")), "fooUTWO")
+        eq_(connection.scalar(select(t.c.x).where(t.c.x == "foo")), "fooUTWO")
 
     @testing.only_on("sqlite")
     @testing.provide_metadata
-    def test_round_trip_sqlite_datetime(self):
+    def test_round_trip_sqlite_datetime(self, connection):
         variant = DateTime().with_variant(
             dialects.sqlite.DATETIME(truncate_microseconds=True), "sqlite"
         )
 
         t = Table("t", self.metadata, Column("x", variant))
-        with testing.db.connect() as conn:
-            t.create(conn)
+        t.create(connection)
 
-            conn.execute(
-                t.insert(), x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839)
-            )
+        connection.execute(
+            t.insert(), x=datetime.datetime(2015, 4, 18, 10, 15, 17, 4839)
+        )
 
-            eq_(
-                conn.scalar(
-                    select(t.c.x).where(
-                        t.c.x
-                        == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059)
-                    )
-                ),
-                datetime.datetime(2015, 4, 18, 10, 15, 17),
-            )
+        eq_(
+            connection.scalar(
+                select(t.c.x).where(
+                    t.c.x == datetime.datetime(2015, 4, 18, 10, 15, 17, 1059)
+                )
+            ),
+            datetime.datetime(2015, 4, 18, 10, 15, 17),
+        )
 
 
 class UnicodeTest(fixtures.TestBase):
@@ -1702,14 +1698,25 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             2,
         )
 
-        with testing.db.connect() as conn:
-            self.metadata.create_all(conn)
+        self.metadata.create_all(testing.db)
+
+        # not using the connection fixture because we need to rollback and
+        # start again in the middle
+        with testing.db.connect() as connection:
+            # postgresql needs this in order to continue after the exception
+            trans = connection.begin()
             assert_raises(
                 (exc.DBAPIError,),
-                conn.exec_driver_sql,
+                connection.exec_driver_sql,
                 "insert into my_table " "(data) values('four')",
             )
-            conn.exec_driver_sql("insert into my_table (data) values ('two')")
+            trans.rollback()
+
+            with connection.begin():
+                connection.exec_driver_sql(
+                    "insert into my_table (data) values ('two')"
+                )
+                eq_(connection.execute(select(t.c.data)).scalar(), "two")
 
     @testing.requires.enforces_check_constraints
     @testing.provide_metadata
@@ -1747,34 +1754,44 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
             2,
         )
 
-        with testing.db.connect() as conn:
-            self.metadata.create_all(conn)
+        self.metadata.create_all(testing.db)
+
+        # not using the connection fixture because we need to rollback and
+        # start again in the middle
+        with testing.db.connect() as connection:
+            # postgresql needs this in order to continue after the exception
+            trans = connection.begin()
             assert_raises(
                 (exc.DBAPIError,),
-                conn.exec_driver_sql,
+                connection.exec_driver_sql,
                 "insert into my_table " "(data) values('two')",
             )
-            conn.exec_driver_sql("insert into my_table (data) values ('four')")
+            trans.rollback()
 
-    def test_skip_check_constraint(self):
-        with testing.db.connect() as conn:
-            conn.exec_driver_sql(
-                "insert into non_native_enum_table "
-                "(id, someotherenum) values(1, 'four')"
-            )
-            eq_(
-                conn.exec_driver_sql(
-                    "select someotherenum from non_native_enum_table"
-                ).scalar(),
-                "four",
-            )
-            assert_raises_message(
-                LookupError,
-                "'four' is not among the defined enum values. "
-                "Enum name: None. Possible values: one, two, three",
-                conn.scalar,
-                select(self.tables.non_native_enum_table.c.someotherenum),
-            )
+            with connection.begin():
+                connection.exec_driver_sql(
+                    "insert into my_table (data) values ('four')"
+                )
+                eq_(connection.execute(select(t.c.data)).scalar(), "four")
+
+    def test_skip_check_constraint(self, connection):
+        connection.exec_driver_sql(
+            "insert into non_native_enum_table "
+            "(id, someotherenum) values(1, 'four')"
+        )
+        eq_(
+            connection.exec_driver_sql(
+                "select someotherenum from non_native_enum_table"
+            ).scalar(),
+            "four",
+        )
+        assert_raises_message(
+            LookupError,
+            "'four' is not among the defined enum values. "
+            "Enum name: None. Possible values: one, two, three",
+            connection.scalar,
+            select(self.tables.non_native_enum_table.c.someotherenum),
+        )
 
     def test_non_native_round_trip(self, connection):
         non_native_enum_table = self.tables["non_native_enum_table"]
@@ -2086,15 +2103,15 @@ class EnumTest(AssertsCompiledSQL, fixtures.TablesTest):
         eq_(e.length, 42)
 
 
-binary_table = MyPickleType = metadata = None
+MyPickleType = None
 
 
-class BinaryTest(fixtures.TestBase, AssertsExecutionResults):
+class BinaryTest(fixtures.TablesTest, AssertsExecutionResults):
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
-        global binary_table, MyPickleType, metadata
+    def define_tables(cls, metadata):
+        global MyPickleType
 
         class MyPickleType(types.TypeDecorator):
             impl = PickleType
@@ -2109,8 +2126,7 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults):
                     value.stuff = "this is the right stuff"
                 return value
 
-        metadata = MetaData(testing.db)
-        binary_table = Table(
+        Table(
             "binary_table",
             metadata,
             Column(
@@ -2125,19 +2141,11 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults):
             Column("pickled", PickleType),
             Column("mypickle", MyPickleType),
         )
-        metadata.create_all()
-
-    @engines.close_first
-    def teardown(self):
-        with testing.db.connect() as conn:
-            conn.execute(binary_table.delete())
-
-    @classmethod
-    def teardown_class(cls):
-        metadata.drop_all()
 
     @testing.requires.non_broken_binary
     def test_round_trip(self, connection):
+        binary_table = self.tables.binary_table
+
         testobj1 = pickleable.Foo("im foo 1")
         testobj2 = pickleable.Foo("im foo 2")
         testobj3 = pickleable.Foo("im foo 3")
@@ -2197,6 +2205,7 @@ class BinaryTest(fixtures.TestBase, AssertsExecutionResults):
     @testing.requires.binary_comparisons
     def test_comparison(self, connection):
         """test that type coercion occurs on comparison for binary"""
+        binary_table = self.tables.binary_table
 
         expr = binary_table.c.data == "foo"
         assert isinstance(expr.right.type, LargeBinary)
@@ -2419,17 +2428,17 @@ class ArrayTest(fixtures.TestBase):
         assert isinstance(arrtable.c.strarr[1:3].type, MyArray)
 
 
-test_table = meta = MyCustomType = MyTypeDec = None
+MyCustomType = MyTypeDec = None
 
 
 class ExpressionTest(
-    fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL
+    fixtures.TablesTest, AssertsExecutionResults, AssertsCompiledSQL
 ):
     __dialect__ = "default"
 
     @classmethod
-    def setup_class(cls):
-        global test_table, meta, MyCustomType, MyTypeDec
+    def define_tables(cls, metadata):
+        global MyCustomType, MyTypeDec
 
         class MyCustomType(types.UserDefinedType):
             def get_col_spec(self):
@@ -2463,10 +2472,9 @@ class ExpressionTest(
             def process_result_value(self, value, dialect):
                 return value + "BIND_OUT"
 
-        meta = MetaData(testing.db)
-        test_table = Table(
+        Table(
             "test",
-            meta,
+            metadata,
             Column("id", Integer, primary_key=True),
             Column("data", String(30)),
             Column("atimestamp", Date),
@@ -2474,25 +2482,22 @@ class ExpressionTest(
             Column("bvalue", MyTypeDec(50)),
         )
 
-        meta.create_all()
-
-        with testing.db.connect() as conn:
-            conn.execute(
-                test_table.insert(),
-                {
-                    "id": 1,
-                    "data": "somedata",
-                    "atimestamp": datetime.date(2007, 10, 15),
-                    "avalue": 25,
-                    "bvalue": "foo",
-                },
-            )
-
     @classmethod
-    def teardown_class(cls):
-        meta.drop_all()
+    def insert_data(cls, connection):
+        test_table = cls.tables.test
+        connection.execute(
+            test_table.insert(),
+            {
+                "id": 1,
+                "data": "somedata",
+                "atimestamp": datetime.date(2007, 10, 15),
+                "avalue": 25,
+                "bvalue": "foo",
+            },
+        )
 
     def test_control(self, connection):
+        test_table = self.tables.test
         assert (
             connection.exec_driver_sql("select avalue from test").scalar()
             == 250
@@ -2513,6 +2518,9 @@ class ExpressionTest(
 
     def test_bind_adapt(self, connection):
         # test an untyped bind gets the left side's type
+
+        test_table = self.tables.test
+
         expr = test_table.c.atimestamp == bindparam("thedate")
         eq_(expr.right.type._type_affinity, Date)
 
@@ -2565,6 +2573,8 @@ class ExpressionTest(
         )
 
     def test_grouped_bind_adapt(self):
+        test_table = self.tables.test
+
         expr = test_table.c.atimestamp == elements.Grouping(
             bindparam("thedate")
         )
@@ -2579,6 +2589,8 @@ class ExpressionTest(
         eq_(expr.right.element.element.type._type_affinity, Date)
 
     def test_bind_adapt_update(self):
+        test_table = self.tables.test
+
         bp = bindparam("somevalue")
         stmt = test_table.update().values(avalue=bp)
         compiled = stmt.compile()
@@ -2586,13 +2598,17 @@ class ExpressionTest(
         eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType)
 
     def test_bind_adapt_insert(self):
+        test_table = self.tables.test
         bp = bindparam("somevalue")
+
         stmt = test_table.insert().values(avalue=bp)
         compiled = stmt.compile()
         eq_(bp.type._type_affinity, types.NullType)
         eq_(compiled.binds["somevalue"].type._type_affinity, MyCustomType)
 
     def test_bind_adapt_expression(self):
+        test_table = self.tables.test
+
         bp = bindparam("somevalue")
         stmt = test_table.c.avalue == bp
         eq_(bp.type._type_affinity, types.NullType)
@@ -2629,6 +2645,8 @@ class ExpressionTest(
         is_(literal(data).type.__class__, expected)
 
     def test_typedec_operator_adapt(self, connection):
+        test_table = self.tables.test
+
         expr = test_table.c.bvalue + "hi"
 
         assert expr.type.__class__ is MyTypeDec
@@ -2846,6 +2864,8 @@ class ExpressionTest(
         eq_(expr.type, types.NULLTYPE)
 
     def test_distinct(self, connection):
+        test_table = self.tables.test
+
         s = select(distinct(test_table.c.avalue))
         eq_(connection.execute(s).scalar(), 25)
 
@@ -3004,17 +3024,18 @@ class NumericRawSQLTest(fixtures.TestBase):
 
     __backend__ = True
 
-    def _fixture(self, metadata, type_, data):
+    def _fixture(self, connection, metadata, type_, data):
         t = Table("t", metadata, Column("val", type_))
-        metadata.create_all()
-        with testing.db.connect() as conn:
-            conn.execute(t.insert(), val=data)
+        metadata.create_all(connection)
+        connection.execute(t.insert(), val=data)
 
     @testing.fails_on("sqlite", "Doesn't provide Decimal results natively")
     @testing.provide_metadata
     def test_decimal_fp(self, connection):
         metadata = self.metadata
-        self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45.5"))
+        self._fixture(
+            connection, metadata, Numeric(10, 5), decimal.Decimal("45.5")
+        )
         val = connection.exec_driver_sql("select val from t").scalar()
         assert isinstance(val, decimal.Decimal)
         eq_(val, decimal.Decimal("45.5"))
@@ -3023,7 +3044,9 @@ class NumericRawSQLTest(fixtures.TestBase):
     @testing.provide_metadata
     def test_decimal_int(self, connection):
         metadata = self.metadata
-        self._fixture(metadata, Numeric(10, 5), decimal.Decimal("45"))
+        self._fixture(
+            connection, metadata, Numeric(10, 5), decimal.Decimal("45")
+        )
         val = connection.exec_driver_sql("select val from t").scalar()
         assert isinstance(val, decimal.Decimal)
         eq_(val, decimal.Decimal("45"))
@@ -3031,7 +3054,7 @@ class NumericRawSQLTest(fixtures.TestBase):
     @testing.provide_metadata
     def test_ints(self, connection):
         metadata = self.metadata
-        self._fixture(metadata, Integer, 45)
+        self._fixture(connection, metadata, Integer, 45)
         val = connection.exec_driver_sql("select val from t").scalar()
         assert isinstance(val, util.int_types)
         eq_(val, 45)
@@ -3039,7 +3062,7 @@ class NumericRawSQLTest(fixtures.TestBase):
     @testing.provide_metadata
     def test_float(self, connection):
         metadata = self.metadata
-        self._fixture(metadata, Float, 46.583)
+        self._fixture(connection, metadata, Float, 46.583)
         val = connection.exec_driver_sql("select val from t").scalar()
         assert isinstance(val, float)
 
@@ -3050,19 +3073,14 @@ class NumericRawSQLTest(fixtures.TestBase):
             eq_(val, 46.583)
 
 
-interval_table = metadata = None
-
-
-class IntervalTest(fixtures.TestBase, AssertsExecutionResults):
+class IntervalTest(fixtures.TablesTest, AssertsExecutionResults):
 
     __backend__ = True
 
     @classmethod
-    def setup_class(cls):
-        global interval_table, metadata
-        metadata = MetaData(testing.db)
-        interval_table = Table(
-            "intervaltable",
+    def define_tables(cls, metadata):
+        Table(
+            "intervals",
             metadata,
             Column(
                 "id", Integer, primary_key=True, test_needs_autoincrement=True
@@ -3074,16 +3092,6 @@ class IntervalTest(fixtures.TestBase, AssertsExecutionResults):
             ),
             Column("non_native_interval", Interval(native=False)),
         )
-        metadata.create_all()
-
-    @engines.close_first
-    def teardown(self):
-        with testing.db.connect() as conn:
-            conn.execute(interval_table.delete())
-
-    @classmethod
-    def teardown_class(cls):
-        metadata.drop_all()
 
     def test_non_native_adapt(self):
         interval = Interval(native=False)
@@ -3092,30 +3100,32 @@ class IntervalTest(fixtures.TestBase, AssertsExecutionResults):
         assert adapted.native is False
         eq_(str(adapted), "DATETIME")
 
-    def test_roundtrip(self):
+    def test_roundtrip(self, connection):
+        interval_table = self.tables.intervals
+
         small_delta = datetime.timedelta(days=15, seconds=5874)
         delta = datetime.timedelta(14)
-        with testing.db.begin() as conn:
-            conn.execute(
-                interval_table.insert(),
-                native_interval=small_delta,
-                native_interval_args=delta,
-                non_native_interval=delta,
-            )
-            row = conn.execute(interval_table.select()).first()
+        connection.execute(
+            interval_table.insert(),
+            native_interval=small_delta,
+            native_interval_args=delta,
+            non_native_interval=delta,
+        )
+        row = connection.execute(interval_table.select()).first()
         eq_(row.native_interval, small_delta)
         eq_(row.native_interval_args, delta)
         eq_(row.non_native_interval, delta)
 
-    def test_null(self):
-        with testing.db.begin() as conn:
-            conn.execute(
-                interval_table.insert(),
-                id=1,
-                native_inverval=None,
-                non_native_interval=None,
-            )
-            row = conn.execute(interval_table.select()).first()
+    def test_null(self, connection):
+        interval_table = self.tables.intervals
+
+        connection.execute(
+            interval_table.insert(),
+            id=1,
+            native_inverval=None,
+            non_native_interval=None,
+        )
+        row = connection.execute(interval_table.select()).first()
         eq_(row.native_interval, None)
         eq_(row.native_interval_args, None)
         eq_(row.non_native_interval, None)
@@ -3215,25 +3225,24 @@ class BooleanTest(
             )
 
     @testing.requires.non_native_boolean_unconstrained
-    def test_nonnative_processor_coerces_integer_to_boolean(self):
+    def test_nonnative_processor_coerces_integer_to_boolean(self, connection):
         boolean_table = self.tables.boolean_table
-        with testing.db.connect() as conn:
-            conn.exec_driver_sql(
-                "insert into boolean_table (id, unconstrained_value) "
-                "values (1, 5)"
-            )
+        connection.exec_driver_sql(
+            "insert into boolean_table (id, unconstrained_value) "
+            "values (1, 5)"
+        )
 
-            eq_(
-                conn.exec_driver_sql(
-                    "select unconstrained_value from boolean_table"
-                ).scalar(),
-                5,
-            )
+        eq_(
+            connection.exec_driver_sql(
+                "select unconstrained_value from boolean_table"
+            ).scalar(),
+            5,
+        )
 
-            eq_(
-                conn.scalar(select(boolean_table.c.unconstrained_value)),
-                True,
-            )
+        eq_(
+            connection.scalar(select(boolean_table.c.unconstrained_value)),
+            True,
+        )
 
     def test_bind_processor_coercion_native_true(self):
         proc = Boolean().bind_processor(
index ec96af207e7f1240f4600c888d5678c9464021d6..946a01651a554a6a1ded4b443f7f58db4e1cf633 100644 (file)
@@ -1263,10 +1263,10 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
     __backend__ = True
 
     @testing.requires.update_from
-    def test_exec_two_table(self):
+    def test_exec_two_table(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
 
-        testing.db.execute(
+        connection.execute(
             addresses.update()
             .values(email_address=users.c.name)
             .where(users.c.id == addresses.c.user_id)
@@ -1280,14 +1280,14 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
             (4, 8, "x", "ed"),
             (5, 9, "x", "fred@fred.com"),
         ]
-        self._assert_addresses(addresses, expected)
+        self._assert_addresses(connection, addresses, expected)
 
     @testing.requires.update_from
-    def test_exec_two_table_plus_alias(self):
+    def test_exec_two_table_plus_alias(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
 
         a1 = addresses.alias()
-        testing.db.execute(
+        connection.execute(
             addresses.update()
             .values(email_address=users.c.name)
             .where(users.c.id == a1.c.user_id)
@@ -1302,15 +1302,15 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
             (4, 8, "x", "ed"),
             (5, 9, "x", "fred@fred.com"),
         ]
-        self._assert_addresses(addresses, expected)
+        self._assert_addresses(connection, addresses, expected)
 
     @testing.requires.update_from
-    def test_exec_three_table(self):
+    def test_exec_three_table(self, connection):
         users = self.tables.users
         addresses = self.tables.addresses
         dingalings = self.tables.dingalings
 
-        testing.db.execute(
+        connection.execute(
             addresses.update()
             .values(email_address=users.c.name)
             .where(users.c.id == addresses.c.user_id)
@@ -1326,15 +1326,15 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
             (4, 8, "x", "ed@lala.com"),
             (5, 9, "x", "fred@fred.com"),
         ]
-        self._assert_addresses(addresses, expected)
+        self._assert_addresses(connection, addresses, expected)
 
     @testing.only_on("mysql", "Multi table update")
-    def test_exec_multitable(self):
+    def test_exec_multitable(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
 
         values = {addresses.c.email_address: "updated", users.c.name: "ed2"}
 
-        testing.db.execute(
+        connection.execute(
             addresses.update()
             .values(values)
             .where(users.c.id == addresses.c.user_id)
@@ -1348,18 +1348,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
             (4, 8, "x", "updated"),
             (5, 9, "x", "fred@fred.com"),
         ]
-        self._assert_addresses(addresses, expected)
+        self._assert_addresses(connection, addresses, expected)
 
         expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")]
-        self._assert_users(users, expected)
+        self._assert_users(connection, users, expected)
 
     @testing.only_on("mysql", "Multi table update")
-    def test_exec_join_multitable(self):
+    def test_exec_join_multitable(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
 
         values = {addresses.c.email_address: "updated", users.c.name: "ed2"}
 
-        testing.db.execute(
+        connection.execute(
             update(users.join(addresses))
             .values(values)
             .where(users.c.name == "ed")
@@ -1372,18 +1372,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
             (4, 8, "x", "updated"),
             (5, 9, "x", "fred@fred.com"),
         ]
-        self._assert_addresses(addresses, expected)
+        self._assert_addresses(connection, addresses, expected)
 
         expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")]
-        self._assert_users(users, expected)
+        self._assert_users(connection, users, expected)
 
     @testing.only_on("mysql", "Multi table update")
-    def test_exec_multitable_same_name(self):
+    def test_exec_multitable_same_name(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
 
         values = {addresses.c.name: "ad_ed2", users.c.name: "ed2"}
 
-        testing.db.execute(
+        connection.execute(
             addresses.update()
             .values(values)
             .where(users.c.id == addresses.c.user_id)
@@ -1397,18 +1397,18 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
             (4, 8, "ad_ed2", "ed@lala.com"),
             (5, 9, "x", "fred@fred.com"),
         ]
-        self._assert_addresses(addresses, expected)
+        self._assert_addresses(connection, addresses, expected)
 
         expected = [(7, "jack"), (8, "ed2"), (9, "fred"), (10, "chuck")]
-        self._assert_users(users, expected)
+        self._assert_users(connection, users, expected)
 
-    def _assert_addresses(self, addresses, expected):
+    def _assert_addresses(self, connection, addresses, expected):
         stmt = addresses.select().order_by(addresses.c.id)
-        eq_(testing.db.execute(stmt).fetchall(), expected)
+        eq_(connection.execute(stmt).fetchall(), expected)
 
-    def _assert_users(self, users, expected):
+    def _assert_users(self, connection, users, expected):
         stmt = users.select().order_by(users.c.id)
-        eq_(testing.db.execute(stmt).fetchall(), expected)
+        eq_(connection.execute(stmt).fetchall(), expected)
 
 
 class UpdateFromMultiTableUpdateDefaultsTest(
@@ -1472,12 +1472,12 @@ class UpdateFromMultiTableUpdateDefaultsTest(
         )
 
     @testing.only_on("mysql", "Multi table update")
-    def test_defaults_second_table(self):
+    def test_defaults_second_table(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
 
         values = {addresses.c.email_address: "updated", users.c.name: "ed2"}
 
-        ret = testing.db.execute(
+        ret = connection.execute(
             addresses.update()
             .values(values)
             .where(users.c.id == addresses.c.user_id)
@@ -1491,18 +1491,18 @@ class UpdateFromMultiTableUpdateDefaultsTest(
             (3, 8, "updated"),
             (4, 9, "fred@fred.com"),
         ]
-        self._assert_addresses(addresses, expected)
+        self._assert_addresses(connection, addresses, expected)
 
         expected = [(8, "ed2", "im the update"), (9, "fred", "value")]
-        self._assert_users(users, expected)
+        self._assert_users(connection, users, expected)
 
     @testing.only_on("mysql", "Multi table update")
-    def test_defaults_second_table_same_name(self):
+    def test_defaults_second_table_same_name(self, connection):
         users, foobar = self.tables.users, self.tables.foobar
 
         values = {foobar.c.data: foobar.c.data + "a", users.c.name: "ed2"}
 
-        ret = testing.db.execute(
+        ret = connection.execute(
             users.update()
             .values(values)
             .where(users.c.id == foobar.c.user_id)
@@ -1519,16 +1519,16 @@ class UpdateFromMultiTableUpdateDefaultsTest(
             (3, 8, "d2a", "im the other update"),
             (4, 9, "d3", None),
         ]
-        self._assert_foobar(foobar, expected)
+        self._assert_foobar(connection, foobar, expected)
 
         expected = [(8, "ed2", "im the update"), (9, "fred", "value")]
-        self._assert_users(users, expected)
+        self._assert_users(connection, users, expected)
 
     @testing.only_on("mysql", "Multi table update")
-    def test_no_defaults_second_table(self):
+    def test_no_defaults_second_table(self, connection):
         users, addresses = self.tables.users, self.tables.addresses
 
-        ret = testing.db.execute(
+        ret = connection.execute(
             addresses.update()
             .values({"email_address": users.c.name})
             .where(users.c.id == addresses.c.user_id)
@@ -1538,20 +1538,20 @@ class UpdateFromMultiTableUpdateDefaultsTest(
         eq_(ret.prefetch_cols(), [])
 
         expected = [(2, 8, "ed"), (3, 8, "ed"), (4, 9, "fred@fred.com")]
-        self._assert_addresses(addresses, expected)
+        self._assert_addresses(connection, addresses, expected)
 
         # users table not actually updated, so no onupdate
         expected = [(8, "ed", "value"), (9, "fred", "value")]
-        self._assert_users(users, expected)
+        self._assert_users(connection, users, expected)
 
-    def _assert_foobar(self, foobar, expected):
+    def _assert_foobar(self, connection, foobar, expected):
         stmt = foobar.select().order_by(foobar.c.id)
-        eq_(testing.db.execute(stmt).fetchall(), expected)
+        eq_(connection.execute(stmt).fetchall(), expected)
 
-    def _assert_addresses(self, addresses, expected):
+    def _assert_addresses(self, connection, addresses, expected):
         stmt = addresses.select().order_by(addresses.c.id)
-        eq_(testing.db.execute(stmt).fetchall(), expected)
+        eq_(connection.execute(stmt).fetchall(), expected)
 
-    def _assert_users(self, users, expected):
+    def _assert_users(self, connection, users, expected):
         stmt = users.select().order_by(users.c.id)
-        eq_(testing.db.execute(stmt).fetchall(), expected)
+        eq_(connection.execute(stmt).fetchall(), expected)
diff --git a/tox.ini b/tox.ini
index 6cfcf62efc25acd3287a7625d3a849d2d51580de..e1aef1a23d772d52c15434b27e8cef8cd7ba54c8 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -56,7 +56,6 @@ setenv=
     PYTHONPATH=
     PYTHONNOUSERSITE=1
     MEMUSAGE=--nomemory
-    SQLALCHEMY_WARN_20=true
     BASECOMMAND=python -m pytest --rootdir {toxinidir} --log-info=sqlalchemy.testing
 
     WORKERS={env:TOX_WORKERS:-n4  --max-worker-restart=5}