]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Run SelectState from obj normalize ahead of calcing ORM joins
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 17 May 2021 15:20:10 +0000 (11:20 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 17 May 2021 16:26:48 +0000 (12:26 -0400)
Fixed regression where the full combination of joined inheritance, global
with_polymorphic, self-referential relationship and joined loading would
fail to be able to produce a query with the scope of lazy loads and object
refresh operations that also attempted to render the joined loader.

Fixes: #6495
Change-Id: If74a744c237069e3a89617498096c18b9b6e5dde

doc/build/changelog/unreleased_14/6495.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/selectable.py
test/orm/inheritance/test_assorted_poly.py

diff --git a/doc/build/changelog/unreleased_14/6495.rst b/doc/build/changelog/unreleased_14/6495.rst
new file mode 100644 (file)
index 0000000..8bb96bc
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 6495
+
+    Fixed regression where the full combination of joined inheritance, global
+    with_polymorphic, self-referential relationship and joined loading would
+    fail to be able to produce a query with the scope of lazy loads and object
+    refresh operations that also attempted to render the joined loader.
index ea84805b4d99e063a8a9d9d1c98558f709b96802..baad288359163e7697a6e41610652134fdde63f7 100644 (file)
@@ -591,9 +591,15 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         self.create_eager_joins = []
         self._fallback_from_clauses = []
 
-        self.from_clauses = [
+        # normalize the FROM clauses early by themselves, as this makes
+        # it an easier job when we need to assemble a JOIN onto these,
+        # for select.join() as well as joinedload().   As of 1.4 there are now
+        # potentially more complex sets of FROM objects here as the use
+        # of lambda statements for lazyload, load_on_pk etc. uses more
+        # cloning of the select() construct.  See #6495
+        self.from_clauses = self._normalize_froms(
             info.selectable for info in select_statement._from_obj
-        ]
+        )
 
         # this is a fairly arbitrary break into a second method,
         # so it might be nicer to break up create_for_statement()
index 2c37ecaa0fbeb567f894bf262266bc00b3d5af4f..dca45730c5597d593f39f1823cdcfde389223c9b 100644 (file)
@@ -2226,6 +2226,7 @@ class JoinedLoader(AbstractRelationshipLoader):
             and not should_nest_selectable
             and compile_state.from_clauses
         ):
+
             indexes = sql_util.find_left_clause_that_matches_given(
                 compile_state.from_clauses, query_entity.selectable
             )
index 7007bb430e70c997d2e2b930973f9a8d6e5d7cdc..997c3588ec098b3947b294013709dd9f1ab8d3bc 100644 (file)
@@ -4209,38 +4209,57 @@ class SelectState(util.MemoizedSlots, CompileState):
         return go
 
     def _get_froms(self, statement):
+        return self._normalize_froms(
+            itertools.chain(
+                itertools.chain.from_iterable(
+                    [
+                        element._from_objects
+                        for element in statement._raw_columns
+                    ]
+                ),
+                itertools.chain.from_iterable(
+                    [
+                        element._from_objects
+                        for element in statement._where_criteria
+                    ]
+                ),
+                self.from_clauses,
+            ),
+            check_statement=statement,
+        )
+
+    def _normalize_froms(self, iterable_of_froms, check_statement=None):
+        """given an iterable of things to select FROM, reduce them to what
+        would actually render in the FROM clause of a SELECT.
+
+        This does the job of checking for JOINs, tables, etc. that are in fact
+        overlapping due to cloning, adaption, present in overlapping joins,
+        etc.
+
+        """
         seen = set()
         froms = []
 
-        for item in itertools.chain(
-            itertools.chain.from_iterable(
-                [element._from_objects for element in statement._raw_columns]
-            ),
-            itertools.chain.from_iterable(
-                [
-                    element._from_objects
-                    for element in statement._where_criteria
-                ]
-            ),
-            self.from_clauses,
-        ):
-            if item._is_subquery and item.element is statement:
+        for item in iterable_of_froms:
+            if item._is_subquery and item.element is check_statement:
                 raise exc.InvalidRequestError(
                     "select() construct refers to itself as a FROM"
                 )
+
             if not seen.intersection(item._cloned_set):
                 froms.append(item)
                 seen.update(item._cloned_set)
 
-        toremove = set(
-            itertools.chain.from_iterable(
-                [_expand_cloned(f._hide_froms) for f in froms]
+        if froms:
+            toremove = set(
+                itertools.chain.from_iterable(
+                    [_expand_cloned(f._hide_froms) for f in froms]
+                )
             )
-        )
-        if toremove:
-            # filter out to FROM clauses not in the list,
-            # using a list to maintain ordering
-            froms = [f for f in froms if f not in toremove]
+            if toremove:
+                # filter out to FROM clauses not in the list,
+                # using a list to maintain ordering
+                froms = [f for f in froms if f not in toremove]
 
         return froms
 
index 3cf9c983722fd68f6ce29d60755fe1288b473890..6606e9b0ce8245832d9f89df49c0d573c5e4e454 100644 (file)
@@ -21,12 +21,14 @@ from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import polymorphic_union
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.orm.interfaces import MANYTOONE
 from sqlalchemy.testing import AssertsExecutionResults
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -1105,6 +1107,80 @@ class RelationshipTest8(fixtures.MappedTest):
         )
 
 
+class SelfRefWPolyJoinedLoadTest(fixtures.DeclarativeMappedTest):
+    """test #6495"""
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class Node(ComparableEntity, Base):
+            __tablename__ = "nodes"
+
+            id = Column(Integer, primary_key=True)
+
+            parent_id = Column(ForeignKey("nodes.id"))
+            type = Column(String(50))
+
+            parent = relationship("Node", remote_side=id)
+
+            local_groups = relationship("LocalGroup", lazy="joined")
+
+            __mapper_args__ = {
+                "polymorphic_on": type,
+                "with_polymorphic": ("*"),
+                "polymorphic_identity": "node",
+            }
+
+        class Content(Node):
+            __tablename__ = "content"
+
+            id = Column(ForeignKey("nodes.id"), primary_key=True)
+
+            __mapper_args__ = {
+                "polymorphic_identity": "content",
+            }
+
+        class File(Node):
+            __tablename__ = "file"
+
+            id = Column(ForeignKey("nodes.id"), primary_key=True)
+            __mapper_args__ = {
+                "polymorphic_identity": "file",
+            }
+
+        class LocalGroup(ComparableEntity, Base):
+            __tablename__ = "local_group"
+            id = Column(Integer, primary_key=True)
+
+            node_id = Column(ForeignKey("nodes.id"))
+
+    @classmethod
+    def insert_data(cls, connection):
+        Node, LocalGroup = cls.classes("Node", "LocalGroup")
+
+        with Session(connection) as sess:
+            f1 = Node(id=2, local_groups=[LocalGroup(), LocalGroup()])
+            c1 = Node(id=1)
+            c1.parent = f1
+
+            sess.add_all([f1, c1])
+
+            sess.commit()
+
+    def test_emit_lazy_loadonpk_parent(self):
+        Node, LocalGroup = self.classes("Node", "LocalGroup")
+
+        s = fixture_session()
+        c1 = s.query(Node).filter_by(id=1).first()
+
+        def go():
+            p1 = c1.parent
+            eq_(p1, Node(id=2, local_groups=[LocalGroup(), LocalGroup()]))
+
+        self.assert_sql_count(testing.db, go, 1)
+
+
 class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults):
     @classmethod
     def define_tables(cls, metadata):