From: Mike Bayer Date: Wed, 6 Jul 2022 01:05:18 +0000 (-0400) Subject: generalize sql server check for id col to accommodate ORM cases X-Git-Tag: rel_1_4_40~33 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=844365266aeb2582d775d019c48e7ffa6113c673;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git generalize sql server check for id col to accommodate ORM cases Fixed issues that prevented the new usage patterns for using DML with ORM objects presented at :ref:`orm_dml_returning_objects` from working correctly with the SQL Server pyodbc dialect. Here we add a step to look in compile_state._dict_values more thoroughly for the keys we need to determine "identity insert" or not, and also add a new compiler variable dml_compile_state so that we can skip the ORM's compile_state if present. Fixes: #8210 Change-Id: Idbd76bb3eb075c647dc6c1cb78f7315c821e15f7 (cherry picked from commit 5806428800d2f1ac775156f90497a2fc3a644f35) --- diff --git a/doc/build/changelog/unreleased_14/8210.rst b/doc/build/changelog/unreleased_14/8210.rst new file mode 100644 index 0000000000..f99d86194f --- /dev/null +++ b/doc/build/changelog/unreleased_14/8210.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, mssql + :tickets: 8210 + + Fixed issues that prevented the new usage patterns for using DML with ORM + objects presented at :ref:`orm_dml_returning_objects` from working + correctly with the SQL Server pyodbc dialect. + diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 1658f27c70..735cc3cff8 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1651,15 +1651,12 @@ class MSExecutionContext(default.DefaultExecutionContext): ) if insert_has_identity: - compile_state = self.compiled.compile_state + compile_state = self.compiled.dml_compile_state self._enable_identity_insert = ( id_column.key in self.compiled_parameters[0] ) or ( compile_state._dict_parameters - and ( - id_column.key in compile_state._dict_parameters - or id_column in compile_state._dict_parameters - ) + and (id_column.key in compile_state._insert_col_keys) ) else: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 477c199c17..667dd7d3de 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -402,6 +402,18 @@ class Compiled(object): """ + dml_compile_state = None + """Optional :class:`.CompileState` assigned at the same point that + .isinsert, .isupdate, or .isdelete is assigned. + + This will normally be the same object as .compile_state, with the + exception of cases like the :class:`.ORMFromStatementCompileState` + object. + + .. versionadded:: 1.4.40 + + """ + cache_key = None _gen_time = None @@ -3838,6 +3850,8 @@ class SQLCompiler(Compiled): if toplevel: self.isinsert = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state if not self.compile_state: self.compile_state = compile_state @@ -4008,6 +4022,8 @@ class SQLCompiler(Compiled): toplevel = not self.stack if toplevel: self.isupdate = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state if not self.compile_state: self.compile_state = compile_state @@ -4134,6 +4150,8 @@ class SQLCompiler(Compiled): toplevel = not self.stack if toplevel: self.isdelete = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state if not self.compile_state: self.compile_state = compile_state diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index dea5d6119d..4a343147c9 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -189,6 +189,14 @@ class InsertDMLState(DMLState): if statement._multi_values: self._process_multi_values(statement) + @util.memoized_property + def _insert_col_keys(self): + # this is also done in crud.py -> _key_getters_for_crud_column + return [ + coercions.expect_as_key(roles.DMLColumnRole, col) + for col in self._dict_parameters + ] + @CompileState.plugin_for("default", "update") class UpdateDMLState(DMLState): diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 255d70f414..4eabe2f6c4 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -2217,21 +2217,45 @@ class LoadFromReturningTest(fixtures.MappedTest): [User(name="jack", age=52), User(name="jill", age=34)], ) - def test_load_from_insert(self, connection): + @testing.combinations( + ("single",), + ("multiple", testing.requires.multivalues_inserts), + argnames="params", + ) + def test_load_from_insert(self, connection, params): User = self.classes.User - stmt = ( - insert(User) - .values({User.id: 5, User.age: 25, User.name: "spongebob"}) - .returning(User) - ) + if params == "multiple": + values = [ + {User.id: 5, User.age: 25, User.name: "spongebob"}, + {User.id: 6, User.age: 30, User.name: "patrick"}, + {User.id: 7, User.age: 35, User.name: "squidward"}, + ] + elif params == "single": + values = {User.id: 5, User.age: 25, User.name: "spongebob"} + else: + assert False + + stmt = insert(User).values(values).returning(User) stmt = select(User).from_statement(stmt) with Session(connection) as sess: rows = sess.execute(stmt).scalars().all() - eq_( - rows, - [User(name="spongebob", age=25)], - ) + if params == "multiple": + eq_( + rows, + [ + User(name="spongebob", age=25), + User(name="patrick", age=30), + User(name="squidward", age=35), + ], + ) + elif params == "single": + eq_( + rows, + [User(name="spongebob", age=25)], + ) + else: + assert False diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index 76b4ba01ea..334df9575e 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -20,6 +20,14 @@ from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table +class ExpectExpr: + def __init__(self, element): + self.element = element + + def __clause_element__(self): + return self.element + + class InsertExecTest(fixtures.TablesTest): __backend__ = True @@ -36,13 +44,27 @@ class InsertExecTest(fixtures.TablesTest): ) @testing.requires.multivalues_inserts - def test_multivalues_insert(self, connection): + @testing.combinations("string", "column", "expect", argnames="keytype") + def test_multivalues_insert(self, connection, keytype): + users = self.tables.users + + if keytype == "string": + user_id, user_name = "user_id", "user_name" + elif keytype == "column": + user_id, user_name = users.c.user_id, users.c.user_name + elif keytype == "expect": + user_id, user_name = ExpectExpr(users.c.user_id), ExpectExpr( + users.c.user_name + ) + else: + assert False + connection.execute( users.insert().values( [ - {"user_id": 7, "user_name": "jack"}, - {"user_id": 8, "user_name": "ed"}, + {user_id: 7, user_name: "jack"}, + {user_id: 8, user_name: "ed"}, ] ) )