From: Mike Bayer Date: Wed, 6 Jul 2022 01:05:18 +0000 (-0400) Subject: generalize sql server check for id col to accommodate ORM cases X-Git-Tag: rel_2_0_0b1~187^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=370d73b4e4b3009ea5feed1341ead965f6aa98bb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git generalize sql server check for id col to accommodate ORM cases Fixed issues that prevented the new usage patterns for using DML with ORM objects presented at :ref:`orm_dml_returning_objects` from working correctly with the SQL Server pyodbc dialect. Here we add a step to look in compile_state._dict_values more thoroughly for the keys we need to determine "identity insert" or not, and also add a new compiler variable dml_compile_state so that we can skip the ORM's compile_state if present. Fixes: #8210 Change-Id: Idbd76bb3eb075c647dc6c1cb78f7315c821e15f7 --- diff --git a/doc/build/changelog/unreleased_14/8210.rst b/doc/build/changelog/unreleased_14/8210.rst new file mode 100644 index 0000000000..f99d86194f --- /dev/null +++ b/doc/build/changelog/unreleased_14/8210.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, mssql + :tickets: 8210 + + Fixed issues that prevented the new usage patterns for using DML with ORM + objects presented at :ref:`orm_dml_returning_objects` from working + correctly with the SQL Server pyodbc dialect. + diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 2a4362ccb9..ed4139ad17 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1751,15 +1751,12 @@ class MSExecutionContext(default.DefaultExecutionContext): not isinstance(id_column.default, Sequence) ): insert_has_identity = True - compile_state = self.compiled.compile_state + compile_state = self.compiled.dml_compile_state self._enable_identity_insert = ( id_column.key in self.compiled_parameters[0] ) or ( compile_state._dict_parameters - and ( - id_column.key in compile_state._dict_parameters - or id_column in compile_state._dict_parameters - ) + and (id_column.key in compile_state._insert_col_keys) ) else: diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index d56035db7b..0b250a28ec 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -162,6 +162,17 @@ def expect( ... +@overload +def expect( + role: Type[roles.DMLColumnRole], + element: Any, + *, + as_key: Literal[True] = ..., + **kw: Any, +) -> str: + ... + + @overload def expect( role: Type[roles.LiteralValueRole], @@ -420,9 +431,11 @@ def expect( ) -def expect_as_key(role, element, **kw): - kw["as_key"] = True - return expect(role, element, **kw) +def expect_as_key( + role: Type[roles.DMLColumnRole], element: Any, **kw: Any +) -> str: + kw.pop("as_key", None) + return expect(role, element, as_key=True, **kw) def expect_col_expression_collection( diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 60ec09771f..a11d83b11c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -529,6 +529,18 @@ class Compiled: """ + dml_compile_state: Optional[CompileState] = None + """Optional :class:`.CompileState` assigned at the same point that + .isinsert, .isupdate, or .isdelete is assigned. + + This will normally be the same object as .compile_state, with the + exception of cases like the :class:`.ORMFromStatementCompileState` + object. + + .. versionadded:: 1.4.40 + + """ + cache_key: Optional[CacheKey] = None """The :class:`.CacheKey` that was generated ahead of creating this :class:`.Compiled` object. @@ -4371,6 +4383,8 @@ class SQLCompiler(Compiled): if toplevel: self.isinsert = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state if not self.compile_state: self.compile_state = compile_state @@ -4548,6 +4562,8 @@ class SQLCompiler(Compiled): toplevel = not self.stack if toplevel: self.isupdate = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state if not self.compile_state: self.compile_state = compile_state @@ -4683,6 +4699,8 @@ class SQLCompiler(Compiled): toplevel = not self.stack if toplevel: self.isdelete = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state if not self.compile_state: self.compile_state = compile_state diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 2ed3be9cbd..e99f354188 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -263,7 +263,7 @@ class DMLState(CompileState): def _process_select_values(self, statement: ValuesBase) -> None: assert statement._select_names is not None - parameters = { + parameters: MutableMapping[_DMLColumnElement, Any] = { coercions.expect(roles.DMLColumnRole, name, as_key=True): Null() for name in statement._select_names } @@ -312,6 +312,14 @@ class InsertDMLState(DMLState): if statement._multi_values: self._process_multi_values(statement) + @util.memoized_property + def _insert_col_keys(self) -> List[str]: + # this is also done in crud.py -> _key_getters_for_crud_column + return [ + coercions.expect_as_key(roles.DMLColumnRole, col) + for col in self._dict_parameters or () + ] + @CompileState.plugin_for("default", "update") class UpdateDMLState(DMLState): diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 22d827be97..933d2bb1fa 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -2253,21 +2253,45 @@ class LoadFromReturningTest(fixtures.MappedTest): [User(name="jack", age=52), User(name="jill", age=34)], ) - def test_load_from_insert(self, connection): + @testing.combinations( + ("single",), + ("multiple", testing.requires.multivalues_inserts), + argnames="params", + ) + def test_load_from_insert(self, connection, params): User = self.classes.User - stmt = ( - insert(User) - .values({User.id: 5, User.age: 25, User.name: "spongebob"}) - .returning(User) - ) + if params == "multiple": + values = [ + {User.id: 5, User.age: 25, User.name: "spongebob"}, + {User.id: 6, User.age: 30, User.name: "patrick"}, + {User.id: 7, User.age: 35, User.name: "squidward"}, + ] + elif params == "single": + values = {User.id: 5, User.age: 25, User.name: "spongebob"} + else: + assert False + + stmt = insert(User).values(values).returning(User) stmt = select(User).from_statement(stmt) with Session(connection) as sess: rows = sess.execute(stmt).scalars().all() - eq_( - rows, - [User(name="spongebob", age=25)], - ) + if params == "multiple": + eq_( + rows, + [ + User(name="spongebob", age=25), + User(name="patrick", age=30), + User(name="squidward", age=35), + ], + ) + elif params == "single": + eq_( + rows, + [User(name="spongebob", age=25)], + ) + else: + assert False diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index 3e51e9450c..b6945813e6 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -19,6 +19,14 @@ from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table +class ExpectExpr: + def __init__(self, element): + self.element = element + + def __clause_element__(self): + return self.element + + class InsertExecTest(fixtures.TablesTest): __backend__ = True @@ -35,13 +43,27 @@ class InsertExecTest(fixtures.TablesTest): ) @testing.requires.multivalues_inserts - def test_multivalues_insert(self, connection): + @testing.combinations("string", "column", "expect", argnames="keytype") + def test_multivalues_insert(self, connection, keytype): + users = self.tables.users + + if keytype == "string": + user_id, user_name = "user_id", "user_name" + elif keytype == "column": + user_id, user_name = users.c.user_id, users.c.user_name + elif keytype == "expect": + user_id, user_name = ExpectExpr(users.c.user_id), ExpectExpr( + users.c.user_name + ) + else: + assert False + connection.execute( users.insert().values( [ - {"user_id": 7, "user_name": "jack"}, - {"user_id": 8, "user_name": "ed"}, + {user_id: 7, user_name: "jack"}, + {user_id: 8, user_name: "ed"}, ] ) )