]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
generalize sql server check for id col to accommodate ORM cases
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 6 Jul 2022 01:05:18 +0000 (21:05 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 6 Jul 2022 02:02:43 +0000 (22:02 -0400)
Fixed issues that prevented the new usage patterns for using DML with ORM
objects presented at :ref:`orm_dml_returning_objects` from working
correctly with the SQL Server pyodbc dialect.

Here we add a step to look in compile_state._dict_values more thoroughly
for the keys we need to determine "identity insert" or not, and also
add a new compiler variable dml_compile_state so that we can skip the
ORM's compile_state if present.

Fixes: #8210
Change-Id: Idbd76bb3eb075c647dc6c1cb78f7315c821e15f7
(cherry picked from commit 5806428800d2f1ac775156f90497a2fc3a644f35)

doc/build/changelog/unreleased_14/8210.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/dml.py
test/orm/test_update_delete.py
test/sql/test_insert_exec.py

diff --git a/doc/build/changelog/unreleased_14/8210.rst b/doc/build/changelog/unreleased_14/8210.rst
new file mode 100644 (file)
index 0000000..f99d861
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, mssql
+    :tickets: 8210
+
+    Fixed issues that prevented the new usage patterns for using DML with ORM
+    objects presented at :ref:`orm_dml_returning_objects` from working
+    correctly with the SQL Server pyodbc dialect.
+
index 1658f27c70c12ed3b2e57d0536eb9a720a6e822f..735cc3cff82880cb43bb61c80da9a34e0656a41b 100644 (file)
@@ -1651,15 +1651,12 @@ class MSExecutionContext(default.DefaultExecutionContext):
             )
 
             if insert_has_identity:
-                compile_state = self.compiled.compile_state
+                compile_state = self.compiled.dml_compile_state
                 self._enable_identity_insert = (
                     id_column.key in self.compiled_parameters[0]
                 ) or (
                     compile_state._dict_parameters
-                    and (
-                        id_column.key in compile_state._dict_parameters
-                        or id_column in compile_state._dict_parameters
-                    )
+                    and (id_column.key in compile_state._insert_col_keys)
                 )
 
             else:
index 477c199c1759539bb7a055cfebbae84c3c1245d6..667dd7d3de5fc1bc06232d5192337b33ef888089 100644 (file)
@@ -402,6 +402,18 @@ class Compiled(object):
 
     """
 
+    dml_compile_state = None
+    """Optional :class:`.CompileState` assigned at the same point that
+    .isinsert, .isupdate, or .isdelete is assigned.
+
+    This will normally be the same object as .compile_state, with the
+    exception of cases like the :class:`.ORMFromStatementCompileState`
+    object.
+
+    .. versionadded:: 1.4.40
+
+    """
+
     cache_key = None
     _gen_time = None
 
@@ -3838,6 +3850,8 @@ class SQLCompiler(Compiled):
 
         if toplevel:
             self.isinsert = True
+            if not self.dml_compile_state:
+                self.dml_compile_state = compile_state
             if not self.compile_state:
                 self.compile_state = compile_state
 
@@ -4008,6 +4022,8 @@ class SQLCompiler(Compiled):
         toplevel = not self.stack
         if toplevel:
             self.isupdate = True
+            if not self.dml_compile_state:
+                self.dml_compile_state = compile_state
             if not self.compile_state:
                 self.compile_state = compile_state
 
@@ -4134,6 +4150,8 @@ class SQLCompiler(Compiled):
         toplevel = not self.stack
         if toplevel:
             self.isdelete = True
+            if not self.dml_compile_state:
+                self.dml_compile_state = compile_state
             if not self.compile_state:
                 self.compile_state = compile_state
 
index dea5d6119df376da05b665394de9555351d44ae0..4a343147c9656a3c17c696a321e7624f5d26ecfe 100644 (file)
@@ -189,6 +189,14 @@ class InsertDMLState(DMLState):
         if statement._multi_values:
             self._process_multi_values(statement)
 
+    @util.memoized_property
+    def _insert_col_keys(self):
+        # this is also done in crud.py -> _key_getters_for_crud_column
+        return [
+            coercions.expect_as_key(roles.DMLColumnRole, col)
+            for col in self._dict_parameters
+        ]
+
 
 @CompileState.plugin_for("default", "update")
 class UpdateDMLState(DMLState):
index 255d70f4142c909ee0eb65f21a4079f688fb75bc..4eabe2f6c49ff706ebd22eaa81ba415e048ab940 100644 (file)
@@ -2217,21 +2217,45 @@ class LoadFromReturningTest(fixtures.MappedTest):
                 [User(name="jack", age=52), User(name="jill", age=34)],
             )
 
-    def test_load_from_insert(self, connection):
+    @testing.combinations(
+        ("single",),
+        ("multiple", testing.requires.multivalues_inserts),
+        argnames="params",
+    )
+    def test_load_from_insert(self, connection, params):
         User = self.classes.User
 
-        stmt = (
-            insert(User)
-            .values({User.id: 5, User.age: 25, User.name: "spongebob"})
-            .returning(User)
-        )
+        if params == "multiple":
+            values = [
+                {User.id: 5, User.age: 25, User.name: "spongebob"},
+                {User.id: 6, User.age: 30, User.name: "patrick"},
+                {User.id: 7, User.age: 35, User.name: "squidward"},
+            ]
+        elif params == "single":
+            values = {User.id: 5, User.age: 25, User.name: "spongebob"}
+        else:
+            assert False
+
+        stmt = insert(User).values(values).returning(User)
 
         stmt = select(User).from_statement(stmt)
 
         with Session(connection) as sess:
             rows = sess.execute(stmt).scalars().all()
 
-            eq_(
-                rows,
-                [User(name="spongebob", age=25)],
-            )
+            if params == "multiple":
+                eq_(
+                    rows,
+                    [
+                        User(name="spongebob", age=25),
+                        User(name="patrick", age=30),
+                        User(name="squidward", age=35),
+                    ],
+                )
+            elif params == "single":
+                eq_(
+                    rows,
+                    [User(name="spongebob", age=25)],
+                )
+            else:
+                assert False
index 76b4ba01ea89fc5c49b7ad6ed6ef0a8d6c7518ba..334df9575e9df9e1e27b730f6296bfeb991dd798 100644 (file)
@@ -20,6 +20,14 @@ from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
 
 
+class ExpectExpr:
+    def __init__(self, element):
+        self.element = element
+
+    def __clause_element__(self):
+        return self.element
+
+
 class InsertExecTest(fixtures.TablesTest):
     __backend__ = True
 
@@ -36,13 +44,27 @@ class InsertExecTest(fixtures.TablesTest):
         )
 
     @testing.requires.multivalues_inserts
-    def test_multivalues_insert(self, connection):
+    @testing.combinations("string", "column", "expect", argnames="keytype")
+    def test_multivalues_insert(self, connection, keytype):
+
         users = self.tables.users
+
+        if keytype == "string":
+            user_id, user_name = "user_id", "user_name"
+        elif keytype == "column":
+            user_id, user_name = users.c.user_id, users.c.user_name
+        elif keytype == "expect":
+            user_id, user_name = ExpectExpr(users.c.user_id), ExpectExpr(
+                users.c.user_name
+            )
+        else:
+            assert False
+
         connection.execute(
             users.insert().values(
                 [
-                    {"user_id": 7, "user_name": "jack"},
-                    {"user_id": 8, "user_name": "ed"},
+                    {user_id: 7, user_name: "jack"},
+                    {user_id: 8, user_name: "ed"},
                 ]
             )
         )