]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
generalize sql server check for id col to accommodate ORM cases
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 6 Jul 2022 01:05:18 +0000 (21:05 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 6 Jul 2022 13:58:11 +0000 (09:58 -0400)
Fixed issues that prevented the new usage patterns for using DML with ORM
objects presented at :ref:`orm_dml_returning_objects` from working
correctly with the SQL Server pyodbc dialect.

Here we add a step to look in compile_state._dict_values more thoroughly
for the keys we need to determine "identity insert" or not, and also
add a new compiler variable dml_compile_state so that we can skip the
ORM's compile_state if present.

Fixes: #8210
Change-Id: Idbd76bb3eb075c647dc6c1cb78f7315c821e15f7

doc/build/changelog/unreleased_14/8210.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/dml.py
test/orm/test_update_delete.py
test/sql/test_insert_exec.py

diff --git a/doc/build/changelog/unreleased_14/8210.rst b/doc/build/changelog/unreleased_14/8210.rst
new file mode 100644 (file)
index 0000000..f99d861
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, mssql
+    :tickets: 8210
+
+    Fixed issues that prevented the new usage patterns for using DML with ORM
+    objects presented at :ref:`orm_dml_returning_objects` from working
+    correctly with the SQL Server pyodbc dialect.
+
index 2a4362ccb9fb52de15604f62df21f8ffd674ab7b..ed4139ad172eca2ed127fb08653510267bbde495 100644 (file)
@@ -1751,15 +1751,12 @@ class MSExecutionContext(default.DefaultExecutionContext):
                 not isinstance(id_column.default, Sequence)
             ):
                 insert_has_identity = True
-                compile_state = self.compiled.compile_state
+                compile_state = self.compiled.dml_compile_state
                 self._enable_identity_insert = (
                     id_column.key in self.compiled_parameters[0]
                 ) or (
                     compile_state._dict_parameters
-                    and (
-                        id_column.key in compile_state._dict_parameters
-                        or id_column in compile_state._dict_parameters
-                    )
+                    and (id_column.key in compile_state._insert_col_keys)
                 )
 
             else:
