]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
init extra_criteria_entities in fromstatement w/ DML
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Jun 2021 13:21:25 +0000 (09:21 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 7 Jun 2021 13:22:47 +0000 (09:22 -0400)
Fixed issue in experimental "select ORM objects from INSERT/UPDATE" use
case where an error was raised if the statement were against a
single-table-inheritance subclass.

Additionally makes some adjustments in the SQL assertion
fixture to test a FromStatement w/ DML.

Fixes: #6591
Change-Id: I53a627ab18a01dc6d9b5037e28312a1177891327

doc/build/changelog/unreleased_14/6591.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/testing/assertions.py
test/orm/inheritance/test_single.py

diff --git a/doc/build/changelog/unreleased_14/6591.rst b/doc/build/changelog/unreleased_14/6591.rst
new file mode 100644 (file)
index 0000000..74cbcc5
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 6591
+
+    Fixed issue in experimental "select ORM objects from INSERT/UPDATE" use
+    case where an error was raised if the statement were against a
+    single-table-inheritance subclass.
index baad288359163e7697a6e41610652134fdde63f7..d60758ffcd0a536074f0d2c6dc4664efc617eb21 100644 (file)
@@ -431,6 +431,9 @@ class ORMFromStatementCompileState(ORMCompileState):
         if isinstance(
             self.statement, (expression.TextClause, expression.UpdateBase)
         ):
+
+            self.extra_criteria_entities = {}
+
             # setup for all entities. Currently, this is not useful
             # for eager loaders, as the eager loaders that work are able
             # to do their work entirely in row_processor.
@@ -709,7 +712,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         # i.e. when each _MappedEntity has its own FROM
 
         if self.compile_options._enable_single_crit:
-
             self._adjust_for_extra_criteria()
 
         if not self.primary_columns:
index b618021a67b8047cabee2992a6eab7043e58a9af..cf61bf95caf1cf0d2f6083c757863d99e88ad65a 100644 (file)
@@ -487,15 +487,15 @@ class AssertsCompiledSQL(object):
                 self.supports_execution = getattr(
                     test_statement, "supports_execution", False
                 )
+
                 if self.supports_execution:
                     self._execution_options = test_statement._execution_options
 
-                    if isinstance(
-                        test_statement, (sql.Insert, sql.Update, sql.Delete)
-                    ):
+                    if hasattr(test_statement, "_returning"):
                         self._returning = test_statement._returning
-                    if isinstance(test_statement, (sql.Insert, sql.Update)):
+                    if hasattr(test_statement, "_inline"):
                         self._inline = test_statement._inline
+                    if hasattr(test_statement, "_return_defaults"):
                         self._return_defaults = test_statement._return_defaults
 
             def _default_dialect(self):
index ececaf882f88e799b4d882377d7e857158095bb6..873b808ec91c0646aad0389029ea423c205753ac 100644 (file)
@@ -369,6 +369,44 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest):
             "WHERE employees_1.type IN ([POSTCOMPILE_type_1])",
         )
 
+    def test_from_statement_select(self):
+        Engineer = self.classes.Engineer
+
+        stmt = select(Engineer)
+
+        q = select(Engineer).from_statement(stmt)
+
+        self.assert_compile(
+            q,
+            "SELECT employees.employee_id, employees.name, "
+            "employees.manager_data, employees.engineer_info, "
+            "employees.type FROM employees WHERE employees.type "
+            "IN ([POSTCOMPILE_type_1])",
+        )
+
+    def test_from_statement_update(self):
+        """test #6591"""
+
+        Engineer = self.classes.Engineer
+
+        from sqlalchemy import update
+
+        stmt = (
+            update(Engineer)
+            .values(engineer_info="bar")
+            .returning(Engineer.employee_id)
+        )
+
+        q = select(Engineer).from_statement(stmt)
+
+        self.assert_compile(
+            q,
+            "UPDATE employees SET engineer_info=:engineer_info "
+            "WHERE employees.type IN ([POSTCOMPILE_type_1]) "
+            "RETURNING employees.employee_id",
+            dialect="default_enhanced",
+        )
+
     def test_union_modifiers(self):
         Engineer, Manager = self.classes("Engineer", "Manager")