]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Clean up .execute in test/sql/test_returning.py
authorGord Thompson <gord@gordthompson.com>
Tue, 14 Apr 2020 11:34:35 +0000 (05:34 -0600)
committerGord Thompson <gord@gordthompson.com>
Tue, 14 Apr 2020 11:59:46 +0000 (05:59 -0600)
Change-Id: I390b0c9926345f9f4deec06b51d1a11a18a72ca9

test/sql/test_returning.py

index d81ad7186926221341af5df6cb11da21808b92da..5f655db6dd69f05241cdee623467272546b05e77 100644 (file)
@@ -53,16 +53,17 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
             Column("full", Boolean),
             Column("goofy", GoofyType(50)),
         )
-        table.create(checkfirst=True)
+        with testing.db.connect() as conn:
+            table.create(conn, checkfirst=True)
 
     def teardown(self):
-        table.drop()
+        with testing.db.connect() as conn:
+            table.drop(conn)
 
-    def test_column_targeting(self):
-        result = (
-            table.insert()
-            .returning(table.c.id, table.c.full)
-            .execute({"persons": 1, "full": False})
+    def test_column_targeting(self, connection):
+        result = connection.execute(
+            table.insert().returning(table.c.id, table.c.full),
+            {"persons": 1, "full": False},
         )
 
         row = result.first()._mapping
@@ -70,11 +71,10 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
         assert row[table.c.full] == row["full"]
         assert row["full"] is False
 
-        result = (
+        result = connection.execute(
             table.insert()
             .values(persons=5, full=True, goofy="somegoofy")
             .returning(table.c.persons, table.c.full, table.c.goofy)
-            .execute()
         )
         row = result.first()._mapping
         assert row[table.c.persons] == row["persons"] == 5
@@ -84,12 +84,11 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
         eq_(row["goofy"], "FOOsomegoofyBAR")
 
     @testing.fails_on("firebird", "fb can't handle returning x AS y")
-    def test_labeling(self):
-        result = (
+    def test_labeling(self, connection):
+        result = connection.execute(
             table.insert()
             .values(persons=6)
             .returning(table.c.persons.label("lala"))
-            .execute()
         )
         row = result.first()._mapping
         assert row["lala"] == 6
@@ -97,53 +96,48 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
     @testing.fails_on(
         "firebird", "fb/kintersbasdb can't handle the bind params"
     )
-    def test_anon_expressions(self):
-        result = (
+    def test_anon_expressions(self, connection):
+        result = connection.execute(
             table.insert()
             .values(goofy="someOTHERgoofy")
             .returning(func.lower(table.c.goofy, type_=GoofyType))
-            .execute()
         )
         row = result.first()
         eq_(row[0], "foosomeothergoofyBAR")
 
-        result = (
-            table.insert()
-            .values(persons=12)
-            .returning(table.c.persons + 18)
-            .execute()
+        result = connection.execute(
+            table.insert().values(persons=12).returning(table.c.persons + 18)
         )
         row = result.first()
         eq_(row[0], 30)
 
-    def test_update_returning(self):
-        table.insert().execute(
-            [{"persons": 5, "full": False}, {"persons": 3, "full": False}]
+    def test_update_returning(self, connection):
+        connection.execute(
+            table.insert(),
+            [{"persons": 5, "full": False}, {"persons": 3, "full": False}],
         )
 
-        result = (
-            table.update(table.c.persons > 4, dict(full=True))
-            .returning(table.c.id)
-            .execute()
+        result = connection.execute(
+            table.update(table.c.persons > 4, dict(full=True)).returning(
+                table.c.id
+            )
         )
         eq_(result.fetchall(), [(1,)])
 
-        result2 = (
-            select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+        result2 = connection.execute(
+            select([table.c.id, table.c.full]).order_by(table.c.id)
         )
         eq_(result2.fetchall(), [(1, True), (2, False)])
 
-    def test_insert_returning(self):
-        result = (
-            table.insert()
-            .returning(table.c.id)
-            .execute({"persons": 1, "full": False})
+    def test_insert_returning(self, connection):
+        result = connection.execute(
+            table.insert().returning(table.c.id), {"persons": 1, "full": False}
         )
 
         eq_(result.fetchall(), [(1,)])
 
     @testing.requires.multivalues_inserts
-    def test_multirow_returning(self):
+    def test_multirow_returning(self, connection):
         ins = (
             table.insert()
             .returning(table.c.id, table.c.persons)
@@ -155,11 +149,11 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
                 ]
             )
         )
-        result = testing.db.execute(ins)
+        result = connection.execute(ins)
         eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)])
 
-    def test_no_ipk_on_returning(self):
-        result = testing.db.execute(
+    def test_no_ipk_on_returning(self, connection):
+        result = connection.execute(
             table.insert().returning(table.c.id), {"persons": 1, "full": False}
         )
         assert_raises_message(
@@ -183,18 +177,19 @@ class ReturningTest(fixtures.TestBase, AssertsExecutionResults):
         )
         eq_([dict(row._mapping) for row in result4], [{"persons": 10}])
 
-    def test_delete_returning(self):
-        table.insert().execute(
-            [{"persons": 5, "full": False}, {"persons": 3, "full": False}]
+    def test_delete_returning(self, connection):
+        connection.execute(
+            table.insert(),
+            [{"persons": 5, "full": False}, {"persons": 3, "full": False}],
         )
 
-        result = (
-            table.delete(table.c.persons > 4).returning(table.c.id).execute()
+        result = connection.execute(
+            table.delete(table.c.persons > 4).returning(table.c.id)
         )
         eq_(result.fetchall(), [(1,)])
 
-        result2 = (
-            select([table.c.id, table.c.full]).order_by(table.c.id).execute()
+        result2 = connection.execute(
+            select([table.c.id, table.c.full]).order_by(table.c.id)
         )
         eq_(result2.fetchall(), [(2, False)])
 
@@ -204,7 +199,7 @@ class CompositeStatementTest(fixtures.TestBase):
     __backend__ = True
 
     @testing.provide_metadata
-    def test_select_doesnt_pollute_result(self):
+    def test_select_doesnt_pollute_result(self, connection):
         class MyType(TypeDecorator):
             impl = Integer
 
@@ -215,18 +210,17 @@ class CompositeStatementTest(fixtures.TestBase):
 
         t2 = Table("t2", self.metadata, Column("x", Integer))
 
-        self.metadata.create_all(testing.db)
-        with testing.db.connect() as conn:
-            conn.execute(t1.insert().values(x=5))
+        self.metadata.create_all(connection)
+        connection.execute(t1.insert().values(x=5))
 
-            stmt = (
-                t2.insert()
-                .values(x=select([t1.c.x]).scalar_subquery())
-                .returning(t2.c.x)
-            )
+        stmt = (
+            t2.insert()
+            .values(x=select([t1.c.x]).scalar_subquery())
+            .returning(t2.c.x)
+        )
 
-            result = conn.execute(stmt)
-            eq_(result.scalar(), 5)
+        result = connection.execute(stmt)
+        eq_(result.scalar(), 5)
 
 
 class SequenceReturningTest(fixtures.TestBase):
@@ -243,15 +237,19 @@ class SequenceReturningTest(fixtures.TestBase):
             Column("id", Integer, seq, primary_key=True),
             Column("data", String(50)),
         )
-        table.create(checkfirst=True)
+        with testing.db.connect() as conn:
+            table.create(conn, checkfirst=True)
 
     def teardown(self):
-        table.drop()
+        with testing.db.connect() as conn:
+            table.drop(conn)
 
-    def test_insert(self):
-        r = table.insert().values(data="hi").returning(table.c.id).execute()
+    def test_insert(self, connection):
+        r = connection.execute(
+            table.insert().values(data="hi").returning(table.c.id)
+        )
         assert r.first() == (1,)
-        assert seq.execute() == 2
+        assert connection.execute(seq) == 2
 
 
 class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults):
@@ -277,21 +275,23 @@ class KeyReturningTest(fixtures.TestBase, AssertsExecutionResults):
             ),
             Column("data", String(20)),
         )
-        table.create(checkfirst=True)
+        with testing.db.connect() as conn:
+            table.create(conn, checkfirst=True)
 
     def teardown(self):
-        table.drop()
+        with testing.db.connect() as conn:
+            table.drop(conn)
 
     @testing.exclude("firebird", "<", (2, 0), "2.0+ feature")
     @testing.exclude("postgresql", "<", (8, 2), "8.2+ feature")
-    def test_insert(self):
-        result = (
-            table.insert().returning(table.c.foo_id).execute(data="somedata")
+    def test_insert(self, connection):
+        result = connection.execute(
+            table.insert().returning(table.c.foo_id), data="somedata"
         )
         row = result.first()._mapping
         assert row[table.c.foo_id] == row["id"] == 1
 
-        result = table.select().execute().first()._mapping
+        result = connection.execute(table.select()).first()._mapping
         assert row[table.c.foo_id] == row["id"] == 1
 
 
@@ -325,9 +325,9 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             Column("upddef", Integer, onupdate=IncDefault()),
         )
 
-    def test_chained_insert_pk(self):
+    def test_chained_insert_pk(self, connection):
         t1 = self.tables.t1
-        result = testing.db.execute(
+        result = connection.execute(
             t1.insert().values(upddef=1).return_defaults(t1.c.insdef)
         )
         eq_(
@@ -338,9 +338,9 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             [1, 0],
         )
 
-    def test_arg_insert_pk(self):
+    def test_arg_insert_pk(self, connection):
         t1 = self.tables.t1
-        result = testing.db.execute(
+        result = connection.execute(
             t1.insert(return_defaults=[t1.c.insdef]).values(upddef=1)
         )
         eq_(
@@ -351,32 +351,32 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             [1, 0],
         )
 
-    def test_chained_update_pk(self):
+    def test_chained_update_pk(self, connection):
         t1 = self.tables.t1
-        testing.db.execute(t1.insert().values(upddef=1))
-        result = testing.db.execute(
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(
             t1.update().values(data="d1").return_defaults(t1.c.upddef)
         )
         eq_(
             [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1]
         )
 
-    def test_arg_update_pk(self):
+    def test_arg_update_pk(self, connection):
         t1 = self.tables.t1
-        testing.db.execute(t1.insert().values(upddef=1))
-        result = testing.db.execute(
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(
             t1.update(return_defaults=[t1.c.upddef]).values(data="d1")
         )
         eq_(
             [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1]
         )
 
-    def test_insert_non_default(self):
+    def test_insert_non_default(self, connection):
         """test that a column not marked at all as a
         default works with this feature."""
 
         t1 = self.tables.t1
-        result = testing.db.execute(
+        result = connection.execute(
             t1.insert().values(upddef=1).return_defaults(t1.c.data)
         )
         eq_(
@@ -387,13 +387,13 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             [1, None],
         )
 
-    def test_update_non_default(self):
+    def test_update_non_default(self, connection):
         """test that a column not marked at all as a
         default works with this feature."""
 
         t1 = self.tables.t1
-        testing.db.execute(t1.insert().values(upddef=1))
-        result = testing.db.execute(
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(
             t1.update().values(upddef=2).return_defaults(t1.c.data)
         )
         eq_(
@@ -401,9 +401,9 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             [None],
         )
 
-    def test_insert_non_default_plus_default(self):
+    def test_insert_non_default_plus_default(self, connection):
         t1 = self.tables.t1
-        result = testing.db.execute(
+        result = connection.execute(
             t1.insert()
             .values(upddef=1)
             .return_defaults(t1.c.data, t1.c.insdef)
@@ -413,10 +413,10 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             {"id": 1, "data": None, "insdef": 0},
         )
 
-    def test_update_non_default_plus_default(self):
+    def test_update_non_default_plus_default(self, connection):
         t1 = self.tables.t1
-        testing.db.execute(t1.insert().values(upddef=1))
-        result = testing.db.execute(
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(
             t1.update()
             .values(insdef=2)
             .return_defaults(t1.c.data, t1.c.upddef)
@@ -426,9 +426,9 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             {"data": None, "upddef": 1},
         )
 
-    def test_insert_all(self):
+    def test_insert_all(self, connection):
         t1 = self.tables.t1
-        result = testing.db.execute(
+        result = connection.execute(
             t1.insert().values(upddef=1).return_defaults()
         )
         eq_(
@@ -436,10 +436,10 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             {"id": 1, "data": None, "insdef": 0},
         )
 
-    def test_update_all(self):
+    def test_update_all(self, connection):
         t1 = self.tables.t1
-        testing.db.execute(t1.insert().values(upddef=1))
-        result = testing.db.execute(
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(
             t1.update().values(insdef=2).return_defaults()
         )
         eq_(dict(result.returned_defaults._mapping), {"upddef": 1})