]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Split up Insert/InsertMany/Delete/Update returning tests 7047/head
authorDaniel Black <daniel@mariadb.org>
Mon, 13 Sep 2021 15:13:20 +0000 (01:13 +1000)
committerDaniel Black <daniel@mariadb.org>
Sat, 25 Sep 2021 06:23:19 +0000 (16:23 +1000)
test/sql/test_returning.py

index e2612ebed9cead341e7956136d58ddd04b691481..95e4c2ac01320539e1e7e37677d5f6fbabaf6c27 100644 (file)
@@ -90,8 +90,8 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
         )
 
 
-class ReturningTest(fixtures.TablesTest, AssertsExecutionResults):
-    __requires__ = ("returning",)
+class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults):
+    __requires__ = ("insert_returning",)
     __backend__ = True
 
     run_create_tables = "each"
@@ -181,26 +181,6 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults):
         row = result.first()
         eq_(row[0], 30)
 
-    def test_update_returning(self, connection):
-        table = self.tables.tables
-        connection.execute(
-            table.insert(),
-            [{"persons": 5, "full": False}, {"persons": 3, "full": False}],
-        )
-
-        result = connection.execute(
-            table.update()
-            .values(dict(full=True))
-            .where(table.c.persons > 4)
-            .returning(table.c.id)
-        )
-        eq_(result.fetchall(), [(1,)])
-
-        result2 = connection.execute(
-            select(table.c.id, table.c.full).order_by(table.c.id)
-        )
-        eq_(result2.fetchall(), [(1, True), (2, False)])
-
     @testing.fails_on(
         "mssql",
         "driver has unknown issue with string concatenation "
@@ -234,6 +214,93 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults):
         )
         eq_(result2.fetchall(), [(1, "FOOsomegoofyBAR")])
 
+    def test_no_ipk_on_returning(self, connection):
+        table = self.tables.tables
+        result = connection.execute(
+            table.insert().returning(table.c.id), {"persons": 1, "full": False}
+        )
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            r"Can't call inserted_primary_key when returning\(\) is used.",
+            getattr,
+            result,
+            "inserted_primary_key",
+        )
+
+    def test_insert_returning(self, connection):
+        table = self.tables.tables
+        result = connection.execute(
+            table.insert().returning(table.c.id), {"persons": 1, "full": False}
+        )
+
+        eq_(result.fetchall(), [(1,)])
+
+    @testing.requires.multivalues_inserts
+    def test_multirow_returning(self, connection):
+        table = self.tables.tables
+        ins = (
+            table.insert()
+            .returning(table.c.id, table.c.persons)
+            .values(
+                [
+                    {"persons": 1, "full": False},
+                    {"persons": 2, "full": True},
+                    {"persons": 3, "full": False},
+                ]
+            )
+        )
+        result = connection.execute(ins)
+        eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)])
+
+    @testing.fails_on_everything_except("postgresql", "mariadb>=10.5",
+                                        "firebird")
+    def test_literal_returning(self, connection):
+        if testing.against("mariadb"):
+            quote = "`"
+        else:
+            quote = '"'
+        if testing.against("postgresql"):
+            literal_true = "true"
+        else:
+            literal_true = "1"
+
+        result4 = connection.exec_driver_sql(
+            'insert into tables (id, persons, %sfull%s) '
+            "values (5, 10, %s) returning persons" % (quote, quote,
+                                                      literal_true)
+        )
+        eq_([dict(row._mapping) for row in result4], [{"persons": 10}])
+
+
+class UpdateReturningTest(fixtures.TablesTest, AssertsExecutionResults):
+    __requires__ = ("returning",)
+    __backend__ = True
+
+    run_create_tables = "each"
+
+    define_tables = InsertReturningTest.define_tables
+
+    @testing.requires.returning
+    def test_update_returning(self, connection):
+        table = self.tables.tables
+        connection.execute(
+            table.insert(),
+            [{"persons": 5, "full": False}, {"persons": 3, "full": False}],
+        )
+
+        result = connection.execute(
+            table.update()
+            .values(dict(full=True))
+            .where(table.c.persons > 4)
+            .returning(table.c.id)
+        )
+        eq_(result.fetchall(), [(1,)])
+
+        result2 = connection.execute(
+            select(table.c.id, table.c.full).order_by(table.c.id)
+        )
+        eq_(result2.fetchall(), [(1, True), (2, False)])
+
     def test_update_returning_w_expression_one(self, connection):
         table = self.tables.tables
         connection.execute(
@@ -299,69 +366,14 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults):
         )
         eq_(result.fetchall(), [(1, True), (2, True)])
 
