]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
improve natural_path usage in two places
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 27 Apr 2023 20:48:25 +0000 (16:48 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 29 Apr 2023 05:07:53 +0000 (01:07 -0400)
Fixed loader strategy pathing issues where eager loaders such as
:func:`_orm.joinedload` / :func:`_orm.selectinload` would fail to traverse
fully for many-levels deep following a load that had a
:func:`_orm.with_polymorphic` or similar construct as an interim member.

Here we can take advantage of 2.0's refactoring of strategy_options
to identify the "chop_path" concept can be simplified to work
with "natural" paths alone.

In addition, identified existing
logic in PropRegistry that works fine, but needed the "is_unnatural"
attribute to be more accurate for a given path, so we set that
up front to True if the ancestor is_unnatural.

Fixes: #9715
Change-Id: Ie6b3f55b6a23d0d32628afd22437094263745114

doc/build/changelog/unreleased_20/9715.rst [new file with mode: 0644]
lib/sqlalchemy/orm/path_registry.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/testing/fixtures.py
test/orm/inheritance/test_assorted_poly.py

diff --git a/doc/build/changelog/unreleased_20/9715.rst b/doc/build/changelog/unreleased_20/9715.rst
new file mode 100644 (file)
index 0000000..107051b
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9715
+
+    Fixed loader strategy pathing issues where eager loaders such as
+    :func:`_orm.joinedload` / :func:`_orm.selectinload` would fail to traverse
+    fully for many-levels deep following a load that had a
+    :func:`_orm.with_polymorphic` or similar construct as an interim member.
index e7084fbf676442c937f30af1e58dc3489ade423c..e1cbd9313ccc69cd7bcb37a13a5a5bec05cb84c3 100644 (file)
@@ -119,6 +119,8 @@ class PathRegistry(HasCacheKey):
     is_property = False
     is_entity = False
 
+    is_unnatural: bool
+
     path: _PathRepresentation
     natural_path: _PathRepresentation
     parent: Optional[PathRegistry]
@@ -510,7 +512,12 @@ class PropRegistry(PathRegistry):
         # given MapperProperty's parent.
         insp = cast("_InternalEntityType[Any]", parent[-1])
         natural_parent: AbstractEntityRegistry = parent
-        self.is_unnatural = False
+
+        # inherit "is_unnatural" from the parent
+        if parent.parent.is_unnatural:
+            self.is_unnatural = True
+        else:
+            self.is_unnatural = False
 
         if not insp.is_aliased_class or insp._use_mapper_path:  # type: ignore
             parent = natural_parent = parent.parent[prop.parent]
@@ -570,6 +577,7 @@ class PropRegistry(PathRegistry):
         self.parent = parent
         self.path = parent.path + (prop,)
         self.natural_path = natural_parent.natural_path + (prop,)
+
         self.has_entity = prop._links_to_entity
         if prop._is_relationship:
             if TYPE_CHECKING:
@@ -674,7 +682,6 @@ class AbstractEntityRegistry(CreatesToken):
         # elif not parent.path and self.is_aliased_class:
         #     self.natural_path = (self.entity._generate_cache_key()[0], )
         else:
-            # self.natural_path = parent.natural_path + (entity, )
             self.natural_path = self.path
 
     def _truncate_recursive(self) -> AbstractEntityRegistry:
index 5581e5c7fa55daa8a4051f27ee510eb9bc5ceb77..8e06c4f5986bcc3c1af90c6b1d35cb3d7fff173e 100644 (file)
@@ -1063,6 +1063,7 @@ class LazyLoader(
 
         if extra_options:
             stmt._with_options += extra_options
+
         stmt._compile_options += {"_current_path": effective_path}
 
         if use_get:
index 48e69aef27e6a1b817636c4883dc15565fa25425..2e073f326c8256cc5fc79e4543c7a89e8f8f8b64 100644 (file)
@@ -927,7 +927,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
     ) -> Optional[_PathRepresentation]:
         i = -1
 
-        for i, (c_token, p_token) in enumerate(zip(to_chop, path.path)):
+        for i, (c_token, p_token) in enumerate(
+            zip(to_chop, path.natural_path)
+        ):
             if isinstance(c_token, str):
                 if i == 0 and c_token.endswith(f":{_DEFAULT_TOKEN}"):
                     return to_chop
@@ -942,36 +944,8 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
             elif (
                 isinstance(c_token, InspectionAttr)
                 and insp_is_mapper(c_token)
-                and (
-                    (insp_is_mapper(p_token) and c_token.isa(p_token))
-                    or (
-                        # a too-liberal check here to allow a path like
-                        # A->A.bs->B->B.cs->C->C.ds, natural path, to chop
-                        # against current path
-                        # A->A.bs->B(B, B2)->B(B, B2)->cs, in an of_type()
-                        # scenario which should only be occurring in a loader
-                        # that is against a non-aliased lead element with
-                        # single path.  otherwise the
-                        # "B" won't match into the B(B, B2).
-                        #
-                        # i>=2 prevents this check from proceeding for
-                        # the first path element.
-                        #
-                        # if we could do away with the "natural_path"
-                        # concept, we would not need guessy checks like this
-                        #
-                        # two conflicting tests for this comparison are:
-                        # test_eager_relations.py->
-                        #       test_lazyload_aliased_abs_bcs_two
-                        # and
-                        # test_of_type.py->test_all_subq_query
-                        #
-                        i >= 2
-                        and insp_is_aliased_class(p_token)
-                        and p_token._is_with_polymorphic
-                        and c_token in p_token.with_polymorphic_mappers
-                    )
-                )
+                and insp_is_mapper(p_token)
+                and c_token.isa(p_token)
             ):
                 continue
 
@@ -1321,7 +1295,7 @@ class _WildcardLoad(_AbstractLoad):
 
     strategy: Optional[Tuple[Any, ...]]
     local_opts: _OptsType
-    path: Tuple[str, ...]
+    path: Union[Tuple[()], Tuple[str]]
     propagate_to_loaders = False
 
     def __init__(self) -> None:
@@ -1366,6 +1340,7 @@ class _WildcardLoad(_AbstractLoad):
         it may be used as the sub-option of a :class:`_orm.Load` object.
 
         """
+        assert self.path
         attr = self.path[0]
         if attr.endswith(_DEFAULT_TOKEN):
             attr = f"{attr.split(':')[0]}:{_WILDCARD_TOKEN}"
@@ -1396,13 +1371,16 @@ class _WildcardLoad(_AbstractLoad):
 
         start_path: _PathRepresentation = self.path
 
-        # TODO: chop_path already occurs in loader.process_compile_state()
-        # so we will seek to simplify this
         if current_path:
+            # TODO: no cases in test suite where we actually get
+            # None back here
             new_path = self._chop_path(start_path, current_path)
-            if not new_path:
+            if new_path is None:
                 return
-            start_path = new_path
+
+            # chop_path does not actually "chop" a wildcard token path,
+            # just returns it
+            assert new_path == start_path
 
         # start_path is a single-token tuple
         assert start_path and len(start_path) == 1
@@ -1618,7 +1596,9 @@ class _LoadElement(
 
 
         """
-        chopped_start_path = Load._chop_path(effective_path.path, current_path)
+        chopped_start_path = Load._chop_path(
+            effective_path.natural_path, current_path
+        )
         if not chopped_start_path:
             return None
 
index fc1fa1483c63128243b16f545a99ea33da00a6f1..bff251b0f747c45a340bbaecdaffa6e523eb97d2 100644 (file)
@@ -13,6 +13,7 @@ import itertools
 import random
 import re
 import sys
+from typing import Any
 
 import sqlalchemy as sa
 from . import assertions
@@ -675,7 +676,7 @@ class MappedTest(TablesTest, assertions.AssertsExecutionResults):
     # 'once', 'each', None
     run_setup_mappers = "each"
 
-    classes = None
+    classes: Any = None
 
     @config.fixture(autouse=True, scope="class")
     def _setup_tables_test_class(self):
index 4bebc9b102818edbd46354fdcf6b0e402281454e..a40a9ae742c6b438f8f92e3469871033975b4396 100644 (file)
@@ -3,6 +3,10 @@ These are generally tests derived from specific user issues.
 
 """
 
+from __future__ import annotations
+
+from typing import Optional
+
 from sqlalchemy import exists
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
@@ -17,10 +21,14 @@ from sqlalchemy.orm import aliased
 from sqlalchemy.orm import class_mapper
 from sqlalchemy.orm import column_property
 from sqlalchemy.orm import contains_eager
+from sqlalchemy.orm import immediateload
 from sqlalchemy.orm import join
 from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import polymorphic_union
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.orm import with_polymorphic
@@ -2554,3 +2562,206 @@ class Issue8168Test(AssertsCompiledSQL, fixtures.TestBase):
             )
         else:
             scenario.fail()
+
+
+class PolyIntoSelfReferentialTest(
+    fixtures.DeclarativeMappedTest, AssertsExecutionResults
+):
+    """test for #9715"""
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class A(Base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(
+                primary_key=True, autoincrement=True
+            )
+
+            rel_id: Mapped[int] = mapped_column(ForeignKey("related.id"))
+
+            related = relationship("Related")
+
+        class Related(Base):
+            __tablename__ = "related"
+
+            id: Mapped[int] = mapped_column(
+                primary_key=True, autoincrement=True
+            )
+            rel_data: Mapped[str]
+            type: Mapped[str] = mapped_column()
+
+            other_related_id: Mapped[int] = mapped_column(
+                ForeignKey("other_related.id")
+            )
+
+            other_related = relationship("OtherRelated")
+
+            __mapper_args__ = {
+                "polymorphic_identity": "related",
+                "polymorphic_on": type,
+            }
+
+        class SubRelated(Related):
+            __tablename__ = "sub_related"
+
+            id: Mapped[int] = mapped_column(
+                ForeignKey("related.id"), primary_key=True
+            )
+            sub_rel_data: Mapped[str]
+
+            __mapper_args__ = {"polymorphic_identity": "sub_related"}
+
+        class OtherRelated(Base):
+            __tablename__ = "other_related"
+
+            id: Mapped[int] = mapped_column(
+                primary_key=True, autoincrement=True
+            )
+            name: Mapped[str]
+
+            parent_id: Mapped[Optional[int]] = mapped_column(
+                ForeignKey("other_related.id")
+            )
+            parent = relationship("OtherRelated", lazy="raise", remote_side=id)
+
+    @classmethod
+    def insert_data(cls, connection):
+        A, SubRelated, OtherRelated = cls.classes(
+            "A", "SubRelated", "OtherRelated"
+        )
+
+        with Session(connection) as sess:
+
+            grandparent_otherrel1 = OtherRelated(name="GP1")
+            grandparent_otherrel2 = OtherRelated(name="GP2")
+
+            parent_otherrel1 = OtherRelated(
+                name="P1", parent=grandparent_otherrel1
+            )
+            parent_otherrel2 = OtherRelated(
+                name="P2", parent=grandparent_otherrel2
+            )
+
+            otherrel1 = OtherRelated(name="A1", parent=parent_otherrel1)
+            otherrel3 = OtherRelated(name="A2", parent=parent_otherrel2)
+
+            address1 = SubRelated(
+                rel_data="ST1", other_related=otherrel1, sub_rel_data="w1"
+            )
+            address3 = SubRelated(
+                rel_data="ST2", other_related=otherrel3, sub_rel_data="w2"
+            )
+
+            a1 = A(related=address1)
+            a2 = A(related=address3)
+
+            sess.add_all([a1, a2])
+            sess.commit()
+
+    def _run_load(self, *opt):
+        A = self.classes.A
+        stmt = select(A).options(*opt)
+
+        sess = fixture_session()
+        all_a = sess.scalars(stmt).all()
+
+        sess.close()
+
+        with self.assert_statement_count(testing.db, 0):
+            for a1 in all_a:
+                d1 = a1.related
+                d2 = d1.other_related
+                d3 = d2.parent
+                d4 = d3.parent
+                assert d4.name in ("GP1", "GP2")
+
+    @testing.variation("use_workaround", [True, False])
+    def test_workaround(self, use_workaround):
+        A, Related, SubRelated, OtherRelated = self.classes(
+            "A", "Related", "SubRelated", "OtherRelated"
+        )
+
+        related = with_polymorphic(Related, [SubRelated], flat=True)
+
+        opt = [
+            (
+                joinedload(A.related.of_type(related))
+                .joinedload(related.other_related)
+                .joinedload(
+                    OtherRelated.parent,
+                )
+            )
+        ]
+        if use_workaround:
+            opt.append(
+                joinedload(
+                    A.related,
+                    Related.other_related,
+                    OtherRelated.parent,
+                    OtherRelated.parent,
+                )
+            )
+        else:
+            opt[0] = opt[0].joinedload(OtherRelated.parent)
+
+        self._run_load(*opt)
+
+    @testing.combinations(
+        (("joined", "joined", "joined", "joined"),),
+        (("selectin", "selectin", "selectin", "selectin"),),
+        (("selectin", "selectin", "joined", "joined"),),
+        (("selectin", "selectin", "joined", "selectin"),),
+        (("joined", "selectin", "joined", "selectin"),),
+        # TODO: immediateload (and lazyload) do not support the target item
+        # being a with_polymorphic.  this seems to be a limitation in the
+        # current_path logic
+        # (("immediate", "joined", "joined", "joined"),),
+        argnames="loaders",
+    )
+    @testing.variation("use_wpoly", [True, False])
+    def test_all_load(self, loaders, use_wpoly):
+        A, Related, SubRelated, OtherRelated = self.classes(
+            "A", "Related", "SubRelated", "OtherRelated"
+        )
+
+        if use_wpoly:
+            related = with_polymorphic(Related, [SubRelated], flat=True)
+        else:
+            related = SubRelated
+
+        opt = None
+        for i, (load_type, element) in enumerate(
+            zip(
+                loaders,
+                [
+                    A.related.of_type(related),
+                    related.other_related,
+                    OtherRelated.parent,
+                    OtherRelated.parent,
+                ],
+            )
+        ):
+            if i == 0:
+                if load_type == "joined":
+                    opt = joinedload(element)
+                elif load_type == "selectin":
+                    opt = selectinload(element)
+                elif load_type == "immediate":
+                    opt = immediateload(element)
+                else:
+                    assert False
+            else:
+                assert opt is not None
+                if load_type == "joined":
+                    opt = opt.joinedload(element)
+                elif load_type == "selectin":
+                    opt = opt.selectinload(element)
+                elif load_type == "immediate":
+                    opt = opt.immediateload(element)
+                else:
+                    assert False
+
+        self._run_load(opt)