]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
apply criteria options from top-level core-only statement
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 17 Apr 2023 17:46:12 +0000 (13:46 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 17 Apr 2023 19:32:00 +0000 (15:32 -0400)
Made an improvement to the :func:`_orm.with_loader_criteria` loader option
to allow it to be indicated in the :meth:`.Executable.options` method of a
top-level statement that is not itself an ORM statement. Examples include
:func:`_sql.select` that's embedded in compound statements such as
:func:`_sql.union`, within an :meth:`_dml.Insert.from_select` construct, as
well as within CTE expressions that are not ORM related at the top level.
Improved propagation of :func:`_orm.with_loader_criteria` within
ORM enabled UPDATE and DELETE statements as well.

Fixes: #9635
Change-Id: I088ad91929dc797c06f292f5dc547d48ffb30430

doc/build/changelog/unreleased_20/9635.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/context.py
test/orm/test_relationship_criteria.py

diff --git a/doc/build/changelog/unreleased_20/9635.rst b/doc/build/changelog/unreleased_20/9635.rst
new file mode 100644 (file)
index 0000000..73281c7
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9635
+
+    Made an improvement to the :func:`_orm.with_loader_criteria` loader option
+    to allow it to be indicated in the :meth:`.Executable.options` method of a
+    top-level statement that is not itself an ORM statement. Examples include
+    :func:`_sql.select` that's embedded in compound statements such as
+    :func:`_sql.union`, within an :meth:`_dml.Insert.from_select` construct, as
+    well as within CTE expressions that are not ORM related at the top level.
index 1b3cce47ab11f16e96b37cdd28720716a3dd3ea9..f9d9d6a433bb5384247b66414cc6f7251d40095a 100644 (file)
@@ -1346,15 +1346,14 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
 
         self.mapper = mapper = ext_info.mapper
 
-        self.extra_criteria_entities = {}
-
         self._resolved_values = self._get_resolved_values(mapper, statement)
 
-        extra_criteria_attributes = {}
-
-        for opt in statement._with_options:
-            if opt._is_criteria_option:
-                opt.get_global_criteria(extra_criteria_attributes)
+        self._init_global_attributes(
+            statement,
+            compiler,
+            toplevel=True,
+            process_criteria_for_toplevel=True,
+        )
 
         if statement._values:
             self._resolved_values = dict(self._resolved_values)
@@ -1372,7 +1371,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
             new_stmt._values = self._resolved_values
 
         new_crit = self._adjust_for_extra_criteria(
-            extra_criteria_attributes, mapper
+            self.global_attributes, mapper
         )
         if new_crit:
             new_stmt = new_stmt.where(*new_crit)
@@ -1741,19 +1740,18 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
         ext_info = statement.table._annotations["parententity"]
         self.mapper = mapper = ext_info.mapper
 
-        self.extra_criteria_entities = {}
-
-        extra_criteria_attributes = {}
-
-        for opt in statement._with_options:
-            if opt._is_criteria_option:
-                opt.get_global_criteria(extra_criteria_attributes)
+        self._init_global_attributes(
+            statement,
+            compiler,
+            toplevel=True,
+            process_criteria_for_toplevel=True,
+        )
 
         new_stmt = statement._clone()
         new_stmt.table = mapper.local_table
 
         new_crit = cls._adjust_for_extra_criteria(
-            extra_criteria_attributes, mapper
+            self.global_attributes, mapper
         )
         if new_crit:
             new_stmt = new_stmt.where(*new_crit)
index 2b45b5adc4d90a0f02d9f92a8ba3e1db1e0457db..e778c4840852ee38b24b2259003e52c21371873d 100644 (file)
@@ -209,6 +209,45 @@ _orm_load_exec_options = util.immutabledict(
 class AbstractORMCompileState(CompileState):
     is_dml_returning = False
 
+    def _init_global_attributes(
+        self, statement, compiler, *, toplevel, process_criteria_for_toplevel
+    ):
+        self.attributes = {}
+
+        if compiler is None:
+            # this is the legacy / testing only ORM _compile_state() use case.
+            # there is no need to apply criteria options for this.
+            self.global_attributes = ga = {}
+            assert toplevel
+            return
+        else:
+            self.global_attributes = ga = compiler._global_attributes
+
+        if toplevel:
+            ga["toplevel_orm"] = True
+
+            if process_criteria_for_toplevel:
+                for opt in statement._with_options:
+                    if opt._is_criteria_option:
+                        opt.process_compile_state(self)
+
+            return
+        elif ga.get("toplevel_orm", False):
+            return
+
+        stack_0 = compiler.stack[0]
+
+        try:
+            toplevel_stmt = stack_0["selectable"]
+        except KeyError:
+            pass
+        else:
+            for opt in toplevel_stmt._with_options:
+                if opt._is_compile_state and opt._is_criteria_option:
+                    opt.process_compile_state(self)
+
+        ga["toplevel_orm"] = True
+
     @classmethod
     def create_for_statement(
         cls,
@@ -622,17 +661,13 @@ class ORMFromStatementCompileState(ORMCompileState):
 
         assert isinstance(statement_container, FromStatement)
 
-        if compiler is not None:
-            toplevel = not compiler.stack
-        else:
-            toplevel = True
-
-        if not toplevel:
+        if compiler is not None and compiler.stack:
             raise sa_exc.CompileError(
                 "The ORM FromStatement construct only supports being "
                 "invoked as the topmost statement, as it is only intended to "
                 "define how result rows should be returned."
             )
+
         self = cls.__new__(cls)
         self._primary_entity = None
 
@@ -680,18 +715,18 @@ class ORMFromStatementCompileState(ORMCompileState):
 
         self.current_path = statement_container._compile_options._current_path
 
-        if toplevel and statement_container._with_options:
-            self.attributes = {}
-            self.global_attributes = compiler._global_attributes
+        self._init_global_attributes(
+            statement_container,
+            compiler,
+            process_criteria_for_toplevel=False,
+            toplevel=True,
+        )
 
+        if statement_container._with_options:
             for opt in statement_container._with_options:
                 if opt._is_compile_state:
                     opt.process_compile_state(self)
 
-        else:
-            self.attributes = {}
-            self.global_attributes = compiler._global_attributes
-
         if statement_container._with_context_options:
             for fn, key in statement_container._with_context_options:
                 fn(self)
@@ -911,10 +946,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
         if compiler is not None:
             toplevel = not compiler.stack
-            self.global_attributes = compiler._global_attributes
         else:
             toplevel = True
-            self.global_attributes = {}
 
         select_statement = statement
 
@@ -1002,11 +1035,17 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
 
         self.eager_order_by = ()
 
+        self._init_global_attributes(
+            select_statement,
+            compiler,
+            toplevel=toplevel,
+            process_criteria_for_toplevel=False,
+        )
+
         if toplevel and (
             select_statement._with_options
             or select_statement._memoized_select_entities
         ):
-            self.attributes = {}
 
             for (
                 memoized_entities
@@ -1028,9 +1067,6 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
                 if opt._is_compile_state:
                     opt.process_compile_state(self)
 
-        else:
-            self.attributes = {}
-
         # uncomment to print out the context.attributes structure
         # after it's been set up above
         # self._dump_option_struct()
index 58244c4620eca53fd415c834707320f6983b2500..c02f7af4c30a58fae1682440278c6defde13a919 100644 (file)
@@ -3,10 +3,12 @@ import random
 
 from sqlalchemy import Column
 from sqlalchemy import DateTime
+from sqlalchemy import delete
 from sqlalchemy import event
 from sqlalchemy import exc as sa_exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
+from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import literal_column
 from sqlalchemy import orm
@@ -14,6 +16,8 @@ from sqlalchemy import select
 from sqlalchemy import sql
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import union
+from sqlalchemy import update
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import column_property
 from sqlalchemy.orm import defer
@@ -588,6 +592,238 @@ class LoaderCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
             "FROM users WHERE users.name != :name_1",
         )
 
+    @testing.variation("style", ["direct_union", "from_statement"])
+    @testing.variation("add_nested_union", [True, False])
+    def test_select_mapper_columns_w_union_mapper_criteria(
+        self, multi_mixin_fixture, style: testing.Variation, add_nested_union
+    ):
+        """test #9635"""
+        HasFoob, Order, Item = multi_mixin_fixture
+
+        stmt = (
+            select(Order.id, Order.description)
+            .where(Order.id > 8)
+            .union(select(Order.id, Order.description).where(Order.id <= 8))
+        )
+
+        if add_nested_union:
+            stmt = union(
+                stmt,
+                union(
+                    select(Item.id, Item.description).where(Item.id <= 8),
+                    select(Item.id, Item.description).where(Item.id > 8),
+                ),
+            )
+
+        if style.direct_union:
+            stmt = stmt.options(
+                with_loader_criteria(
+                    HasFoob,
+                    lambda cls: cls.description != "name",
+                    include_aliases=True,
+                )
+            )
+        elif style.from_statement:
+
+            stmt = (
+                select(Order.id, Order.description)
+                .from_statement(stmt)
+                .options(
+                    with_loader_criteria(
+                        HasFoob,
+                        lambda cls: cls.description != "name",
+                        include_aliases=True,
+                    )
+                )
+            )
+
+        else:
+            style.fail()
+
+        if add_nested_union:
+            # the criteria is embedded into all UNIONS regardless of nesting.
+            self.assert_compile(
+                stmt,
+                "(SELECT orders.id, orders.description FROM orders WHERE "
+                "orders.id > :id_1 AND orders.description != :description_1 "
+                "UNION SELECT orders.id, orders.description FROM orders WHERE "
+                "orders.id <= :id_2 AND orders.description != :description_2) "
+                "UNION (SELECT items.id, items.description FROM items WHERE "
+                "items.id <= :id_3 AND items.description != :description_3 "
+                "UNION SELECT items.id, items.description FROM items WHERE "
+                "items.id > :id_4 AND items.description != :description_4)",
+                checkparams={
+                    "id_1": 8,
+                    "description_1": "name",
+                    "id_2": 8,
+                    "description_2": "name",
+                    "id_3": 8,
+                    "description_3": "name",
+                    "id_4": 8,
+                    "description_4": "name",
+                },
+            )
+        else:
+            self.assert_compile(
+                stmt,
+                "SELECT orders.id, orders.description FROM orders WHERE "
+                "orders.id > :id_1 AND orders.description != :description_1 "
+                "UNION SELECT orders.id, orders.description FROM orders WHERE "
+                "orders.id <= :id_2 AND orders.description != :description_2",
+                checkparams={
+                    "description_1": "name",
+                    "description_2": "name",
+                    "id_1": 8,
+                    "id_2": 8,
+                },
+            )
+
+    def test_select_mapper_columns_w_core_dml_mapper_criteria(
+        self, multi_mixin_fixture
+    ):
+        """test #9635"""
+        HasFoob, Order, Item = multi_mixin_fixture
+
+        stmt = (
+            insert(Order)
+            .from_select(
+                ["id", "description"],
+                select(Order.id, Order.description).where(Order.id > 8),
+            )
+            .options(
+                with_loader_criteria(
+                    HasFoob,
+                    lambda cls: cls.description != "name",
+                    include_aliases=True,
+                )
+            )
+        )
+
+        self.assert_compile(
+            stmt,
+            "INSERT INTO orders (id, description) SELECT orders.id, "
+            "orders.description FROM orders WHERE orders.id > :id_1 "
+            "AND orders.description != :description_1",
+            checkparams={"description_1": "name", "id_1": 8},
+        )
+
+    @testing.variation("update_is_orm", [True, False])
+    def test_select_mapper_columns_w_core_cte_update_mapper_criteria(
+        self, multi_mixin_fixture, update_is_orm
+    ):
+        """test #9635"""
+        HasFoob, Order, Item = multi_mixin_fixture
+
+        cte = select(Order).cte("pd")
+
+        if update_is_orm:
+            stmt = (
+                update(Order)
+                .where(Order.id == cte.c.id)
+                .values(description="newname")
+            )
+        else:
+            stmt = (
+                update(Order.__table__)
+                .where(Order.__table__.c.id == cte.c.id)
+                .values(description="newname")
+            )
+
+        stmt = stmt.options(
+            with_loader_criteria(
+                HasFoob,
+                lambda cls: cls.description != "name",
+                include_aliases=True,
+            )
+        )
+
+        if update_is_orm:
+            self.assert_compile(
+                stmt,
+                "WITH pd AS (SELECT orders.id AS id, "
+                "orders.user_id AS user_id, "
+                "orders.address_id AS address_id, "
+                "orders.description AS description, orders.isopen AS isopen "
+                "FROM orders WHERE orders.description != %(description_1)s) "
+                "UPDATE orders SET description=%(description)s "
+                "FROM pd WHERE orders.id = pd.id "
+                "AND orders.description != %(description_2)s",
+                dialect="postgresql",
+                checkparams={
+                    "description": "newname",
+                    "description_1": "name",
+                    "description_2": "name",
+                },
+            )
+        else:
+            # non ORM update, no criteria, but criteria still gets rendered
+            # inside the SELECT
+            self.assert_compile(
+                stmt,
+                "WITH pd AS (SELECT orders.id AS id, "
+                "orders.user_id AS user_id, "
+                "orders.address_id AS address_id, "
+                "orders.description AS description, orders.isopen AS isopen "
+                "FROM orders WHERE orders.description != %(description_1)s) "
+                "UPDATE orders SET description=%(description)s "
+                "FROM pd WHERE orders.id = pd.id",
+                dialect="postgresql",
+                checkparams={
+                    "description": "newname",
+                    "description_1": "name",
+                },
+            )
+
+    @testing.variation("delete_is_orm", [True, False])
+    def test_select_mapper_columns_w_core_cte_delete_mapper_criteria(
+        self, multi_mixin_fixture, delete_is_orm
+    ):
+        """test #9635"""
+        HasFoob, Order, Item = multi_mixin_fixture
+
+        cte = select(Order).cte("pd")
+
+        if delete_is_orm:
+            stmt = delete(Order).where(Order.id == cte.c.id)
+        else:
+            stmt = delete(Order.__table__).where(
+                Order.__table__.c.id == cte.c.id
+            )
+
+        stmt = stmt.options(
+            with_loader_criteria(
+                HasFoob,
+                lambda cls: cls.description != "name",
+                include_aliases=True,
+            )
+        )
+
+        if delete_is_orm:
+            self.assert_compile(
+                stmt,
+                "WITH pd AS (SELECT orders.id AS id, orders.user_id AS "
+                "user_id, orders.address_id AS address_id, "
+                "orders.description AS description, orders.isopen AS isopen "
+                "FROM orders WHERE orders.description != %(description_1)s) "
+                "DELETE FROM orders USING pd WHERE orders.id = pd.id "
+                "AND orders.description != %(description_2)s",
+                dialect="postgresql",
+                checkparams={"description_1": "name", "description_2": "name"},
+            )
+        else:
+            # non ORM update, no criteria, but criteria still gets rendered
+            # inside the SELECT
+            self.assert_compile(
+                stmt,
+                "WITH pd AS (SELECT orders.id AS id, orders.user_id AS "
+                "user_id, orders.address_id AS address_id, "
+                "orders.description AS description, orders.isopen AS isopen "
+                "FROM orders WHERE orders.description != %(description_1)s) "
+                "DELETE FROM orders USING pd WHERE orders.id = pd.id",
+                dialect="postgresql",
+                checkparams={"description_1": "name"},
+            )
+
     def test_select_join_mapper_mapper_criteria(self, user_address_fixture):
         User, Address = user_address_fixture