-    @testing.requires.delete_returning
-    def test_delete_returning(self, connection):
-        table = self.tables.tables
-        connection.execute(
-            table.insert(),
-            [{"persons": 5, "full": False}, {"persons": 3, "full": False}],
-        )
-
-        result = connection.execute(
-            table.delete().returning(table.c.id, table.c.full)
-        )
-        eq_(result.fetchall(), [(1, False), (2, False)])
-
-    def test_insert_returning(self, connection):
-        table = self.tables.tables
-        result = connection.execute(
-            table.insert().returning(table.c.id), {"persons": 1, "full": False}
-        )
-
-        eq_(result.fetchall(), [(1,)])
-
-    @testing.requires.multivalues_inserts
-    def test_multirow_returning(self, connection):
-        table = self.tables.tables
-        ins = (
-            table.insert()
-            .returning(table.c.id, table.c.persons)
-            .values(
-                [
-                    {"persons": 1, "full": False},
-                    {"persons": 2, "full": True},
-                    {"persons": 3, "full": False},
-                ]
-            )
-        )
-        result = connection.execute(ins)
-        eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)])
 
-    def test_no_ipk_on_returning(self, connection):
-        table = self.tables.tables
-        result = connection.execute(
-            table.insert().returning(table.c.id), {"persons": 1, "full": False}
-        )
-        assert_raises_message(
-            sa_exc.InvalidRequestError,
-            r"Can't call inserted_primary_key when returning\(\) is used.",
-            getattr,
-            result,
-            "inserted_primary_key",
-        )
+class DeleteReturningTest(fixtures.TablesTest, AssertsExecutionResults):
+    __requires__ = ("delete_returning",)
+    __backend__ = True
 
-    @testing.fails_on_everything_except("postgresql", "firebird")
-    def test_literal_returning(self, connection):
-        if testing.against("postgresql"):
-            literal_true = "true"
-        else:
-            literal_true = "1"
+    run_create_tables = "each"
 
-        result4 = connection.exec_driver_sql(
-            'insert into tables (id, persons, "full") '
-            "values (5, 10, %s) returning persons" % literal_true
-        )
-        eq_([dict(row._mapping) for row in result4], [{"persons": 10}])
+    define_tables = InsertReturningTest.define_tables
 
     def test_delete_returning(self, connection):
         table = self.tables.tables
@@ -382,7 +394,7 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults):
 
 
 class CompositeStatementTest(fixtures.TestBase):
-    __requires__ = ("returning",)
+    __requires__ = ("insert_returning",)
     __backend__ = True
 
     @testing.provide_metadata
@@ -412,7 +424,7 @@ class CompositeStatementTest(fixtures.TestBase):
 
 
 class SequenceReturningTest(fixtures.TablesTest):
-    __requires__ = "returning", "sequences"
+    __requires__ = "insert_returning", "sequences"
     __backend__ = True
 
     @classmethod
@@ -447,7 +459,7 @@ class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults):
 
     """test returning() works with columns that define 'key'."""
 
-    __requires__ = ("returning",)
+    __requires__ = ("insert_returning",)
     __backend__ = True
 
     @classmethod
@@ -479,7 +491,7 @@ class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults):
         assert row[table.c.foo_id] == row["id"] == 1
 
 
-class ReturnDefaultsTest(fixtures.TablesTest):
+class InsertReturnDefaultsTest(fixtures.TablesTest):
     __requires__ = ("returning",)
     run_define_tables = "each"
     __backend__ = True
@@ -535,67 +547,99 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             [1, 0],
         )
 
-    def test_chained_update_pk(self, connection):
+    def test_insert_non_default(self, connection):
+        """test that a column not marked at all as a
+        default works with this feature."""
+
         t1 = self.tables.t1
-        connection.execute(t1.insert().values(upddef=1))
         result = connection.execute(
-            t1.update().values(data="d1").return_defaults(t1.c.upddef)
+            t1.insert().values(upddef=1).return_defaults(t1.c.data)
         )
         eq_(
-            [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1]
+            [
+                result.returned_defaults._mapping[k]
+                for k in (t1.c.id, t1.c.data)
+            ],
+            [1, None],
         )
 
-    def test_arg_update_pk(self, connection):
+    def test_insert_sql_expr(self, connection):
+        from sqlalchemy import literal
+
         t1 = self.tables.t1
-        connection.execute(t1.insert().values(upddef=1))
         result = connection.execute(
-            t1.update().return_defaults(t1.c.upddef).values(data="d1")
+            t1.insert().return_defaults().values(insdef=literal(10) + 5)
         )
+
         eq_(
-            [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1]
+            result.returned_defaults._mapping,
+            {"id": 1, "data": None, "insdef": 15, "upddef": None},
         )
 
-    def test_insert_non_default(self, connection):
-        """test that a column not marked at all as a
-        default works with this feature."""
+    def test_insert_non_default_plus_default(self, connection):
+        t1 = self.tables.t1
+        result = connection.execute(
+            t1.insert()
+            .values(upddef=1)
+            .return_defaults(t1.c.data, t1.c.insdef)
+        )
+        eq_(
+            dict(result.returned_defaults._mapping),
+            {"id": 1, "data": None, "insdef": 0},
+        )
+        eq_(result.inserted_primary_key, (1,))
 
