From: Mike Bayer Date: Mon, 17 May 2021 15:20:10 +0000 (-0400) Subject: Run SelectState from obj normalize ahead of calcing ORM joins X-Git-Tag: rel_1_4_16~24 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=a49b2c3dbb9bff1d004eb2c53a752999e27ff769;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Run SelectState from obj normalize ahead of calcing ORM joins Fixed regression where the full combination of joined inheritance, global with_polymorphic, self-referential relationship and joined loading would fail to be able to produce a query with the scope of lazy loads and object refresh operations that also attempted to render the joined loader. Fixes: #6495 Change-Id: If74a744c237069e3a89617498096c18b9b6e5dde --- diff --git a/doc/build/changelog/unreleased_14/6495.rst b/doc/build/changelog/unreleased_14/6495.rst new file mode 100644 index 0000000000..8bb96bc423 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6495.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, orm, regression + :tickets: 6495 + + Fixed regression where the full combination of joined inheritance, global + with_polymorphic, self-referential relationship and joined loading would + fail to be able to produce a query with the scope of lazy loads and object + refresh operations that also attempted to render the joined loader. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index ea84805b4d..baad288359 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -591,9 +591,15 @@ class ORMSelectCompileState(ORMCompileState, SelectState): self.create_eager_joins = [] self._fallback_from_clauses = [] - self.from_clauses = [ + # normalize the FROM clauses early by themselves, as this makes + # it an easier job when we need to assemble a JOIN onto these, + # for select.join() as well as joinedload(). As of 1.4 there are now + # potentially more complex sets of FROM objects here as the use + # of lambda statements for lazyload, load_on_pk etc. uses more + # cloning of the select() construct. See #6495 + self.from_clauses = self._normalize_froms( info.selectable for info in select_statement._from_obj - ] + ) # this is a fairly arbitrary break into a second method, # so it might be nicer to break up create_for_statement() diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 2c37ecaa0f..dca45730c5 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -2226,6 +2226,7 @@ class JoinedLoader(AbstractRelationshipLoader): and not should_nest_selectable and compile_state.from_clauses ): + indexes = sql_util.find_left_clause_that_matches_given( compile_state.from_clauses, query_entity.selectable ) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 7007bb430e..997c3588ec 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4209,38 +4209,57 @@ class SelectState(util.MemoizedSlots, CompileState): return go def _get_froms(self, statement): + return self._normalize_froms( + itertools.chain( + itertools.chain.from_iterable( + [ + element._from_objects + for element in statement._raw_columns + ] + ), + itertools.chain.from_iterable( + [ + element._from_objects + for element in statement._where_criteria + ] + ), + self.from_clauses, + ), + check_statement=statement, + ) + + def _normalize_froms(self, iterable_of_froms, check_statement=None): + """given an iterable of things to select FROM, reduce them to what + would actually render in the FROM clause of a SELECT. + + This does the job of checking for JOINs, tables, etc. that are in fact + overlapping due to cloning, adaption, present in overlapping joins, + etc. + + """ seen = set() froms = [] - for item in itertools.chain( - itertools.chain.from_iterable( - [element._from_objects for element in statement._raw_columns] - ), - itertools.chain.from_iterable( - [ - element._from_objects - for element in statement._where_criteria - ] - ), - self.from_clauses, - ): - if item._is_subquery and item.element is statement: + for item in iterable_of_froms: + if item._is_subquery and item.element is check_statement: raise exc.InvalidRequestError( "select() construct refers to itself as a FROM" ) + if not seen.intersection(item._cloned_set): froms.append(item) seen.update(item._cloned_set) - toremove = set( - itertools.chain.from_iterable( - [_expand_cloned(f._hide_froms) for f in froms] + if froms: + toremove = set( + itertools.chain.from_iterable( + [_expand_cloned(f._hide_froms) for f in froms] + ) ) - ) - if toremove: - # filter out to FROM clauses not in the list, - # using a list to maintain ordering - froms = [f for f in froms if f not in toremove] + if toremove: + # filter out to FROM clauses not in the list, + # using a list to maintain ordering + froms = [f for f in froms if f not in toremove] return froms diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index 3cf9c98372..6606e9b0ce 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -21,12 +21,14 @@ from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper from sqlalchemy.orm import polymorphic_union from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import with_polymorphic from sqlalchemy.orm.interfaces import MANYTOONE from sqlalchemy.testing import AssertsExecutionResults from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -1105,6 +1107,80 @@ class RelationshipTest8(fixtures.MappedTest): ) +class SelfRefWPolyJoinedLoadTest(fixtures.DeclarativeMappedTest): + """test #6495""" + + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + class Node(ComparableEntity, Base): + __tablename__ = "nodes" + + id = Column(Integer, primary_key=True) + + parent_id = Column(ForeignKey("nodes.id")) + type = Column(String(50)) + + parent = relationship("Node", remote_side=id) + + local_groups = relationship("LocalGroup", lazy="joined") + + __mapper_args__ = { + "polymorphic_on": type, + "with_polymorphic": ("*"), + "polymorphic_identity": "node", + } + + class Content(Node): + __tablename__ = "content" + + id = Column(ForeignKey("nodes.id"), primary_key=True) + + __mapper_args__ = { + "polymorphic_identity": "content", + } + + class File(Node): + __tablename__ = "file" + + id = Column(ForeignKey("nodes.id"), primary_key=True) + __mapper_args__ = { + "polymorphic_identity": "file", + } + + class LocalGroup(ComparableEntity, Base): + __tablename__ = "local_group" + id = Column(Integer, primary_key=True) + + node_id = Column(ForeignKey("nodes.id")) + + @classmethod + def insert_data(cls, connection): + Node, LocalGroup = cls.classes("Node", "LocalGroup") + + with Session(connection) as sess: + f1 = Node(id=2, local_groups=[LocalGroup(), LocalGroup()]) + c1 = Node(id=1) + c1.parent = f1 + + sess.add_all([f1, c1]) + + sess.commit() + + def test_emit_lazy_loadonpk_parent(self): + Node, LocalGroup = self.classes("Node", "LocalGroup") + + s = fixture_session() + c1 = s.query(Node).filter_by(id=1).first() + + def go(): + p1 = c1.parent + eq_(p1, Node(id=2, local_groups=[LocalGroup(), LocalGroup()])) + + self.assert_sql_count(testing.db, go, 1) + + class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults): @classmethod def define_tables(cls, metadata):