From 94169108cdd4dace09b752a6af4f4404819b49a3 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 7 Jun 2021 09:21:25 -0400 Subject: [PATCH] init extra_criteria_entities in fromstatement w/ DML Fixed issue in experimental "select ORM objects from INSERT/UPDATE" use case where an error was raised if the statement were against a single-table-inheritance subclass. Additionally makes some adjustments in the SQL assertion fixture to test a FromStatement w/ DML. Fixes: #6591 Change-Id: I53a627ab18a01dc6d9b5037e28312a1177891327 --- doc/build/changelog/unreleased_14/6591.rst | 7 ++++ lib/sqlalchemy/orm/context.py | 4 ++- lib/sqlalchemy/testing/assertions.py | 8 ++--- test/orm/inheritance/test_single.py | 38 ++++++++++++++++++++++ 4 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6591.rst diff --git a/doc/build/changelog/unreleased_14/6591.rst b/doc/build/changelog/unreleased_14/6591.rst new file mode 100644 index 0000000000..74cbcc5f4c --- /dev/null +++ b/doc/build/changelog/unreleased_14/6591.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, orm + :tickets: 6591 + + Fixed issue in experimental "select ORM objects from INSERT/UPDATE" use + case where an error was raised if the statement were against a + single-table-inheritance subclass. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index baad288359..d60758ffcd 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -431,6 +431,9 @@ class ORMFromStatementCompileState(ORMCompileState): if isinstance( self.statement, (expression.TextClause, expression.UpdateBase) ): + + self.extra_criteria_entities = {} + # setup for all entities. Currently, this is not useful # for eager loaders, as the eager loaders that work are able # to do their work entirely in row_processor. @@ -709,7 +712,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # i.e. when each _MappedEntity has its own FROM if self.compile_options._enable_single_crit: - self._adjust_for_extra_criteria() if not self.primary_columns: diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index b618021a67..cf61bf95ca 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -487,15 +487,15 @@ class AssertsCompiledSQL(object): self.supports_execution = getattr( test_statement, "supports_execution", False ) + if self.supports_execution: self._execution_options = test_statement._execution_options - if isinstance( - test_statement, (sql.Insert, sql.Update, sql.Delete) - ): + if hasattr(test_statement, "_returning"): self._returning = test_statement._returning - if isinstance(test_statement, (sql.Insert, sql.Update)): + if hasattr(test_statement, "_inline"): self._inline = test_statement._inline + if hasattr(test_statement, "_return_defaults"): self._return_defaults = test_statement._return_defaults def _default_dialect(self): diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index ececaf882f..873b808ec9 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -369,6 +369,44 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): "WHERE employees_1.type IN ([POSTCOMPILE_type_1])", ) + def test_from_statement_select(self): + Engineer = self.classes.Engineer + + stmt = select(Engineer) + + q = select(Engineer).from_statement(stmt) + + self.assert_compile( + q, + "SELECT employees.employee_id, employees.name, " + "employees.manager_data, employees.engineer_info, " + "employees.type FROM employees WHERE employees.type " + "IN ([POSTCOMPILE_type_1])", + ) + + def test_from_statement_update(self): + """test #6591""" + + Engineer = self.classes.Engineer + + from sqlalchemy import update + + stmt = ( + update(Engineer) + .values(engineer_info="bar") + .returning(Engineer.employee_id) + ) + + q = select(Engineer).from_statement(stmt) + + self.assert_compile( + q, + "UPDATE employees SET engineer_info=:engineer_info " + "WHERE employees.type IN ([POSTCOMPILE_type_1]) " + "RETURNING employees.employee_id", + dialect="default_enhanced", + ) + def test_union_modifiers(self): Engineer, Manager = self.classes("Engineer", "Manager") -- 2.47.2