From d25d5ea3abe094f282c53c7dd87f5f53a9e85248 Mon Sep 17 00:00:00 2001 From: Daniel Black Date: Tue, 14 Sep 2021 01:13:20 +1000 Subject: [PATCH] Split up Insert/InsertMany/Delete/Update returning tests --- test/sql/test_returning.py | 309 ++++++++++++++++++++----------------- 1 file changed, 167 insertions(+), 142 deletions(-) diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index e2612ebed9..95e4c2ac01 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -90,8 +90,8 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL): ) -class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): - __requires__ = ("returning",) +class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults): + __requires__ = ("insert_returning",) __backend__ = True run_create_tables = "each" @@ -181,26 +181,6 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): row = result.first() eq_(row[0], 30) - def test_update_returning(self, connection): - table = self.tables.tables - connection.execute( - table.insert(), - [{"persons": 5, "full": False}, {"persons": 3, "full": False}], - ) - - result = connection.execute( - table.update() - .values(dict(full=True)) - .where(table.c.persons > 4) - .returning(table.c.id) - ) - eq_(result.fetchall(), [(1,)]) - - result2 = connection.execute( - select(table.c.id, table.c.full).order_by(table.c.id) - ) - eq_(result2.fetchall(), [(1, True), (2, False)]) - @testing.fails_on( "mssql", "driver has unknown issue with string concatenation " @@ -234,6 +214,93 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): ) eq_(result2.fetchall(), [(1, "FOOsomegoofyBAR")]) + def test_no_ipk_on_returning(self, connection): + table = self.tables.tables + result = connection.execute( + table.insert().returning(table.c.id), {"persons": 1, "full": False} + ) + assert_raises_message( + sa_exc.InvalidRequestError, + r"Can't call inserted_primary_key when returning\(\) is used.", + getattr, + result, + "inserted_primary_key", + ) + + def test_insert_returning(self, connection): + table = self.tables.tables + result = connection.execute( + table.insert().returning(table.c.id), {"persons": 1, "full": False} + ) + + eq_(result.fetchall(), [(1,)]) + + @testing.requires.multivalues_inserts + def test_multirow_returning(self, connection): + table = self.tables.tables + ins = ( + table.insert() + .returning(table.c.id, table.c.persons) + .values( + [ + {"persons": 1, "full": False}, + {"persons": 2, "full": True}, + {"persons": 3, "full": False}, + ] + ) + ) + result = connection.execute(ins) + eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)]) + + @testing.fails_on_everything_except("postgresql", "mariadb>=10.5", + "firebird") + def test_literal_returning(self, connection): + if testing.against("mariadb"): + quote = "`" + else: + quote = '"' + if testing.against("postgresql"): + literal_true = "true" + else: + literal_true = "1" + + result4 = connection.exec_driver_sql( + 'insert into tables (id, persons, %sfull%s) ' + "values (5, 10, %s) returning persons" % (quote, quote, + literal_true) + ) + eq_([dict(row._mapping) for row in result4], [{"persons": 10}]) + + +class UpdateReturningTest(fixtures.TablesTest, AssertsExecutionResults): + __requires__ = ("returning",) + __backend__ = True + + run_create_tables = "each" + + define_tables = InsertReturningTest.define_tables + + @testing.requires.returning + def test_update_returning(self, connection): + table = self.tables.tables + connection.execute( + table.insert(), + [{"persons": 5, "full": False}, {"persons": 3, "full": False}], + ) + + result = connection.execute( + table.update() + .values(dict(full=True)) + .where(table.c.persons > 4) + .returning(table.c.id) + ) + eq_(result.fetchall(), [(1,)]) + + result2 = connection.execute( + select(table.c.id, table.c.full).order_by(table.c.id) + ) + eq_(result2.fetchall(), [(1, True), (2, False)]) + def test_update_returning_w_expression_one(self, connection): table = self.tables.tables connection.execute( @@ -299,69 +366,14 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): ) eq_(result.fetchall(), [(1, True), (2, True)]) - @testing.requires.delete_returning - def test_delete_returning(self, connection): - table = self.tables.tables - connection.execute( - table.insert(), - [{"persons": 5, "full": False}, {"persons": 3, "full": False}], - ) - - result = connection.execute( - table.delete().returning(table.c.id, table.c.full) - ) - eq_(result.fetchall(), [(1, False), (2, False)]) - - def test_insert_returning(self, connection): - table = self.tables.tables - result = connection.execute( - table.insert().returning(table.c.id), {"persons": 1, "full": False} - ) - - eq_(result.fetchall(), [(1,)]) - - @testing.requires.multivalues_inserts - def test_multirow_returning(self, connection): - table = self.tables.tables - ins = ( - table.insert() - .returning(table.c.id, table.c.persons) - .values( - [ - {"persons": 1, "full": False}, - {"persons": 2, "full": True}, - {"persons": 3, "full": False}, - ] - ) - ) - result = connection.execute(ins) - eq_(result.fetchall(), [(1, 1), (2, 2), (3, 3)]) - def test_no_ipk_on_returning(self, connection): - table = self.tables.tables - result = connection.execute( - table.insert().returning(table.c.id), {"persons": 1, "full": False} - ) - assert_raises_message( - sa_exc.InvalidRequestError, - r"Can't call inserted_primary_key when returning\(\) is used.", - getattr, - result, - "inserted_primary_key", - ) +class DeleteReturningTest(fixtures.TablesTest, AssertsExecutionResults): + __requires__ = ("delete_returning",) + __backend__ = True - @testing.fails_on_everything_except("postgresql", "firebird") - def test_literal_returning(self, connection): - if testing.against("postgresql"): - literal_true = "true" - else: - literal_true = "1" + run_create_tables = "each" - result4 = connection.exec_driver_sql( - 'insert into tables (id, persons, "full") ' - "values (5, 10, %s) returning persons" % literal_true - ) - eq_([dict(row._mapping) for row in result4], [{"persons": 10}]) + define_tables = InsertReturningTest.define_tables def test_delete_returning(self, connection): table = self.tables.tables @@ -382,7 +394,7 @@ class ReturningTest(fixtures.TablesTest, AssertsExecutionResults): class CompositeStatementTest(fixtures.TestBase): - __requires__ = ("returning",) + __requires__ = ("insert_returning",) __backend__ = True @testing.provide_metadata @@ -412,7 +424,7 @@ class CompositeStatementTest(fixtures.TestBase): class SequenceReturningTest(fixtures.TablesTest): - __requires__ = "returning", "sequences" + __requires__ = "insert_returning", "sequences" __backend__ = True @classmethod @@ -447,7 +459,7 @@ class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults): """test returning() works with columns that define 'key'.""" - __requires__ = ("returning",) + __requires__ = ("insert_returning",) __backend__ = True @classmethod @@ -479,7 +491,7 @@ class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults): assert row[table.c.foo_id] == row["id"] == 1 -class ReturnDefaultsTest(fixtures.TablesTest): +class InsertReturnDefaultsTest(fixtures.TablesTest): __requires__ = ("returning",) run_define_tables = "each" __backend__ = True @@ -535,67 +547,99 @@ class ReturnDefaultsTest(fixtures.TablesTest): [1, 0], ) - def test_chained_update_pk(self, connection): + def test_insert_non_default(self, connection): + """test that a column not marked at all as a + default works with this feature.""" + t1 = self.tables.t1 - connection.execute(t1.insert().values(upddef=1)) result = connection.execute( - t1.update().values(data="d1").return_defaults(t1.c.upddef) + t1.insert().values(upddef=1).return_defaults(t1.c.data) ) eq_( - [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1] + [ + result.returned_defaults._mapping[k] + for k in (t1.c.id, t1.c.data) + ], + [1, None], ) - def test_arg_update_pk(self, connection): + def test_insert_sql_expr(self, connection): + from sqlalchemy import literal + t1 = self.tables.t1 - connection.execute(t1.insert().values(upddef=1)) result = connection.execute( - t1.update().return_defaults(t1.c.upddef).values(data="d1") + t1.insert().return_defaults().values(insdef=literal(10) + 5) ) + eq_( - [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1] + result.returned_defaults._mapping, + {"id": 1, "data": None, "insdef": 15, "upddef": None}, ) - def test_insert_non_default(self, connection): - """test that a column not marked at all as a - default works with this feature.""" + def test_insert_non_default_plus_default(self, connection): + t1 = self.tables.t1 + result = connection.execute( + t1.insert() + .values(upddef=1) + .return_defaults(t1.c.data, t1.c.insdef) + ) + eq_( + dict(result.returned_defaults._mapping), + {"id": 1, "data": None, "insdef": 0}, + ) + eq_(result.inserted_primary_key, (1,)) + def test_insert_all(self, connection): t1 = self.tables.t1 result = connection.execute( - t1.insert().values(upddef=1).return_defaults(t1.c.data) + t1.insert().values(upddef=1).return_defaults() ) eq_( - [ - result.returned_defaults._mapping[k] - for k in (t1.c.id, t1.c.data) - ], - [1, None], + dict(result.returned_defaults._mapping), + {"id": 1, "data": None, "insdef": 0}, ) + eq_(result.inserted_primary_key, (1,)) - def test_update_non_default(self, connection): - """test that a column not marked at all as a - default works with this feature.""" +class UpdatedReturnDefaultsTest(fixtures.TablesTest): + __requires__ = ("returning",) + run_define_tables = "each" + __backend__ = True + + define_tables = InsertReturnDefaultsTest.define_tables + + def test_chained_update_pk(self, connection): t1 = self.tables.t1 connection.execute(t1.insert().values(upddef=1)) result = connection.execute( - t1.update().values(upddef=2).return_defaults(t1.c.data) + t1.update().values(data="d1").return_defaults(t1.c.upddef) ) eq_( - [result.returned_defaults._mapping[k] for k in (t1.c.data,)], - [None], + [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1] ) - def test_insert_sql_expr(self, connection): - from sqlalchemy import literal - + def test_arg_update_pk(self, connection): t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) result = connection.execute( - t1.insert().return_defaults().values(insdef=literal(10) + 5) + t1.update().return_defaults(t1.c.upddef).values(data="d1") + ) + eq_( + [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1] ) + def test_update_non_default(self, connection): + """test that a column not marked at all as a + default works with this feature.""" + + t1 = self.tables.t1 + connection.execute(t1.insert().values(upddef=1)) + result = connection.execute( + t1.update().values(upddef=2).return_defaults(t1.c.data) + ) eq_( - result.returned_defaults._mapping, - {"id": 1, "data": None, "insdef": 15, "upddef": None}, + [result.returned_defaults._mapping[k] for k in (t1.c.data,)], + [None], ) def test_update_sql_expr(self, connection): @@ -609,19 +653,6 @@ class ReturnDefaultsTest(fixtures.TablesTest): eq_(result.returned_defaults._mapping, {"upddef": 15}) - def test_insert_non_default_plus_default(self, connection): - t1 = self.tables.t1 - result = connection.execute( - t1.insert() - .values(upddef=1) - .return_defaults(t1.c.data, t1.c.insdef) - ) - eq_( - dict(result.returned_defaults._mapping), - {"id": 1, "data": None, "insdef": 0}, - ) - eq_(result.inserted_primary_key, (1,)) - def test_update_non_default_plus_default(self, connection): t1 = self.tables.t1 connection.execute(t1.insert().values(upddef=1)) @@ -635,17 +666,6 @@ class ReturnDefaultsTest(fixtures.TablesTest): {"data": None, "upddef": 1}, ) - def test_insert_all(self, connection): - t1 = self.tables.t1 - result = connection.execute( - t1.insert().values(upddef=1).return_defaults() - ) - eq_( - dict(result.returned_defaults._mapping), - {"id": 1, "data": None, "insdef": 0}, - ) - eq_(result.inserted_primary_key, (1,)) - def test_update_all(self, connection): t1 = self.tables.t1 connection.execute(t1.insert().values(upddef=1)) @@ -654,7 +674,14 @@ class ReturnDefaultsTest(fixtures.TablesTest): ) eq_(dict(result.returned_defaults._mapping), {"upddef": 1}) - @testing.requires.insert_executemany_returning + +class InsertManyReturnDefaultsTest(fixtures.TablesTest): + __requires__ = ("insert_executemany_returning",) + run_define_tables = "each" + __backend__ = True + + define_tables = InsertReturnDefaultsTest.define_tables + def test_insert_executemany_no_defaults_passed(self, connection): t1 = self.tables.t1 result = connection.execute( @@ -698,7 +725,6 @@ class ReturnDefaultsTest(fixtures.TablesTest): lambda: result.inserted_primary_key, ) - @testing.requires.insert_executemany_returning def test_insert_executemany_insdefault_passed(self, connection): t1 = self.tables.t1 result = connection.execute( @@ -742,7 +768,6 @@ class ReturnDefaultsTest(fixtures.TablesTest): lambda: result.inserted_primary_key, ) - @testing.requires.insert_executemany_returning def test_insert_executemany_only_pk_passed(self, connection): t1 = self.tables.t1 result = connection.execute( -- 2.47.3