]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
remove upfront sanitization of entities from joins
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Sep 2025 16:26:06 +0000 (12:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Sep 2025 13:10:03 +0000 (09:10 -0400)
ORM entities can now be involved within the SQL expressions used within
:paramref:`_orm.relationship.primaryjoin` and
:paramref:`_orm.relationship.secondaryjoin` parameters without the ORM
entity information being implicitly sanitized, allowing ORM-specific
features such as single-inheritance criteria in subqueries to continue
working even when used in this context.   This is made possible by overall
ORM simplifications that occurred as of the 2.0 series.  The changes here
also provide a performance boost (up to 20%) for certain query compilation
scenarios.

Here we see that we're not only able to remove the
relationships deannotation steps, but we can also change
context -> _get_current_adapter() to be an unconditional
adapter, since the only remaining case where it was conditional
was the polymorphic_adapter.  that adapter is itself
only used for exotic joined inh cases against select
statements (totally not used by anyone) or by abstract
concrete setups.   That lets us remove a whole host
of orm_annotate stuff that doesn't apply anymore.

if this does lead to user regressions in 2.1 it will be
a good reason for us to revisit the complexity here in
any case.

Fixes: #12843
Change-Id: Ic1c6e72d70ec6a27b73495c1a56e9307c9280133

doc/build/changelog/unreleased_21/12843.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
test/orm/inheritance/test_relationship.py
test/orm/test_relationships.py
test/profiles.txt

diff --git a/doc/build/changelog/unreleased_21/12843.rst b/doc/build/changelog/unreleased_21/12843.rst
new file mode 100644 (file)
index 0000000..679edf0
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 12843
+
+    ORM entities can now be involved within the SQL expressions used within
+    :paramref:`_orm.relationship.primaryjoin` and
+    :paramref:`_orm.relationship.secondaryjoin` parameters without the ORM
+    entity information being implicitly sanitized, allowing ORM-specific
+    features such as single-inheritance criteria in subqueries to continue
+    working even when used in this context.   This is made possible by overall
+    ORM simplifications that occurred as of the 2.0 series.  The changes here
+    also provide a performance boost (up to 20%) for certain query compilation
+    scenarios.
index f00691fbc8956bcec2faa0b531018eaef24a47c9..15a9d2a7869906a687fee182b65b7540e7931d12 100644 (file)
@@ -1828,17 +1828,14 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
             # subquery of itself, i.e. _from_selectable(), apply adaption
             # to all SQL constructs.
             adapters.append(
-                (
-                    True,
-                    self._from_obj_alias.replace,
-                )
+                self._from_obj_alias.replace,
             )
 
         # this was *hopefully* the only adapter we were going to need
         # going forward...however, we unfortunately need _from_obj_alias
         # for query.union(), which we can't drop
         if self._polymorphic_adapters:
-            adapters.append((False, self._adapt_polymorphic_element))
+            adapters.append(self._adapt_polymorphic_element)
 
         if not adapters:
             return None
@@ -1848,15 +1845,10 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
             # tagged as 'ORM' constructs ?
 
             def replace(elem):
-                is_orm_adapt = (
-                    "_orm_adapt" in elem._annotations
-                    or "parententity" in elem._annotations
-                )
-                for always_adapt, adapter in adapters:
-                    if is_orm_adapt or always_adapt:
-                        e = adapter(elem)
-                        if e is not None:
-                            return e
+                for adapter in adapters:
+                    e = adapter(elem)
+                    if e is not None:
+                        return e
 
             return visitors.replacement_traverse(clause, {}, replace)
 
@@ -2565,7 +2557,6 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
             # finally run all the criteria through the "main" adapter, if we
             # have one, and concatenate to final WHERE criteria
             for crit in _where_criteria_to_add:
-                crit = sql_util._deep_annotate(crit, {"_orm_adapt": True})
                 crit = current_adapter(crit, False)
                 self._where_criteria += (crit,)
         else:
index 0e28a7a4682d0d9e06cd23a5120ede23a98dc136..f1d90f8d8728c5772d3510498bc4c01314b3b805 100644 (file)
@@ -581,9 +581,7 @@ def _load_on_pk_identity(
                     "release."
                 )
 
-        q._where_criteria = (
-            sql_util._deep_annotate(_get_clause, {"_orm_adapt": True}),
-        )
+        q._where_criteria = (_get_clause,)
 
         params = {
             _get_params[primary_key].key: id_val
index 33d6646b35a91a24b00a788d5fd6727560466fd2..017b829d8a0971f6b8ea9029a5d8c7497a03ec0c 100644 (file)
@@ -3821,10 +3821,7 @@ class Mapper(
                     _reconcile_to_other=False,
                 )
 
-        primary_key = [
-            sql_util._deep_annotate(pk, {"_orm_adapt": True})
-            for pk in self.primary_key
-        ]
+        primary_key = list(self.primary_key)
 
         in_expr: ColumnElement[Any]
 
index b608c520160868bb6853bef5cdeb708a6cff8f69..2373fd9ccdebbdd25da1a29ecda5742c577550c5 100644 (file)
@@ -40,6 +40,7 @@ from typing import Sequence
 from typing import Set
 from typing import Tuple
 from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 import weakref
@@ -65,8 +66,6 @@ from .interfaces import ONETOMANY
 from .interfaces import PropComparator
 from .interfaces import RelationshipDirection
 from .interfaces import StrategizedProperty
-from .util import _orm_annotate
-from .util import _orm_deannotate
 from .util import CascadeOptions
 from .. import exc as sa_exc
 from .. import Exists
@@ -394,7 +393,7 @@ class RelationshipProperty(
     synchronize_pairs: _ColumnPairs
     secondary_synchronize_pairs: Optional[_ColumnPairs]
 
-    local_remote_pairs: Optional[_ColumnPairs]
+    local_remote_pairs: _ColumnPairs
 
     direction: RelationshipDirection
 
@@ -515,7 +514,7 @@ class RelationshipProperty(
             )
 
         self.omit_join = omit_join
-        self.local_remote_pairs = _local_remote_pairs
+        self.local_remote_pairs = _local_remote_pairs or ()
         self.load_on_pending = load_on_pending
         self.comparator_factory = (
             comparator_factory or RelationshipProperty.Comparator
@@ -804,10 +803,8 @@ class RelationshipProperty(
                 if self.property.direction in [ONETOMANY, MANYTOMANY]:
                     return ~self._criterion_exists()
                 else:
-                    return _orm_annotate(
-                        self.property._optimized_compare(
-                            None, adapt_source=self.adapter
-                        )
+                    return self.property._optimized_compare(
+                        None, adapt_source=self.adapter
                     )
             elif self.property.uselist:
                 raise sa_exc.InvalidRequestError(
@@ -815,10 +812,8 @@ class RelationshipProperty(
                     "use contains() to test for membership."
                 )
             else:
-                return _orm_annotate(
-                    self.property._optimized_compare(
-                        other, adapt_source=self.adapter
-                    )
+                return self.property._optimized_compare(
+                    other, adapt_source=self.adapter
                 )
 
         def _criterion_exists(
@@ -882,10 +877,11 @@ class RelationshipProperty(
             # annotate the *local* side of the join condition, in the case
             # of pj + sj this is the full primaryjoin, in the case of just
             # pj its the local side of the primaryjoin.
+            j: ColumnElement[bool]
             if sj is not None:
-                j = _orm_annotate(pj) & sj
+                j = pj & sj
             else:
-                j = _orm_annotate(pj, exclude=self.property.remote_side)
+                j = pj
 
             if (
                 where_criteria is not None
@@ -1194,10 +1190,8 @@ class RelationshipProperty(
             """
             if other is None or isinstance(other, expression.Null):
                 if self.property.direction == MANYTOONE:
-                    return _orm_annotate(
-                        ~self.property._optimized_compare(
-                            None, adapt_source=self.adapter
-                        )
+                    return ~self.property._optimized_compare(
+                        None, adapt_source=self.adapter
                     )
 
                 else:
@@ -1209,7 +1203,10 @@ class RelationshipProperty(
                     "contains() to test for membership."
                 )
             else:
-                return _orm_annotate(self.__negated_contains_or_equals(other))
+                return self.__negated_contains_or_equals(other)
+
+        if TYPE_CHECKING:
+            property: RelationshipProperty[_PT]  # noqa: A001
 
         def _memoized_attr_property(self) -> RelationshipProperty[_PT]:
             self.prop.parent._check_configure()
@@ -1757,10 +1754,8 @@ class RelationshipProperty(
             rel_arg = getattr(init_args, attr)
             val = rel_arg.resolved
             if val is not None:
-                rel_arg.resolved = _orm_deannotate(
-                    coercions.expect(
-                        roles.ColumnArgumentRole, val, argname=attr
-                    )
+                rel_arg.resolved = coercions.expect(
+                    roles.ColumnArgumentRole, val, argname=attr
                 )
 
         secondary = init_args.secondary.resolved
@@ -2393,7 +2388,6 @@ class _JoinCondition:
         self._determine_joins()
         assert self.primaryjoin is not None
 
-        self._sanitize_joins()
         self._annotate_fks()
         self._annotate_remote()
         self._annotate_local()
@@ -2444,24 +2438,6 @@ class _JoinCondition:
         )
         log.info("%s relationship direction %s", self.prop, self.direction)
 
-    def _sanitize_joins(self) -> None:
-        """remove the parententity annotation from our join conditions which
-        can leak in here based on some declarative patterns and maybe others.
-
-        "parentmapper" is relied upon both by the ORM evaluator as well as
-        the use case in _join_fixture_inh_selfref_w_entity
-        that relies upon it being present, see :ticket:`3364`.
-
-        """
-
-        self.primaryjoin = _deep_deannotate(
-            self.primaryjoin, values=("parententity", "proxy_key")
-        )
-        if self.secondaryjoin is not None:
-            self.secondaryjoin = _deep_deannotate(
-                self.secondaryjoin, values=("parententity", "proxy_key")
-            )
-
     def _determine_joins(self) -> None:
         """Determine the 'primaryjoin' and 'secondaryjoin' attributes,
         if not passed to the constructor already.
index 392ba671e95774285ca54273bb2f895a7c0625a7..8af9c020382f7c8eec418e64837654e606874f95 100644 (file)
@@ -740,10 +740,7 @@ class _LazyLoader(
         ) = join_condition.create_lazy_clause(reverse_direction=True)
 
         if self.parent_property.order_by:
-            self._order_by = [
-                sql_util._deep_annotate(elem, {"_orm_adapt": True})
-                for elem in util.to_list(self.parent_property.order_by)
-            ]
+            self._order_by = util.to_list(self.parent_property.order_by)
         else:
             self._order_by = None
 
@@ -812,9 +809,7 @@ class _LazyLoader(
         )
 
     def _memoized_attr__simple_lazy_clause(self):
-        lazywhere = sql_util._deep_annotate(
-            self._lazywhere, {"_orm_adapt": True}
-        )
+        lazywhere = self._lazywhere
 
         criterion, bind_to_col = (lazywhere, self._bind_to_col)
 
index 0f43aff820f654e467d2bcd35775fdcdcc37867e..aacf2f736237da66b0c812e0e7103fbe4002eeb3 100644 (file)
@@ -1755,30 +1755,6 @@ class Bundle(
         return proc
 
 
-def _orm_annotate(element: _SA, exclude: Optional[Any] = None) -> _SA:
-    """Deep copy the given ClauseElement, annotating each element with the
-    "_orm_adapt" flag.
-
-    Elements within the exclude collection will be cloned but not annotated.
-
-    """
-    return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude)
-
-
-def _orm_deannotate(element: _SA) -> _SA:
-    """Remove annotations that link a column to a particular mapping.
-
-    Note this doesn't affect "remote" and "foreign" annotations
-    passed by the :func:`_orm.foreign` and :func:`_orm.remote`
-    annotators.
-
-    """
-
-    return sql_util._deep_deannotate(
-        element, values=("_orm_adapt", "parententity")
-    )
-
-
 def _orm_full_deannotate(element: _SA) -> _SA:
     return sql_util._deep_deannotate(element)
 
index e2016f8b5d9b07becd363fba66e139cc33a49478..49f8c9062086e0e05394742d4a7d251c62fffdb3 100644 (file)
@@ -1,17 +1,23 @@
 from contextlib import nullcontext
 
+from sqlalchemy import and_
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import or_
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import backref
+from sqlalchemy.orm import column_property
 from sqlalchemy.orm import configure_mappers
 from sqlalchemy.orm import contains_eager
+from sqlalchemy.orm import foreign
 from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
@@ -23,6 +29,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing.assertions import expect_raises_message
+from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
@@ -3103,3 +3110,115 @@ class JoinedLoadSpliceFromJoinedTest(
             "ON base_model_1.id = sub_model_element_1.model_id"
             "",
         )
+
+
+class SingleSubclassInRelationship(
+    AssertsCompiledSQL, fixtures.DeclarativeMappedTest
+):
+    """test for #12843 / discussion #12842"""
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class LogEntry(ComparableEntity, Base):
+            __tablename__ = "log_entry"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            timestamp: Mapped[int] = mapped_column(Integer)
+            type: Mapped[str]
+
+            __mapper_args__ = {
+                "polymorphic_on": "type",
+                "polymorphic_identity": "log_entry",
+            }
+
+        class StartEntry(LogEntry):
+            __mapper_args__ = {
+                "polymorphic_identity": "start_entry",
+            }
+
+        StartAlias = aliased(StartEntry)
+
+        next_start_ts = (
+            select(func.min(StartAlias.timestamp))
+            .where(
+                StartAlias.timestamp > LogEntry.timestamp,
+            )
+            .scalar_subquery()
+        )
+
+        StartEntry.next_start_ts = column_property(next_start_ts)
+
+        LogAlias = aliased(LogEntry)
+
+        StartEntry.associated_entries = relationship(
+            LogAlias,
+            primaryjoin=and_(
+                foreign(LogAlias.timestamp) >= LogEntry.timestamp,
+                or_(
+                    next_start_ts == None,
+                    LogAlias.timestamp < next_start_ts,
+                ),
+            ),
+            viewonly=True,
+            order_by=LogAlias.id,
+        )
+
+    @classmethod
+    def insert_data(cls, connection):
+        LogEntry, StartEntry = cls.classes.LogEntry, cls.classes.StartEntry
+
+        with Session(connection) as sess:
+            s1 = StartEntry(timestamp=1)
+            l1 = LogEntry(timestamp=2)
+            l2 = LogEntry(timestamp=3)
+
+            s2 = StartEntry(timestamp=4)
+            l3 = LogEntry(timestamp=5)
+
+            sess.add_all([s1, l1, l2, s2, l3])
+            sess.commit()
+
+    def test_assoc_entries(self):
+        LogEntry, StartEntry = self.classes.LogEntry, self.classes.StartEntry
+
+        sess = fixture_session()
+
+        s1 = sess.scalars(select(StartEntry).filter_by(timestamp=1)).one()
+
+        with self.sql_execution_asserter(testing.db) as asserter:
+            eq_(
+                s1.associated_entries,
+                [
+                    StartEntry(timestamp=1),
+                    LogEntry(timestamp=2),
+                    LogEntry(timestamp=3),
+                ],
+            )
+
+        asserter.assert_(
+            CompiledSQL(
+                "SELECT log_entry_1.id AS log_entry_1_id, "
+                "log_entry_1.timestamp AS log_entry_1_timestamp, "
+                "log_entry_1.type AS log_entry_1_type "
+                "FROM log_entry AS log_entry_1 "
+                "WHERE log_entry_1.timestamp >= :param_1 AND "
+                "((SELECT min(log_entry_2.timestamp) AS min_1 "
+                "FROM log_entry AS log_entry_2 "
+                "WHERE log_entry_2.timestamp > :param_1 "
+                "AND log_entry_2.type IN (__[POSTCOMPILE_type_1])) IS NULL "
+                "OR log_entry_1.timestamp < "
+                "(SELECT min(log_entry_2.timestamp) AS min_1 "
+                "FROM log_entry AS log_entry_2 "
+                "WHERE log_entry_2.timestamp > :param_1 "
+                "AND log_entry_2.type IN (__[POSTCOMPILE_type_2]))) "
+                "ORDER BY log_entry_1.id",
+                params=[
+                    {
+                        "param_1": 1,
+                        "type_1": ["start_entry"],
+                        "type_2": ["start_entry"],
+                    }
+                ],
+            )
+        )
index 589dcf2fed46b3eae7e3b9e0a9be915d314d1686..3bfb9b06bfa17643941ae5756101d40f6209225d 100644 (file)
@@ -6682,3 +6682,92 @@ class SecondaryIncludesLocalColsTest(fixtures.MappedTest):
                 params=[{"id_1": "%", "param_1": "%", "primary_keys": [2]}],
             ),
         )
+
+
+class AnnotationsMaintainedTest(AssertsCompiledSQL, fixtures.TestBase):
+    """tests for #12843"""
+
+    __dialect__ = "default"
+
+    def test_annos_maintained(self, decl_base):
+        class User(decl_base):
+            __tablename__ = "user"
+            id = Column(Integer, primary_key=True)
+
+        class Address(decl_base):
+            __tablename__ = "address"
+            id = Column(Integer, primary_key=True)
+            user_id = Column(ForeignKey("user.id"))
+
+        User.addresses = relationship(
+            Address, primaryjoin=User.id == foreign(Address.user_id)
+        )
+
+        is_(
+            User.addresses.property.primaryjoin.left._annotations[
+                "parententity"
+            ],
+            User.__mapper__,
+        )
+        is_(
+            User.addresses.property.primaryjoin.right._annotations[
+                "parententity"
+            ],
+            Address.__mapper__,
+        )
+
+    @testing.variation("use_orm", [True, False])
+    def test_orm_operations_primaryjoin(self, decl_base, use_orm):
+        class Employee(decl_base):
+            __tablename__ = "employee"
+            id = Column(Integer, primary_key=True)
+            type = Column(String(50))
+            company_id = Column(Integer)
+            __mapper_args__ = {
+                "polymorphic_identity": "employee",
+                "polymorphic_on": type,
+            }
+
+        class Engineer(Employee):
+            __mapper_args__ = {"polymorphic_identity": "engineer"}
+
+        class Company(decl_base):
+            __tablename__ = "company"
+            id = Column(Integer, primary_key=True)
+
+            employees_who_are_engineers = relationship(
+                Employee,
+                # this is a ridiculous primaryjoin and relationship,
+                # but we just need to see that the single inh clause
+                # generates, indicating we know we have an ORM entity
+                # for Engineer
+                primaryjoin=lambda: and_(
+                    foreign(Employee.company_id) == Company.id,
+                    Employee.id.in_(subq),
+                ),
+            )
+
+        if use_orm:
+            # will render "type IN <types>"
+            subq = (
+                select(Engineer)
+                .where(foreign(Engineer.company_id) == Company.id)
+                .correlate(Company)
+            )
+        else:
+            # will not render "type IN <types>"
+            subq = (
+                select(Engineer.__table__)
+                .where(foreign(Engineer.company_id) == Company.id)
+                .correlate(Company)
+            )
+
+        self.assert_compile(
+            select(Company).join(Company.employees_who_are_engineers),
+            "SELECT company.id FROM company JOIN employee "
+            "ON employee.company_id = company.id AND employee.id IN "
+            "(SELECT employee.id, employee.type, employee.company_id "
+            "FROM employee WHERE employee.company_id = company.id"
+            f"""{" AND employee.type IN (__[POSTCOMPILE_type_1])"
+                 if use_orm else ""})""",
+        )
index d29b8c745491e91cd6a444ac5ec1b1768596b395..512c700249cb82875f181d7d3174a52f82349858 100644 (file)
@@ -373,10 +373,10 @@ test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_
 
 # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity
 
-test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.13_sqlite_pysqlite_dbapiunicode_cextensions 110410
-test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.13_sqlite_pysqlite_dbapiunicode_nocextensions 123169
-test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.14_sqlite_pysqlite_dbapiunicode_cextensions 110657
-test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.14_sqlite_pysqlite_dbapiunicode_nocextensions 123416
+test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.13_sqlite_pysqlite_dbapiunicode_cextensions 84733
+test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.13_sqlite_pysqlite_dbapiunicode_nocextensions 97492
+test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.14_sqlite_pysqlite_dbapiunicode_cextensions 84980
+test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_no_identity x86_64_linux_cpython_3.14_sqlite_pysqlite_dbapiunicode_nocextensions 97739
 
 # TEST: test.aaa_profiling.test_orm.MergeBackrefsTest.test_merge_pending_with_all_pks