+    def test_insert_all(self, connection):
         t1 = self.tables.t1
         result = connection.execute(
-            t1.insert().values(upddef=1).return_defaults(t1.c.data)
+            t1.insert().values(upddef=1).return_defaults()
         )
         eq_(
-            [
-                result.returned_defaults._mapping[k]
-                for k in (t1.c.id, t1.c.data)
-            ],
-            [1, None],
+            dict(result.returned_defaults._mapping),
+            {"id": 1, "data": None, "insdef": 0},
         )
+        eq_(result.inserted_primary_key, (1,))
 
-    def test_update_non_default(self, connection):
-        """test that a column not marked at all as a
-        default works with this feature."""
 
+class UpdatedReturnDefaultsTest(fixtures.TablesTest):
+    __requires__ = ("returning",)
+    run_define_tables = "each"
+    __backend__ = True
+
+    define_tables = InsertReturnDefaultsTest.define_tables
+
+    def test_chained_update_pk(self, connection):
         t1 = self.tables.t1
         connection.execute(t1.insert().values(upddef=1))
         result = connection.execute(
-            t1.update().values(upddef=2).return_defaults(t1.c.data)
+            t1.update().values(data="d1").return_defaults(t1.c.upddef)
         )
         eq_(
-            [result.returned_defaults._mapping[k] for k in (t1.c.data,)],
-            [None],
+            [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1]
         )
 
-    def test_insert_sql_expr(self, connection):
-        from sqlalchemy import literal
-
+    def test_arg_update_pk(self, connection):
         t1 = self.tables.t1
+        connection.execute(t1.insert().values(upddef=1))
         result = connection.execute(
-            t1.insert().return_defaults().values(insdef=literal(10) + 5)
+            t1.update().return_defaults(t1.c.upddef).values(data="d1")
+        )
+        eq_(
+            [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1]
         )
 
+    def test_update_non_default(self, connection):
+        """test that a column not marked at all as a
+        default works with this feature."""
+
+        t1 = self.tables.t1
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(
+            t1.update().values(upddef=2).return_defaults(t1.c.data)
+        )
         eq_(
-            result.returned_defaults._mapping,
-            {"id": 1, "data": None, "insdef": 15, "upddef": None},
+            [result.returned_defaults._mapping[k] for k in (t1.c.data,)],
+            [None],
         )
 
     def test_update_sql_expr(self, connection):
@@ -609,19 +653,6 @@ class ReturnDefaultsTest(fixtures.TablesTest):
 
         eq_(result.returned_defaults._mapping, {"upddef": 15})
 
-    def test_insert_non_default_plus_default(self, connection):
-        t1 = self.tables.t1
-        result = connection.execute(
-            t1.insert()
-            .values(upddef=1)
-            .return_defaults(t1.c.data, t1.c.insdef)
-        )
-        eq_(
-            dict(result.returned_defaults._mapping),
-            {"id": 1, "data": None, "insdef": 0},
-        )
-        eq_(result.inserted_primary_key, (1,))
-
     def test_update_non_default_plus_default(self, connection):
         t1 = self.tables.t1
         connection.execute(t1.insert().values(upddef=1))
@@ -635,17 +666,6 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             {"data": None, "upddef": 1},
         )
 
-    def test_insert_all(self, connection):
-        t1 = self.tables.t1
-        result = connection.execute(
-            t1.insert().values(upddef=1).return_defaults()
-        )
-        eq_(
-            dict(result.returned_defaults._mapping),
-            {"id": 1, "data": None, "insdef": 0},
-        )
-        eq_(result.inserted_primary_key, (1,))
-
     def test_update_all(self, connection):
         t1 = self.tables.t1
         connection.execute(t1.insert().values(upddef=1))
@@ -654,7 +674,14 @@ class ReturnDefaultsTest(fixtures.TablesTest):
         )
         eq_(dict(result.returned_defaults._mapping), {"upddef": 1})
 
-    @testing.requires.insert_executemany_returning
+
+class InsertManyReturnDefaultsTest(fixtures.TablesTest):
+    __requires__ = ("insert_executemany_returning",)
+    run_define_tables = "each"
+    __backend__ = True
+
+    define_tables = InsertReturnDefaultsTest.define_tables
+
     def test_insert_executemany_no_defaults_passed(self, connection):
         t1 = self.tables.t1
         result = connection.execute(
@@ -698,7 +725,6 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             lambda: result.inserted_primary_key,
         )
 
-    @testing.requires.insert_executemany_returning
     def test_insert_executemany_insdefault_passed(self, connection):
         t1 = self.tables.t1
         result = connection.execute(
@@ -742,7 +768,6 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             lambda: result.inserted_primary_key,
         )
 
-    @testing.requires.insert_executemany_returning
     def test_insert_executemany_only_pk_passed(self, connection):
         t1 = self.tables.t1
         result = connection.execute(