From: Mike Bayer Date: Mon, 17 Apr 2023 17:46:12 +0000 (-0400) Subject: apply criteria options from top-level core-only statement X-Git-Tag: rel_2_0_10~11^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a05ae2c7ce0c056eef549d078faa2ca20356d35c;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git apply criteria options from top-level core-only statement Made an improvement to the :func:`_orm.with_loader_criteria` loader option to allow it to be indicated in the :meth:`.Executable.options` method of a top-level statement that is not itself an ORM statement. Examples include :func:`_sql.select` that's embedded in compound statements such as :func:`_sql.union`, within an :meth:`_dml.Insert.from_select` construct, as well as within CTE expressions that are not ORM related at the top level. Improved propagation of :func:`_orm.with_loader_criteria` within ORM enabled UPDATE and DELETE statements as well. Fixes: #9635 Change-Id: I088ad91929dc797c06f292f5dc547d48ffb30430 --- diff --git a/doc/build/changelog/unreleased_20/9635.rst b/doc/build/changelog/unreleased_20/9635.rst new file mode 100644 index 0000000000..73281c7e1d --- /dev/null +++ b/doc/build/changelog/unreleased_20/9635.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, orm + :tickets: 9635 + + Made an improvement to the :func:`_orm.with_loader_criteria` loader option + to allow it to be indicated in the :meth:`.Executable.options` method of a + top-level statement that is not itself an ORM statement. Examples include + :func:`_sql.select` that's embedded in compound statements such as + :func:`_sql.union`, within an :meth:`_dml.Insert.from_select` construct, as + well as within CTE expressions that are not ORM related at the top level. diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 1b3cce47ab..f9d9d6a433 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -1346,15 +1346,14 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): self.mapper = mapper = ext_info.mapper - self.extra_criteria_entities = {} - self._resolved_values = self._get_resolved_values(mapper, statement) - extra_criteria_attributes = {} - - for opt in statement._with_options: - if opt._is_criteria_option: - opt.get_global_criteria(extra_criteria_attributes) + self._init_global_attributes( + statement, + compiler, + toplevel=True, + process_criteria_for_toplevel=True, + ) if statement._values: self._resolved_values = dict(self._resolved_values) @@ -1372,7 +1371,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): new_stmt._values = self._resolved_values new_crit = self._adjust_for_extra_criteria( - extra_criteria_attributes, mapper + self.global_attributes, mapper ) if new_crit: new_stmt = new_stmt.where(*new_crit) @@ -1741,19 +1740,18 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): ext_info = statement.table._annotations["parententity"] self.mapper = mapper = ext_info.mapper - self.extra_criteria_entities = {} - - extra_criteria_attributes = {} - - for opt in statement._with_options: - if opt._is_criteria_option: - opt.get_global_criteria(extra_criteria_attributes) + self._init_global_attributes( + statement, + compiler, + toplevel=True, + process_criteria_for_toplevel=True, + ) new_stmt = statement._clone() new_stmt.table = mapper.local_table new_crit = cls._adjust_for_extra_criteria( - extra_criteria_attributes, mapper + self.global_attributes, mapper ) if new_crit: new_stmt = new_stmt.where(*new_crit) diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 2b45b5adc4..e778c48408 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -209,6 +209,45 @@ _orm_load_exec_options = util.immutabledict( class AbstractORMCompileState(CompileState): is_dml_returning = False + def _init_global_attributes( + self, statement, compiler, *, toplevel, process_criteria_for_toplevel + ): + self.attributes = {} + + if compiler is None: + # this is the legacy / testing only ORM _compile_state() use case. + # there is no need to apply criteria options for this. + self.global_attributes = ga = {} + assert toplevel + return + else: + self.global_attributes = ga = compiler._global_attributes + + if toplevel: + ga["toplevel_orm"] = True + + if process_criteria_for_toplevel: + for opt in statement._with_options: + if opt._is_criteria_option: + opt.process_compile_state(self) + + return + elif ga.get("toplevel_orm", False): + return + + stack_0 = compiler.stack[0] + + try: + toplevel_stmt = stack_0["selectable"] + except KeyError: + pass + else: + for opt in toplevel_stmt._with_options: + if opt._is_compile_state and opt._is_criteria_option: + opt.process_compile_state(self) + + ga["toplevel_orm"] = True + @classmethod def create_for_statement( cls, @@ -622,17 +661,13 @@ class ORMFromStatementCompileState(ORMCompileState): assert isinstance(statement_container, FromStatement) - if compiler is not None: - toplevel = not compiler.stack - else: - toplevel = True - - if not toplevel: + if compiler is not None and compiler.stack: raise sa_exc.CompileError( "The ORM FromStatement construct only supports being " "invoked as the topmost statement, as it is only intended to " "define how result rows should be returned." ) + self = cls.__new__(cls) self._primary_entity = None @@ -680,18 +715,18 @@ class ORMFromStatementCompileState(ORMCompileState): self.current_path = statement_container._compile_options._current_path - if toplevel and statement_container._with_options: - self.attributes = {} - self.global_attributes = compiler._global_attributes + self._init_global_attributes( + statement_container, + compiler, + process_criteria_for_toplevel=False, + toplevel=True, + ) + if statement_container._with_options: for opt in statement_container._with_options: if opt._is_compile_state: opt.process_compile_state(self) - else: - self.attributes = {} - self.global_attributes = compiler._global_attributes - if statement_container._with_context_options: for fn, key in statement_container._with_context_options: fn(self) @@ -911,10 +946,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState): if compiler is not None: toplevel = not compiler.stack - self.global_attributes = compiler._global_attributes else: toplevel = True - self.global_attributes = {} select_statement = statement @@ -1002,11 +1035,17 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self.eager_order_by = () + self._init_global_attributes( + select_statement, + compiler, + toplevel=toplevel, + process_criteria_for_toplevel=False, + ) + if toplevel and ( select_statement._with_options or select_statement._memoized_select_entities ): - self.attributes = {} for ( memoized_entities @@ -1028,9 +1067,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState): if opt._is_compile_state: opt.process_compile_state(self) - else: - self.attributes = {} - # uncomment to print out the context.attributes structure # after it's been set up above # self._dump_option_struct() diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py index 58244c4620..c02f7af4c3 100644 --- a/test/orm/test_relationship_criteria.py +++ b/test/orm/test_relationship_criteria.py @@ -3,10 +3,12 @@ import random from sqlalchemy import Column from sqlalchemy import DateTime +from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import exc as sa_exc from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import literal_column from sqlalchemy import orm @@ -14,6 +16,8 @@ from sqlalchemy import select from sqlalchemy import sql from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import union +from sqlalchemy import update from sqlalchemy.orm import aliased from sqlalchemy.orm import column_property from sqlalchemy.orm import defer @@ -588,6 +592,238 @@ class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): "FROM users WHERE users.name != :name_1", ) + @testing.variation("style", ["direct_union", "from_statement"]) + @testing.variation("add_nested_union", [True, False]) + def test_select_mapper_columns_w_union_mapper_criteria( + self, multi_mixin_fixture, style: testing.Variation, add_nested_union + ): + """test #9635""" + HasFoob, Order, Item = multi_mixin_fixture + + stmt = ( + select(Order.id, Order.description) + .where(Order.id > 8) + .union(select(Order.id, Order.description).where(Order.id <= 8)) + ) + + if add_nested_union: + stmt = union( + stmt, + union( + select(Item.id, Item.description).where(Item.id <= 8), + select(Item.id, Item.description).where(Item.id > 8), + ), + ) + + if style.direct_union: + stmt = stmt.options( + with_loader_criteria( + HasFoob, + lambda cls: cls.description != "name", + include_aliases=True, + ) + ) + elif style.from_statement: + + stmt = ( + select(Order.id, Order.description) + .from_statement(stmt) + .options( + with_loader_criteria( + HasFoob, + lambda cls: cls.description != "name", + include_aliases=True, + ) + ) + ) + + else: + style.fail() + + if add_nested_union: + # the criteria is embedded into all UNIONS regardless of nesting. + self.assert_compile( + stmt, + "(SELECT orders.id, orders.description FROM orders WHERE " + "orders.id > :id_1 AND orders.description != :description_1 " + "UNION SELECT orders.id, orders.description FROM orders WHERE " + "orders.id <= :id_2 AND orders.description != :description_2) " + "UNION (SELECT items.id, items.description FROM items WHERE " + "items.id <= :id_3 AND items.description != :description_3 " + "UNION SELECT items.id, items.description FROM items WHERE " + "items.id > :id_4 AND items.description != :description_4)", + checkparams={ + "id_1": 8, + "description_1": "name", + "id_2": 8, + "description_2": "name", + "id_3": 8, + "description_3": "name", + "id_4": 8, + "description_4": "name", + }, + ) + else: + self.assert_compile( + stmt, + "SELECT orders.id, orders.description FROM orders WHERE " + "orders.id > :id_1 AND orders.description != :description_1 " + "UNION SELECT orders.id, orders.description FROM orders WHERE " + "orders.id <= :id_2 AND orders.description != :description_2", + checkparams={ + "description_1": "name", + "description_2": "name", + "id_1": 8, + "id_2": 8, + }, + ) + + def test_select_mapper_columns_w_core_dml_mapper_criteria( + self, multi_mixin_fixture + ): + """test #9635""" + HasFoob, Order, Item = multi_mixin_fixture + + stmt = ( + insert(Order) + .from_select( + ["id", "description"], + select(Order.id, Order.description).where(Order.id > 8), + ) + .options( + with_loader_criteria( + HasFoob, + lambda cls: cls.description != "name", + include_aliases=True, + ) + ) + ) + + self.assert_compile( + stmt, + "INSERT INTO orders (id, description) SELECT orders.id, " + "orders.description FROM orders WHERE orders.id > :id_1 " + "AND orders.description != :description_1", + checkparams={"description_1": "name", "id_1": 8}, + ) + + @testing.variation("update_is_orm", [True, False]) + def test_select_mapper_columns_w_core_cte_update_mapper_criteria( + self, multi_mixin_fixture, update_is_orm + ): + """test #9635""" + HasFoob, Order, Item = multi_mixin_fixture + + cte = select(Order).cte("pd") + + if update_is_orm: + stmt = ( + update(Order) + .where(Order.id == cte.c.id) + .values(description="newname") + ) + else: + stmt = ( + update(Order.__table__) + .where(Order.__table__.c.id == cte.c.id) + .values(description="newname") + ) + + stmt = stmt.options( + with_loader_criteria( + HasFoob, + lambda cls: cls.description != "name", + include_aliases=True, + ) + ) + + if update_is_orm: + self.assert_compile( + stmt, + "WITH pd AS (SELECT orders.id AS id, " + "orders.user_id AS user_id, " + "orders.address_id AS address_id, " + "orders.description AS description, orders.isopen AS isopen " + "FROM orders WHERE orders.description != %(description_1)s) " + "UPDATE orders SET description=%(description)s " + "FROM pd WHERE orders.id = pd.id " + "AND orders.description != %(description_2)s", + dialect="postgresql", + checkparams={ + "description": "newname", + "description_1": "name", + "description_2": "name", + }, + ) + else: + # non ORM update, no criteria, but criteria still gets rendered + # inside the SELECT + self.assert_compile( + stmt, + "WITH pd AS (SELECT orders.id AS id, " + "orders.user_id AS user_id, " + "orders.address_id AS address_id, " + "orders.description AS description, orders.isopen AS isopen " + "FROM orders WHERE orders.description != %(description_1)s) " + "UPDATE orders SET description=%(description)s " + "FROM pd WHERE orders.id = pd.id", + dialect="postgresql", + checkparams={ + "description": "newname", + "description_1": "name", + }, + ) + + @testing.variation("delete_is_orm", [True, False]) + def test_select_mapper_columns_w_core_cte_delete_mapper_criteria( + self, multi_mixin_fixture, delete_is_orm + ): + """test #9635""" + HasFoob, Order, Item = multi_mixin_fixture + + cte = select(Order).cte("pd") + + if delete_is_orm: + stmt = delete(Order).where(Order.id == cte.c.id) + else: + stmt = delete(Order.__table__).where( + Order.__table__.c.id == cte.c.id + ) + + stmt = stmt.options( + with_loader_criteria( + HasFoob, + lambda cls: cls.description != "name", + include_aliases=True, + ) + ) + + if delete_is_orm: + self.assert_compile( + stmt, + "WITH pd AS (SELECT orders.id AS id, orders.user_id AS " + "user_id, orders.address_id AS address_id, " + "orders.description AS description, orders.isopen AS isopen " + "FROM orders WHERE orders.description != %(description_1)s) " + "DELETE FROM orders USING pd WHERE orders.id = pd.id " + "AND orders.description != %(description_2)s", + dialect="postgresql", + checkparams={"description_1": "name", "description_2": "name"}, + ) + else: + # non ORM update, no criteria, but criteria still gets rendered + # inside the SELECT + self.assert_compile( + stmt, + "WITH pd AS (SELECT orders.id AS id, orders.user_id AS " + "user_id, orders.address_id AS address_id, " + "orders.description AS description, orders.isopen AS isopen " + "FROM orders WHERE orders.description != %(description_1)s) " + "DELETE FROM orders USING pd WHERE orders.id = pd.id", + dialect="postgresql", + checkparams={"description_1": "name"}, + ) + def test_select_join_mapper_mapper_criteria(self, user_address_fixture): User, Address = user_address_fixture