]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Clean up .execute calls in remaining suite tests
authorGord Thompson <gord@gordthompson.com>
Sun, 12 Apr 2020 13:03:25 +0000 (07:03 -0600)
committerGord Thompson <gord@gordthompson.com>
Sun, 12 Apr 2020 13:03:25 +0000 (07:03 -0600)
Change-Id: Ib5c7f46067bcf5b162060476cc323bf671db101a

lib/sqlalchemy/testing/suite/test_cte.py
lib/sqlalchemy/testing/suite/test_ddl.py
lib/sqlalchemy/testing/suite/test_deprecations.py
lib/sqlalchemy/testing/suite/test_types.py
lib/sqlalchemy/testing/suite/test_update_delete.py

index c7e6a266ca4384b26b1b2db4a5624f79750d7c65..fab457606732a1171820ab7948d7fc654893c696 100644 (file)
@@ -37,16 +37,17 @@ class CTETest(fixtures.TablesTest):
 
     @classmethod
     def insert_data(cls):
-        config.db.execute(
-            cls.tables.some_table.insert(),
-            [
-                {"id": 1, "data": "d1", "parent_id": None},
-                {"id": 2, "data": "d2", "parent_id": 1},
-                {"id": 3, "data": "d3", "parent_id": 1},
-                {"id": 4, "data": "d4", "parent_id": 3},
-                {"id": 5, "data": "d5", "parent_id": 3},
-            ],
-        )
+        with config.db.connect() as conn:
+            conn.execute(
+                cls.tables.some_table.insert(),
+                [
+                    {"id": 1, "data": "d1", "parent_id": None},
+                    {"id": 2, "data": "d2", "parent_id": 1},
+                    {"id": 3, "data": "d3", "parent_id": 1},
+                    {"id": 4, "data": "d4", "parent_id": 3},
+                    {"id": 5, "data": "d5", "parent_id": 3},
+                ],
+            )
 
     def test_select_nonrecursive_round_trip(self):
         some_table = self.tables.some_table
index 81a55e18ae47617111b575985b434fe6a9edb707..1f49106fb6b22f7a574dd8406011313785eebe6b 100644 (file)
@@ -67,25 +67,27 @@ class TableDDLTest(fixtures.TestBase):
 
     @requirements.comment_reflection
     @util.provide_metadata
-    def test_add_table_comment(self):
+    def test_add_table_comment(self, connection):
         table = self._simple_fixture()
-        table.create(config.db, checkfirst=False)
+        table.create(connection, checkfirst=False)
         table.comment = "a comment"
-        config.db.execute(schema.SetTableComment(table))
+        connection.execute(schema.SetTableComment(table))
         eq_(
-            inspect(config.db).get_table_comment("test_table"),
+            inspect(connection).get_table_comment("test_table"),
             {"text": "a comment"},
         )
 
     @requirements.comment_reflection
     @util.provide_metadata
-    def test_drop_table_comment(self):
+    def test_drop_table_comment(self, connection):
         table = self._simple_fixture()
-        table.create(config.db, checkfirst=False)
+        table.create(connection, checkfirst=False)
         table.comment = "a comment"
-        config.db.execute(schema.SetTableComment(table))
-        config.db.execute(schema.DropTableComment(table))
-        eq_(inspect(config.db).get_table_comment("test_table"), {"text": None})
+        connection.execute(schema.SetTableComment(table))
+        connection.execute(schema.DropTableComment(table))
+        eq_(
+            inspect(connection).get_table_comment("test_table"), {"text": None}
+        )
 
 
 __all__ = ("TableDDLTest",)
index d0202a0a9557552991517953fd330fcb75633505..126d82fe975888db7c88d17163797c43d7a0840d 100644 (file)
@@ -24,20 +24,21 @@ class DeprecatedCompoundSelectTest(fixtures.TablesTest):
 
     @classmethod
     def insert_data(cls):
