From: bekapono Date: Mon, 18 May 2026 16:29:50 +0000 (-0400) Subject: omit_join optimization for selectinload on many-to-many relationships X-Git-Url: http://git.ipfire.org/gitweb/index.cgi?a=commitdiff_plain;h=808fd28297f36bf932443bae77ca5bb16bcbd4dd;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git omit_join optimization for selectinload on many-to-many relationships The :func:`.selectinload` loader strategy now selects the ``omit_join`` optimization for many-to-many non-self-referential relationships, reducing the number of joins in the secondary SELECT by selecting from the secondary table directly rather than joining back to the parent entity. ``omit_join`` is enabled automatically when the join condition determines that the secondary table's foreign keys fully cover the parent's primary key. As always, ``omit_join`` can be disabled by setting :paramref:`.relationship.omit_join` to ``False``. Pull request courtesy bekapono. Fixes: #5987 Closes: #13278 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13278 Pull-request-sha: fdd9847e7db041be51160853c075b1427f7be051 Change-Id: Ib68f8e2be1399222383cdd7b55793fe88402212c --- diff --git a/doc/build/changelog/unreleased_21/5987.rst b/doc/build/changelog/unreleased_21/5987.rst new file mode 100644 index 0000000000..3ff37340f1 --- /dev/null +++ b/doc/build/changelog/unreleased_21/5987.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: usecase, orm, performance + :tickets: 5987 + + The :func:`.selectinload` loader strategy now selects the ``omit_join`` + optimization for many-to-many non-self-referential relationships, reducing + the number of joins in the secondary SELECT by selecting from the secondary + table directly rather than joining back to the parent entity. ``omit_join`` + is enabled automatically when the join condition determines that the + secondary table's foreign keys fully cover the parent's primary key. As + always, ``omit_join`` can be disabled by setting + :paramref:`.relationship.omit_join` to ``False``. Pull request courtesy + bekapono. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 71bb459294..8609fd3545 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1952,6 +1952,16 @@ class Mapper( _validate_polymorphic_identity = None + @HasMemoized.memoized_attribute + def _local_pk_cols(self) -> set[Any]: + pk_cols = util.column_set(self.primary_key) + equiv = self._equivalent_columns + for pk_col in pk_cols.intersection(equiv): + pk_cols.update(equiv[pk_col]) + return util.column_set( + col for col in pk_cols if col.table is self.local_table + ) + @HasMemoized.memoized_attribute def _version_id_prop(self): if self.version_id_col is not None: diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 8ac91415b1..d81ae32ba5 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -2609,6 +2609,27 @@ class _JoinCondition: def _has_remote_annotations(self) -> bool: return self._has_annotation(self.primaryjoin, "remote") + @util.memoized_property + def secondary_covers_parent_primary_key(self) -> bool: + """Return True if the "secondary" selectable's join to the parent + table contains columns that encompass the complete primary key + value of the parent. + + Used in optimizing the selectinload loader strategy to indicate + the parent table need not be included in the query, as a complete + primary key can be derived from the secondary table. + + """ + if self.secondary is None: + return False + + secondary_synced_parent_cols = util.column_set( + l for (l, _) in self.synchronize_pairs + ) + return self.prop.parent._local_pk_cols.issubset( + secondary_synced_parent_cols + ) + def _annotate_fks(self) -> None: """Annotate the primaryjoin and secondaryjoin structures with 'foreign' annotations marking columns diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index d7672b3e4a..c79eda64f8 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -2996,6 +2996,7 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots): super().__init__(parent, strategy_key) self.join_depth = self.parent_property.join_depth is_m2o = self.parent_property.direction is interfaces.MANYTOONE + is_m2m = self.parent_property.direction is interfaces.MANYTOMANY if self.parent_property.omit_join is not None: self.omit_join = self.parent_property.omit_join @@ -3005,6 +3006,9 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots): ) if is_m2o: self.omit_join = lazyloader.use_get + elif is_m2m and not self.parent_property._is_self_referential: + join_cond = self.parent_property._join_condition + self.omit_join = join_cond.secondary_covers_parent_primary_key else: self.omit_join = self.parent._get_clause[0].compare( lazyloader._rev_lazywhere, @@ -3253,7 +3257,18 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots): }, ) - if not query_info.load_with_join: + if ( + self.parent_property.secondary is not None + and self.omit_join is True + ): + # The secondaryjoin condition is used to connect the + # secondary table to the related entity, + # and is required for composite foreign keys where SQLAlchemy + # cannot determine the join condition. + q = q.select_from(self.parent_property.secondary).join( + entity_sql, self.parent_property._join_condition.secondaryjoin + ) + elif not query_info.load_with_join: # the Bundle we have in the "omit_join" case is against raw, non # annotated columns, so to ensure the Query knows its primary # entity, we add it explicitly. If we made the Bundle against diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index 9bd97dd16f..7da30bc16d 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -1,4 +1,6 @@ import datetime +from typing import List +from typing import Optional import sqlalchemy as sa from sqlalchemy import and_ @@ -23,6 +25,8 @@ from sqlalchemy.orm import declarative_base from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import foreign from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import remote from sqlalchemy.orm import selectinload @@ -5636,6 +5640,471 @@ class InvalidRelationshipEscalationTestM2M( ) +class SecondaryCoversParentFlagTestM2M(fixtures.DeclarativeMappedTest): + def test_false_no_secondary(self): + Base = declarative_base() + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + bs: Mapped[list["B"]] = relationship("B", back_populates="a") + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + a: Mapped["A"] = relationship("A", back_populates="bs") + + join_cond = A.bs.property._join_condition + flag = join_cond.secondary_covers_parent_primary_key + + assert flag is False + + def test_joins_only_on_unique_non_pk(self): + Base = declarative_base() + + atob = Table( + "atob", + Base.metadata, + Column("a_code", ForeignKey("a.code")), + Column("b_id", ForeignKey("b.id")), + ) + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + code: Mapped[str] = mapped_column(unique=True) + bs: Mapped[list["B"]] = relationship("B", secondary=atob) + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + join_cond = A.bs.property._join_condition + flag = join_cond.secondary_covers_parent_primary_key + + assert flag is False + + def test_joins_only_on_unique_non_pk_to_subclass(self): + Base = declarative_base() + + a_b = Table( + "a_b", + Base.metadata, + Column( + "a_code", String, ForeignKey("a_child.code"), primary_key=True + ), + Column("b_id", Integer, ForeignKey("b.id"), primary_key=True), + ) + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + class AChild(A): + __tablename__ = "a_child" + id: Mapped[int] = mapped_column( + ForeignKey("a.id"), primary_key=True + ) + code: Mapped[str] = mapped_column(unique=True) + bs: Mapped[list["B"]] = relationship(secondary=a_b) + __mapper_args__ = {"polymorphic_identity": "a_child"} + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + join_cond = AChild.bs.property._join_condition + flag = join_cond.secondary_covers_parent_primary_key + + assert flag is False + + def test_joins_from_inherited_subclass(self): + Base = self.DeclarativeBasic + + a_b = Table( + "a_b", + Base.metadata, + Column( + "a_child_id", + String, + ForeignKey("a_child.id"), + primary_key=True, + ), + Column("b_id", Integer, ForeignKey("b.id"), primary_key=True), + ) + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + class AChild(A): + __tablename__ = "a_child" + id: Mapped[int] = mapped_column( + ForeignKey("a.id"), primary_key=True + ) + code: Mapped[str] = mapped_column(unique=True) + bs: Mapped[list["B"]] = relationship(secondary=a_b) + __mapper_args__ = {"polymorphic_identity": "a_child"} + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + join_cond = AChild.bs.property._join_condition + flag = join_cond.secondary_covers_parent_primary_key + + assert flag is True + + def test_joins_from_parent_with_inheritance(self): + Base = declarative_base() + + a_b = Table( + "a_b", + Base.metadata, + Column("a_id", Integer, ForeignKey("a.id"), primary_key=True), + Column("b_id", Integer, ForeignKey("b.id"), primary_key=True), + ) + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] = mapped_column() + bs = relationship("B", secondary=a_b) + __mapper_args__ = { + "polymorphic_on": type, + "polymorphic_identity": "parent", + } + + class AChild(A): + __tablename__ = "a_child" + id: Mapped[int] = mapped_column( + ForeignKey("a.id"), primary_key=True + ) + code: Mapped[str] = mapped_column(unique=True) + __mapper_args__ = {"polymorphic_identity": "child"} + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + join_cond = A.bs.property._join_condition + flag = join_cond.secondary_covers_parent_primary_key + + assert flag is True + + def test_joins_from_superclass_relationship_from_subclass(self): + """ + This is a special case where relationship bs is for + AChild, and AChild's pk is a fk from parent A. But + for table a_b the keys come from parent A and B. + + We are expecting the results to be False since + it should fail the `issubset` evaluation between AChild + and B. + + When forcing True to allow omit_join=True for this example, + the IN statement still works and gets optimized since + AChild.id = A.id + """ + Base = declarative_base() + + a_b = Table( + "a_b", + Base.metadata, + Column("a_id", Integer, ForeignKey("a.id"), primary_key=True), + Column("b_id", Integer, ForeignKey("b.id"), primary_key=True), + ) + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] = mapped_column() + __mapper_args__ = { + "polymorphic_on": type, + "polymorphic_identity": "parent", + } + + class AChild(A): + __tablename__ = "a_child" + id: Mapped[int] = mapped_column( + ForeignKey("a.id"), primary_key=True + ) + code: Mapped[str] = mapped_column(unique=True) + bs = relationship("B", secondary=a_b) + __mapper_args__ = {"polymorphic_identity": "child"} + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + join_cond = AChild.bs.property._join_condition + flag = join_cond.secondary_covers_parent_primary_key + + assert flag is False + + def test_joins_from_composite_pk(self): + Base = declarative_base() + + association_table = Table( + "a_b", + Base.metadata, + Column("a_id1", Integer, ForeignKey("a.id1")), + Column("a_id2", Integer, ForeignKey("a.id2")), + Column("b_id", Integer, ForeignKey("b.id")), + ) + + class A(Base): + __tablename__ = "a" + id1: Mapped[int] = mapped_column(primary_key=True) + id2: Mapped[int] = mapped_column(primary_key=True) + bs = relationship( + "B", + secondary=association_table, + primaryjoin=lambda: and_( + A.id1 == association_table.c.a_id1, + A.id2 == association_table.c.a_id2, + ), + secondaryjoin=lambda: B.id == association_table.c.b_id, + ) + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + join_cond = A.bs.property._join_condition + flag = join_cond.secondary_covers_parent_primary_key + + assert flag is True + + def test_only_partial_of_composite_pk(self): + Base = declarative_base() + + association_table = Table( + "a_b", + Base.metadata, + Column("a_id1", Integer, ForeignKey("a.id1")), + Column("b_id", Integer, ForeignKey("b.id")), + ) + + class A(Base): + __tablename__ = "a" + id1: Mapped[int] = mapped_column(primary_key=True) + id2: Mapped[int] = mapped_column(primary_key=True) + bs = relationship( + "B", + secondary=association_table, + primaryjoin=lambda: A.id1 == association_table.c.a_id1, + secondaryjoin=lambda: B.id == association_table.c.b_id, + ) + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + join_cond = A.bs.property._join_condition + flag = join_cond.secondary_covers_parent_primary_key + + assert flag is False + + def test_simple(self): + Base = declarative_base() + + atob = Table( + "atob", + Base.metadata, + Column("a_id", ForeignKey("a.id")), + Column("b_id", ForeignKey("b.id")), + ) + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + bs: Mapped[list["B"]] = relationship("B", secondary=atob) + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + join_cond = A.bs.property._join_condition + flag = join_cond.secondary_covers_parent_primary_key + + assert flag is True + + def test_bidirectional(self): + Base = declarative_base() + + atob = Table( + "atob", + Base.metadata, + Column("a_id", ForeignKey("a.id")), + Column("b_id", ForeignKey("b.id")), + ) + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + bs: Mapped[list["B"]] = relationship( + "B", secondary=atob, back_populates="as_" + ) + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + as_: Mapped[list["A"]] = relationship( + "A", secondary=atob, back_populates="bs" + ) + + join_cond = A.bs.property._join_condition + a_flag = join_cond.secondary_covers_parent_primary_key + + join_cond = B.as_.property._join_condition + b_flag = join_cond.secondary_covers_parent_primary_key + + assert a_flag is True + assert b_flag is True + + def test_viewonly_association_table(self): + """ + Example pulled from Sqlalchemy documentation, with the only + changes was to include "viewonly=True" to the secondary + relationship. + """ + Base = declarative_base() + + class Association(Base): + __tablename__ = "association_table" + + left_id: Mapped[int] = mapped_column( + ForeignKey("left_table.id"), primary_key=True + ) + right_id: Mapped[int] = mapped_column( + ForeignKey("right_table.id"), primary_key=True + ) + extra_data: Mapped[Optional[str]] + + # association between Association -> Child + child: Mapped["Child"] = relationship( + back_populates="parent_associations" + ) + + # association between Association -> Parent + parent: Mapped["Parent"] = relationship( + back_populates="child_associations" + ) + + class Parent(Base): + __tablename__ = "left_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + # many-to-many relationship to Child, + # bypassing the `Association` class + children: Mapped[List["Child"]] = relationship( + secondary="association_table", + back_populates="parents", + viewonly=True, + ) + + # association between Parent -> Association -> Child + child_associations: Mapped[List["Association"]] = relationship( + back_populates="parent" + ) + + class Child(Base): + __tablename__ = "right_table" + + id: Mapped[int] = mapped_column(primary_key=True) + + # many-to-many relationship to Parent, + # bypassing the `Association` class + parents: Mapped[List["Parent"]] = relationship( + secondary="association_table", + back_populates="children", + viewonly=True, + ) + + # association between Child -> Association -> Parent + parent_associations: Mapped[List["Association"]] = relationship( + back_populates="child" + ) + + join_cond = Parent.children.property._join_condition + parent_child_flag = join_cond.secondary_covers_parent_primary_key + + join_cond = Child.parents.property._join_condition + child_parent_flag = join_cond.secondary_covers_parent_primary_key + + assert parent_child_flag is True + assert child_parent_flag is True + + def test_mapper_only_pk_true(self): + """secondary joins on the mapper-specified primary_key column, + not the table's actual primary key.""" + Base = declarative_base() + + atob = Table( + "atob", + Base.metadata, + Column("a_code", ForeignKey("a.code")), + Column("b_id", ForeignKey("b.id")), + ) + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + code: Mapped[str] = mapped_column(unique=True) + bs: Mapped[list["B"]] = relationship("B", secondary=atob) + __mapper_args__ = {"primary_key": "code"} + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + join_cond = A.bs.property._join_condition + flag = join_cond.secondary_covers_parent_primary_key + + assert flag is True + + def test_mapper_only_pk_false(self): + """secondary does not join on the mapper-specified primary_key + column, so the flag should be False.""" + Base = declarative_base() + + atob = Table( + "atob", + Base.metadata, + Column("a_id", ForeignKey("a.id")), + Column("b_id", ForeignKey("b.id")), + ) + + class A(Base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + code: Mapped[str] = mapped_column(unique=True) + bs: Mapped[list["B"]] = relationship("B", secondary=atob) + __mapper_args__ = {"primary_key": "code"} + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + + join_cond = A.bs.property._join_condition + flag = join_cond.secondary_covers_parent_primary_key + + assert flag is False + + class ActiveHistoryFlagTest(_fixtures.FixtureTest): run_inserts = None run_deletes = None @@ -6674,11 +7143,11 @@ class SecondaryIncludesLocalColsTest(fixtures.MappedTest): params=[{"id_1": 2}], ), CompiledSQL( - "SELECT a_1.id, b.id FROM a AS a_1 JOIN " + "SELECT anon_1.aid, b.id FROM " "(SELECT a.id AS aid, b.id AS id FROM a JOIN b ON a.b_ids " "LIKE (:id_1 || b.id || :param_1)) AS anon_1 " - "ON a_1.id = anon_1.aid JOIN b ON b.id = anon_1.id " - "WHERE a_1.id IN (__[POSTCOMPILE_primary_keys])", + "JOIN b ON b.id = anon_1.id " + "WHERE anon_1.aid IN (__[POSTCOMPILE_primary_keys])", params=[{"id_1": "%", "param_1": "%", "primary_keys": [2]}], ), ) diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py index 4ca4778c0a..ba1d7e4f56 100644 --- a/test/orm/test_selectin_relations.py +++ b/test/orm/test_selectin_relations.py @@ -1,5 +1,7 @@ import sqlalchemy as sa +from sqlalchemy import and_ from sqlalchemy import bindparam +from sqlalchemy import Boolean from sqlalchemy import ForeignKey from sqlalchemy import ForeignKeyConstraint from sqlalchemy import Integer @@ -20,6 +22,7 @@ from sqlalchemy.orm import undefer from sqlalchemy.orm import with_polymorphic from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import assert_warns +from sqlalchemy.testing import AssertsExecutionResults from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -3638,6 +3641,593 @@ class M2OWDegradeTest( ) +class M2MOmitJoinTest( + fixtures.TestBase, AssertsExecutionResults, testing.AssertsCompiledSQL +): + __dialect__ = "default" + + @testing.fixture + def simple_m2m(self, decl_base, connection): + association_table = Table( + "a_b", + decl_base.metadata, + Column("a_id", Integer, ForeignKey("a.id")), + Column("b_id", Integer, ForeignKey("b.id")), + ) + + class A(decl_base): + __tablename__ = "a" + id = Column(Integer, primary_key=True) + bs = relationship("B", secondary=association_table) + bs_no_omit_join = relationship( + "B", + secondary=association_table, + omit_join=False, + overlaps="bs", + ) + + class B(decl_base): + __tablename__ = "b" + id = Column(Integer, primary_key=True) + + decl_base.metadata.create_all(connection) + + with Session(connection) as session: + a1 = A(id=1) + a2 = A(id=2) + b1 = B(id=1) + b2 = B(id=2) + a1.bs = [b1, b2] + a2.bs = [b1] + session.add_all([a1, a2, b1, b2]) + session.commit() + + return A + + @testing.fixture + def symmetric_composite_m2m(self, decl_base, connection): + association_table = Table( + "a_b", + decl_base.metadata, + Column("a_id1", Integer, ForeignKey("a.id1")), + Column("a_id2", Integer, ForeignKey("a.id2")), + Column("b_id1", Integer, ForeignKey("b.id1")), + Column("b_id2", Integer, ForeignKey("b.id2")), + ) + + class B(decl_base): + __tablename__ = "b" + id1 = Column(Integer, primary_key=True) + id2 = Column(Integer, primary_key=True) + + class A(decl_base): + __tablename__ = "a" + id1 = Column(Integer, primary_key=True) + id2 = Column(Integer, primary_key=True) + bs = relationship( + "B", + secondary=association_table, + primaryjoin=lambda: and_( + A.id1 == association_table.c.a_id1, + A.id2 == association_table.c.a_id2, + ), + secondaryjoin=lambda: and_( + B.id1 == association_table.c.b_id1, + B.id2 == association_table.c.b_id2, + ), + ) + bs_no_omit_join = relationship( + "B", + secondary=association_table, + omit_join=False, + overlaps="bs", + primaryjoin=lambda: and_( + A.id1 == association_table.c.a_id1, + A.id2 == association_table.c.a_id2, + ), + secondaryjoin=lambda: and_( + B.id1 == association_table.c.b_id1, + B.id2 == association_table.c.b_id2, + ), + ) + + decl_base.metadata.create_all(connection) + + with Session(connection) as session: + a1 = A(id1=1, id2=1) + a2 = A(id1=1, id2=2) + b1 = B(id1=1, id2=1) + b2 = B(id1=1, id2=2) + a1.bs = [b1, b2] + a2.bs = [b1] + session.add_all([a1, a2, b1, b2]) + session.commit() + + return A + + @testing.fixture + def asymmetric_composite_m2m(self, decl_base, connection): + association_table = Table( + "a_b", + decl_base.metadata, + Column("a_id1", Integer, ForeignKey("a.id1")), + Column("a_id2", Integer, ForeignKey("a.id2")), + Column("b_id", Integer, ForeignKey("b.id")), + ) + + class A(decl_base): + __tablename__ = "a" + id1 = Column(Integer, primary_key=True) + id2 = Column(Integer, primary_key=True) + bs = relationship( + "B", + secondary=association_table, + primaryjoin=lambda: and_( + A.id1 == association_table.c.a_id1, + A.id2 == association_table.c.a_id2, + ), + secondaryjoin=(lambda: B.id == association_table.c.b_id), + ) + bs_no_omit_join = relationship( + "B", + secondary=association_table, + omit_join=False, + overlaps="bs", + primaryjoin=lambda: and_( + A.id1 == association_table.c.a_id1, + A.id2 == association_table.c.a_id2, + ), + secondaryjoin=(lambda: B.id == association_table.c.b_id), + ) + + class B(decl_base): + __tablename__ = "b" + id = Column(Integer, primary_key=True) + + decl_base.metadata.create_all(connection) + + with Session(connection) as session: + a1 = A(id1=1, id2=1) + a2 = A(id1=1, id2=2) + b1 = B(id=1) + b2 = B(id=2) + b3 = B(id=3) + a1.bs = [b1, b2, b3] + a2.bs = [b2, b3] + session.add_all([a1, a2, b1, b2, b3]) + session.commit() + + return A + + @testing.fixture + def reverse_asymmetric_composite_m2m(self, decl_base, connection): + association_table = Table( + "a_b", + decl_base.metadata, + Column("a_id", Integer, ForeignKey("a.id")), + Column("b_id1", Integer, ForeignKey("b.id1")), + Column("b_id2", Integer, ForeignKey("b.id2")), + ) + + class A(decl_base): + __tablename__ = "a" + id = Column(Integer, primary_key=True) + bs = relationship( + "B", + secondary=association_table, + primaryjoin=(lambda: A.id == association_table.c.a_id), + secondaryjoin=lambda: and_( + B.id1 == association_table.c.b_id1, + B.id2 == association_table.c.b_id2, + ), + ) + bs_no_omit_join = relationship( + "B", + secondary=association_table, + omit_join=False, + overlaps="bs", + primaryjoin=(lambda: A.id == association_table.c.a_id), + secondaryjoin=lambda: and_( + B.id1 == association_table.c.b_id1, + B.id2 == association_table.c.b_id2, + ), + ) + + class B(decl_base): + __tablename__ = "b" + id1 = Column(Integer, primary_key=True) + id2 = Column(Integer, primary_key=True) + + decl_base.metadata.create_all(connection) + + with Session(connection) as session: + a1 = A(id=1) + a2 = A(id=2) + b1 = B(id1=1, id2=1) + b2 = B(id1=1, id2=2) + b3 = B(id1=2, id2=1) + a1.bs = [b1, b2, b3] + a2.bs = [b2, b3] + session.add_all([a1, a2, b1, b2, b3]) + session.commit() + + return A + + @testing.fixture + def filtered_secondaryjoin_m2m(self, decl_base, connection): + association_table = Table( + "a_b", + decl_base.metadata, + Column("a_id", Integer, ForeignKey("a.id")), + Column("b_id", Integer, ForeignKey("b.id")), + ) + + class A(decl_base): + __tablename__ = "a" + id = Column(Integer, primary_key=True) + bs = relationship( + "B", + secondary=association_table, + primaryjoin=(lambda: A.id == association_table.c.a_id), + secondaryjoin=lambda: and_( + B.id == association_table.c.b_id, + B.active == True, # noqa: E712 + ), + ) + bs_no_omit_join = relationship( + "B", + secondary=association_table, + omit_join=False, + overlaps="bs", + primaryjoin=(lambda: A.id == association_table.c.a_id), + secondaryjoin=lambda: and_( + B.id == association_table.c.b_id, + B.active == True, # noqa: E712 + ), + ) + + class B(decl_base): + __tablename__ = "b" + id = Column(Integer, primary_key=True) + active = Column(Boolean, default=True) + + decl_base.metadata.create_all(connection) + + with Session(connection) as session: + a1 = A(id=1) + a2 = A(id=2) + b1 = B(id=1, active=True) + b2 = B(id=2, active=False) + b3 = B(id=3, active=True) + a1.bs = [b1, b2, b3] + a2.bs = [b2, b3] + session.add_all([a1, a2, b1, b2, b3]) + session.commit() + + return A + + def test_simple_optimized(self, simple_m2m, connection): + A = simple_m2m + + with Session(connection) as session: + + def go(): + statement = ( + select(A).options(selectinload(A.bs)).order_by(A.id) + ) + session.execute(statement).scalars().all() + + self.assert_sql_execution( + connection, + go, + CompiledSQL("SELECT a.id FROM a ORDER BY a.id", {}), + CompiledSQL( + "SELECT a_b.a_id, b.id " + "FROM a_b JOIN b ON b.id = a_b.b_id " + "WHERE a_b.a_id IN " + "(__[POSTCOMPILE_primary_keys])", + {"primary_keys": [1, 2]}, + ), + ) + + def test_simple_unoptimized(self, simple_m2m, connection): + A = simple_m2m + + with Session(connection) as session: + + def go(): + statement = ( + select(A) + .options(selectinload(A.bs_no_omit_join)) + .order_by(A.id) + ) + session.execute(statement).scalars().all() + + self.assert_sql_execution( + connection, + go, + CompiledSQL("SELECT a.id FROM a ORDER BY a.id", {}), + CompiledSQL( + "SELECT a_1.id, b.id " + "FROM a AS a_1 " + "JOIN a_b AS a_b_1 ON a_1.id = a_b_1.a_id " + "JOIN b ON b.id = a_b_1.b_id " + "WHERE a_1.id IN " + "(__[POSTCOMPILE_primary_keys])", + {"primary_keys": [1, 2]}, + ), + ) + + def test_symmetric_composite(self, symmetric_composite_m2m, connection): + A = symmetric_composite_m2m + + with Session(connection) as session: + + def go(): + statement = ( + select(A) + .options(selectinload(A.bs)) + .order_by(A.id1, A.id2) + ) + session.execute(statement).scalars().all() + + self.assert_sql_execution( + connection, + go, + CompiledSQL( + "SELECT a.id1, a.id2 " "FROM a ORDER BY a.id1, a.id2", + {}, + ), + CompiledSQL( + "SELECT a_b.a_id1, a_b.a_id2, b.id1, b.id2 " + "FROM a_b " + "JOIN b ON b.id1 = a_b.b_id1 " + "AND b.id2 = a_b.b_id2 " + "WHERE (a_b.a_id1, a_b.a_id2) IN " + "(__[POSTCOMPILE_primary_keys])", + {"primary_keys": [(1, 1), (1, 2)]}, + ), + ) + + def test_symmetric_composite_unoptimized( + self, symmetric_composite_m2m, connection + ): + A = symmetric_composite_m2m + + with Session(connection) as session: + + def go(): + statement = ( + select(A) + .options(selectinload(A.bs_no_omit_join)) + .order_by(A.id1, A.id2) + ) + session.execute(statement).scalars().all() + + self.assert_sql_execution( + connection, + go, + CompiledSQL( + "SELECT a.id1, a.id2 " "FROM a ORDER BY a.id1, a.id2", + {}, + ), + CompiledSQL( + "SELECT a_1.id1, a_1.id2, b.id1, b.id2 " + "FROM a AS a_1 JOIN a_b AS a_b_1 ON " + "a_1.id1 = a_b_1.a_id1 " + "AND a_1.id2 = a_b_1.a_id2 " + "JOIN b ON b.id1 = a_b_1.b_id1 " + "AND b.id2 = a_b_1.b_id2 " + "WHERE (a_1.id1, a_1.id2) IN " + "(__[POSTCOMPILE_primary_keys])", + {"primary_keys": [(1, 1), (1, 2)]}, + ), + ) + + def test_asymmetric_composite(self, asymmetric_composite_m2m, connection): + A = asymmetric_composite_m2m + + with Session(connection) as session: + + def go(): + statement = ( + select(A) + .options(selectinload(A.bs)) + .order_by(A.id1, A.id2) + ) + session.execute(statement).scalars().all() + + self.assert_sql_execution( + connection, + go, + CompiledSQL( + "SELECT a.id1, a.id2 FROM a ORDER BY a.id1, a.id2", + {}, + ), + CompiledSQL( + "SELECT a_b.a_id1, a_b.a_id2, b.id " + "FROM a_b JOIN b ON b.id = a_b.b_id " + "WHERE (a_b.a_id1, a_b.a_id2) IN " + "(__[POSTCOMPILE_primary_keys])", + {"primary_keys": [(1, 1), (1, 2)]}, + ), + ) + + def test_asymmetric_composite_unoptimized( + self, asymmetric_composite_m2m, connection + ): + A = asymmetric_composite_m2m + + with Session(connection) as session: + + def go(): + statement = ( + select(A) + .options(selectinload(A.bs_no_omit_join)) + .order_by(A.id1, A.id2) + ) + session.execute(statement).scalars().all() + + self.assert_sql_execution( + connection, + go, + CompiledSQL( + "SELECT a.id1, a.id2 FROM a ORDER BY a.id1, a.id2", + {}, + ), + CompiledSQL( + "SELECT a_1.id1, a_1.id2, b.id " + "FROM a AS a_1 JOIN a_b AS a_b_1 " + "ON a_1.id1 = a_b_1.a_id1 " + "AND a_1.id2 = a_b_1.a_id2 " + "JOIN b ON b.id = a_b_1.b_id " + "WHERE (a_1.id1, a_1.id2) IN " + "(__[POSTCOMPILE_primary_keys])", + {"primary_keys": [(1, 1), (1, 2)]}, + ), + ) + + def test_reverse_asymmetric_composite( + self, reverse_asymmetric_composite_m2m, connection + ): + A = reverse_asymmetric_composite_m2m + + with Session(connection) as session: + + def go(): + statement = ( + select(A).options(selectinload(A.bs)).order_by(A.id) + ) + session.execute(statement).scalars().all() + + self.assert_sql_execution( + connection, + go, + CompiledSQL( + "SELECT a.id FROM a ORDER BY a.id", + {}, + ), + CompiledSQL( + "SELECT a_b.a_id, b.id1, b.id2 " + "FROM a_b " + "JOIN b ON b.id1 = a_b.b_id1 " + "AND b.id2 = a_b.b_id2 " + "WHERE a_b.a_id IN " + "(__[POSTCOMPILE_primary_keys])", + {"primary_keys": [1, 2]}, + ), + ) + + def test_reverse_asymmetric_composite_unoptimized( + self, reverse_asymmetric_composite_m2m, connection + ): + A = reverse_asymmetric_composite_m2m + + with Session(connection) as session: + + def go(): + statement = ( + select(A) + .options(selectinload(A.bs_no_omit_join)) + .order_by(A.id) + ) + session.execute(statement).scalars().all() + + self.assert_sql_execution( + connection, + go, + CompiledSQL( + "SELECT a.id FROM a ORDER BY a.id", + {}, + ), + CompiledSQL( + "SELECT a_1.id, b.id1, b.id2 " + "FROM a AS a_1 " + "JOIN a_b AS a_b_1 ON a_1.id = a_b_1.a_id " + "JOIN b ON b.id1 = a_b_1.b_id1 " + "AND b.id2 = a_b_1.b_id2 " + "WHERE a_1.id IN " + "(__[POSTCOMPILE_primary_keys])", + {"primary_keys": [1, 2]}, + ), + ) + + def test_filtered_secondaryjoin( + self, filtered_secondaryjoin_m2m, connection + ): + A = filtered_secondaryjoin_m2m + + with Session(connection) as session: + + def go(): + statement = ( + select(A).options(selectinload(A.bs)).order_by(A.id) + ) + return session.execute(statement).scalars().all() + + results = self.assert_sql_execution( + connection, + go, + CompiledSQL( + "SELECT a.id FROM a ORDER BY a.id", + {}, + ), + CompiledSQL( + "SELECT a_b.a_id, b.id, b.active " + "FROM a_b " + "JOIN b ON b.id = a_b.b_id AND b.active = 1 " + "WHERE a_b.a_id IN " + "(__[POSTCOMPILE_primary_keys])", + {"primary_keys": [1, 2]}, + ), + ) + + eq_(sorted(b.id for b in results[0].bs), [1, 3]) + eq_(sorted(b.id for b in results[1].bs), [3]) + + def test_filtered_secondaryjoin_unoptimized( + self, filtered_secondaryjoin_m2m, connection + ): + A = filtered_secondaryjoin_m2m + + with Session(connection) as session: + + def go(): + statement = ( + select(A) + .options(selectinload(A.bs_no_omit_join)) + .order_by(A.id) + ) + return session.execute(statement).scalars().all() + + results = self.assert_sql_execution( + connection, + go, + CompiledSQL( + "SELECT a.id FROM a ORDER BY a.id", + {}, + ), + CompiledSQL( + "SELECT a_1.id, b.id, b.active " + "FROM a AS a_1 " + "JOIN a_b AS a_b_1 ON a_1.id = a_b_1.a_id " + "JOIN b ON b.id = a_b_1.b_id AND b.active = 1 " + "WHERE a_1.id IN " + "(__[POSTCOMPILE_primary_keys])", + {"primary_keys": [1, 2]}, + ), + ) + + eq_( + sorted(b.id for b in results[0].bs_no_omit_join), + [1, 3], + ) + eq_( + sorted(b.id for b in results[1].bs_no_omit_join), + [3], + ) + + class SameNamePolymorphicTest(fixtures.DeclarativeMappedTest): @classmethod def setup_classes(cls):