index d56035db7b3c72a34daac309a093cdf7ec614f05..0b250a28ec50c45b476e75f6ab6f0e97456ebfd7 100644 (file)
@@ -162,6 +162,17 @@ def expect(
     ...
 
 
+@overload
+def expect(
+    role: Type[roles.DMLColumnRole],
+    element: Any,
+    *,
+    as_key: Literal[True] = ...,
+    **kw: Any,
+) -> str:
+    ...
+
+
 @overload
 def expect(
     role: Type[roles.LiteralValueRole],
@@ -420,9 +431,11 @@ def expect(
         )
 
 
-def expect_as_key(role, element, **kw):
-    kw["as_key"] = True
-    return expect(role, element, **kw)
+def expect_as_key(
+    role: Type[roles.DMLColumnRole], element: Any, **kw: Any
+) -> str:
+    kw.pop("as_key", None)
+    return expect(role, element, as_key=True, **kw)
 
 
 def expect_col_expression_collection(
index 60ec09771f373e28d37334effae2e8e3d5f05f3f..a11d83b11cd7e9bf98dfe41fe9c70b0f733a8779 100644 (file)
@@ -529,6 +529,18 @@ class Compiled:
 
     """
 
+    dml_compile_state: Optional[CompileState] = None
+    """Optional :class:`.CompileState` assigned at the same point that
+    .isinsert, .isupdate, or .isdelete is assigned.
+
+    This will normally be the same object as .compile_state, with the
+    exception of cases like the :class:`.ORMFromStatementCompileState`
+    object.
+
+    .. versionadded:: 1.4.40
+
+    """
+
     cache_key: Optional[CacheKey] = None
     """The :class:`.CacheKey` that was generated ahead of creating this
     :class:`.Compiled` object.
@@ -4371,6 +4383,8 @@ class SQLCompiler(Compiled):
 
         if toplevel:
             self.isinsert = True
+            if not self.dml_compile_state:
+                self.dml_compile_state = compile_state
             if not self.compile_state:
                 self.compile_state = compile_state
 
@@ -4548,6 +4562,8 @@ class SQLCompiler(Compiled):
         toplevel = not self.stack
         if toplevel:
             self.isupdate = True
+            if not self.dml_compile_state:
+                self.dml_compile_state = compile_state
             if not self.compile_state:
                 self.compile_state = compile_state
 
@@ -4683,6 +4699,8 @@ class SQLCompiler(Compiled):
         toplevel = not self.stack
         if toplevel:
             self.isdelete = True
+            if not self.dml_compile_state:
+                self.dml_compile_state = compile_state
             if not self.compile_state:
                 self.compile_state = compile_state
 
index 2ed3be9cbd4b45b13ff39fe5898035d9a6e0a481..e99f3541886a611a68a43900841ab49db97eaaf0 100644 (file)
@@ -263,7 +263,7 @@ class DMLState(CompileState):
 
     def _process_select_values(self, statement: ValuesBase) -> None:
         assert statement._select_names is not None
-        parameters = {
+        parameters: MutableMapping[_DMLColumnElement, Any] = {
             coercions.expect(roles.DMLColumnRole, name, as_key=True): Null()
             for name in statement._select_names
         }
@@ -312,6 +312,14 @@ class InsertDMLState(DMLState):
         if statement._multi_values:
             self._process_multi_values(statement)
 
+    @util.memoized_property
+    def _insert_col_keys(self) -> List[str]:
+        # this is also done in crud.py -> _key_getters_for_crud_column
+        return [
+            coercions.expect_as_key(roles.DMLColumnRole, col)
+            for col in self._dict_parameters or ()
+        ]
+
 
 @CompileState.plugin_for("default", "update")
 class UpdateDMLState(DMLState):
index 22d827be977b68df8104e344124d769aef56e5fd..933d2bb1fa47df044d6a3666ccd49f9888fc4e3a 100644 (file)
@@ -2253,21 +2253,45 @@ class LoadFromReturningTest(fixtures.MappedTest):
                 [User(name="jack", age=52), User(name="jill", age=34)],
             )
 
-    def test_load_from_insert(self, connection):
+    @testing.combinations(
+        ("single",),
+        ("multiple", testing.requires.multivalues_inserts),
+        argnames="params",
+    )
+    def test_load_from_insert(self, connection, params):
         User = self.classes.User
 
-        stmt = (
-            insert(User)
-            .values({User.id: 5, User.age: 25, User.name: "spongebob"})
-            .returning(User)
-        )
+        if params == "multiple":
+            values = [
+                {User.id: 5, User.age: 25, User.name: "spongebob"},
+                {User.id: 6, User.age: 30, User.name: "patrick"},
+                {User.id: 7, User.age: 35, User.name: "squidward"},
+            ]
+        elif params == "single":
+            values = {User.id: 5, User.age: 25, User.name: "spongebob"}
+        else:
+            assert False
+
+        stmt = insert(User).values(values).returning(User)
 
         stmt = select(User).from_statement(stmt)
 
         with Session(connection) as sess:
             rows = sess.execute(stmt).scalars().all()
 
-            eq_(
-                rows,
-                [User(name="spongebob", age=25)],
-            )
+            if params == "multiple":
+                eq_(
+                    rows,
+                    [
+                        User(name="spongebob", age=25),
+                        User(name="patrick", age=30),
+                        User(name="squidward", age=35),
+                    ],
+                )
+            elif params == "single":
+                eq_(
+                    rows,
+                    [User(name="spongebob", age=25)],
+                )
+            else:
+                assert False
index 3e51e9450c2089d12d65f8622ca2cec53ec27902..b6945813e62ffffe6a626ebb16bbb2a96bd55f49 100644 (file)
@@ -19,6 +19,14 @@ from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
 
+class ExpectExpr:
+    def __init__(self, element):
+        self.element = element
+
+    def __clause_element__(self):
+        return self.element
+
+
 class InsertExecTest(fixtures.TablesTest):
     __backend__ = True
 
@@ -35,13 +43,27 @@ class InsertExecTest(fixtures.TablesTest):
         )
 
     @testing.requires.multivalues_inserts
-    def test_multivalues_insert(self, connection):
+    @testing.combinations("string", "column", "expect", argnames="keytype")
+    def test_multivalues_insert(self, connection, keytype):
+
         users = self.tables.users
+
+        if keytype == "string":
+            user_id, user_name = "user_id", "user_name"
+        elif keytype == "column":
+            user_id, user_name = users.c.user_id, users.c.user_name
+        elif keytype == "expect":
+            user_id, user_name = ExpectExpr(users.c.user_id), ExpectExpr(
+                users.c.user_name
+            )
+        else:
+            assert False
+
         connection.execute(
             users.insert().values(
                 [
-                    {"user_id": 7, "user_name": "jack"},
-                    {"user_id": 8, "user_name": "ed"},
+                    {user_id: 7, user_name: "jack"},
+                    {user_id: 8, user_name: "ed"},
                 ]
             )
         )