-        config.db.execute(
-            cls.tables.some_table.insert(),
-            [
-                {"id": 1, "x": 1, "y": 2},
-                {"id": 2, "x": 2, "y": 3},
-                {"id": 3, "x": 3, "y": 4},
-                {"id": 4, "x": 4, "y": 5},
-            ],
-        )
-
-    def _assert_result(self, select, result, params=()):
-        eq_(config.db.execute(select, params).fetchall(), result)
-
-    def test_plain_union(self):
+        with config.db.connect() as conn:
+            conn.execute(
+                cls.tables.some_table.insert(),
+                [
+                    {"id": 1, "x": 1, "y": 2},
+                    {"id": 2, "x": 2, "y": 3},
+                    {"id": 3, "x": 3, "y": 4},
+                    {"id": 4, "x": 4, "y": 5},
+                ],
+            )
+
+    def _assert_result(self, conn, select, result, params=()):
+        eq_(conn.execute(select, params).fetchall(), result)
+
+    def test_plain_union(self, connection):
         table = self.tables.some_table
         s1 = select([table]).where(table.c.id == 2)
         s2 = select([table]).where(table.c.id == 3)
@@ -47,7 +48,9 @@ class DeprecatedCompoundSelectTest(fixtures.TablesTest):
             "The SelectBase.c and SelectBase.columns "
             "attributes are deprecated"
         ):
-            self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
+            self._assert_result(
+                connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+            )
 
     # note we've had to remove one use case entirely, which is this
     # one.   the Select gets its FROMS from the WHERE clause and the
@@ -56,7 +59,7 @@ class DeprecatedCompoundSelectTest(fixtures.TablesTest):
     # ORDER BY without adding the SELECT into the FROM and breaking the
     # query.  Users will have to adjust for this use case if they were doing
     # it before.
-    def _dont_test_select_from_plain_union(self):
+    def _dont_test_select_from_plain_union(self, connection):
         table = self.tables.some_table
         s1 = select([table]).where(table.c.id == 2)
         s2 = select([table]).where(table.c.id == 3)
@@ -66,11 +69,13 @@ class DeprecatedCompoundSelectTest(fixtures.TablesTest):
             "The SelectBase.c and SelectBase.columns "
             "attributes are deprecated"
         ):
-            self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
+            self._assert_result(
+                connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+            )
 
     @testing.requires.order_by_col_from_union
     @testing.requires.parens_in_union_contained_select_w_limit_offset
-    def test_limit_offset_selectable_in_unions(self):
+    def test_limit_offset_selectable_in_unions(self, connection):
         table = self.tables.some_table
         s1 = (
             select([table])
@@ -90,10 +95,12 @@ class DeprecatedCompoundSelectTest(fixtures.TablesTest):
             "The SelectBase.c and SelectBase.columns "
             "attributes are deprecated"
         ):
-            self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
+            self._assert_result(
+                connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+            )
 
     @testing.requires.parens_in_union_contained_select_wo_limit_offset
-    def test_order_by_selectable_in_unions(self):
+    def test_order_by_selectable_in_unions(self, connection):
         table = self.tables.some_table
         s1 = select([table]).where(table.c.id == 2).order_by(table.c.id)
         s2 = select([table]).where(table.c.id == 3).order_by(table.c.id)
@@ -103,9 +110,11 @@ class DeprecatedCompoundSelectTest(fixtures.TablesTest):
             "The SelectBase.c and SelectBase.columns "
             "attributes are deprecated"
         ):
-            self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
+            self._assert_result(
+                connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+            )
 
-    def test_distinct_selectable_in_unions(self):
+    def test_distinct_selectable_in_unions(self, connection):
         table = self.tables.some_table
         s1 = select([table]).where(table.c.id == 2).distinct()
         s2 = select([table]).where(table.c.id == 3).distinct()
@@ -115,9 +124,11 @@ class DeprecatedCompoundSelectTest(fixtures.TablesTest):
             "The SelectBase.c and SelectBase.columns "
             "attributes are deprecated"
         ):
-            self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
+            self._assert_result(
+                connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+            )
 
-    def test_limit_offset_aliased_selectable_in_unions(self):
+    def test_limit_offset_aliased_selectable_in_unions(self, connection):
         table = self.tables.some_table
         s1 = (
             select([table])
@@ -141,4 +152,6 @@ class DeprecatedCompoundSelectTest(fixtures.TablesTest):
             "The SelectBase.c and SelectBase.columns "
             "attributes are deprecated"
         ):
-            self._assert_result(u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)])
+            self._assert_result(
+                connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+            )
index 9dabdbd650ab58f40d5f6d6fd77b4ed26ca01081..7719a3b3c35a7292f9a4b7fdfe15dc3970679ddb 100644 (file)
@@ -519,10 +519,10 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
             filter_=lambda n: n is not None and round(n, 5) or None,
         )
 
-    def test_float_coerce_round_trip(self):
+    def test_float_coerce_round_trip(self, connection):
         expr = 15.7563
 
-        val = testing.db.scalar(select([literal(expr)]))
+        val = connection.scalar(select([literal(expr)]))
         eq_(val, expr)
 
     # this does not work in MySQL, see #4036, however we choose not
@@ -530,17 +530,17 @@ class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
 
     @testing.requires.implicit_decimal_binds
     @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
-    def test_decimal_coerce_round_trip(self):
+    def test_decimal_coerce_round_trip(self, connection):
         expr = decimal.Decimal("15.7563")
 
-        val = testing.db.scalar(select([literal(expr)]))
+        val = connection.scalar(select([literal(expr)]))
         eq_(val, expr)
 
     @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
-    def test_decimal_coerce_round_trip_w_cast(self):
+    def test_decimal_coerce_round_trip_w_cast(self, connection):
         expr = decimal.Decimal("15.7563")
 
-        val = testing.db.scalar(select([cast(expr, Numeric(10, 4))]))
+        val = connection.scalar(select([cast(expr, Numeric(10, 4))]))
         eq_(val, expr)
 
     @testing.requires.precision_numerics_general
index 97bdf0ad76d2a32a62d09845427361db623417da..6003a09941eaf7717be54dff2abfbfea31ef02ad 100644 (file)
@@ -22,33 +22,34 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest):
 
     @classmethod
     def insert_data(cls):
-        config.db.execute(
-            cls.tables.plain_pk.insert(),
-            [
-                {"id": 1, "data": "d1"},
-                {"id": 2, "data": "d2"},
-                {"id": 3, "data": "d3"},
-            ],
-        )
-
-    def test_update(self):
+        with config.db.connect() as conn:
+            conn.execute(
+                cls.tables.plain_pk.insert(),
+                [
+                    {"id": 1, "data": "d1"},
+                    {"id": 2, "data": "d2"},
+                    {"id": 3, "data": "d3"},
+                ],
+            )
+
+    def test_update(self, connection):
         t = self.tables.plain_pk
-        r = config.db.execute(t.update().where(t.c.id == 2), data="d2_new")
+        r = connection.execute(t.update().where(t.c.id == 2), data="d2_new")
         assert not r.is_insert
         assert not r.returns_rows
 
         eq_(
-            config.db.execute(t.select().order_by(t.c.id)).fetchall(),
+            connection.execute(t.select().order_by(t.c.id)).fetchall(),
             [(1, "d1"), (2, "d2_new"), (3, "d3")],
         )
 
-    def test_delete(self):
+    def test_delete(self, connection):
         t = self.tables.plain_pk
-        r = config.db.execute(t.delete().where(t.c.id == 2))
+        r = connection.execute(t.delete().where(t.c.id == 2))
         assert not r.is_insert
         assert not r.returns_rows
         eq_(
-            config.db.execute(t.select().order_by(t.c.id)).fetchall(),
+            connection.execute(t.select().order_by(t.c.id)).fetchall(),
             [(1, "d1"), (3, "d3")],
         )