]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
omit_join optimization for selectinload on many-to-many relationships
authorbekapono <bsiliezar2@gmail.com>
Mon, 18 May 2026 16:29:50 +0000 (12:29 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 27 May 2026 17:23:10 +0000 (13:23 -0400)
The :func:`.selectinload` loader strategy now selects the ``omit_join``
optimization for many-to-many non-self-referential relationships, reducing
the number of joins in the secondary SELECT by selecting from the secondary
table directly rather than joining back to the parent entity. ``omit_join``
is enabled automatically when the join condition determines that the
secondary table's foreign keys fully cover the parent's primary key. As
always, ``omit_join`` can be disabled by setting
:paramref:`.relationship.omit_join` to ``False``. Pull request courtesy
bekapono.

Fixes: #5987
Closes: #13278
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13278
Pull-request-sha: fdd9847e7db041be51160853c075b1427f7be051

Change-Id: Ib68f8e2be1399222383cdd7b55793fe88402212c

doc/build/changelog/unreleased_21/5987.rst [new file with mode: 0644]
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/strategies.py
test/orm/test_relationships.py
test/orm/test_selectin_relations.py

diff --git a/doc/build/changelog/unreleased_21/5987.rst b/doc/build/changelog/unreleased_21/5987.rst
new file mode 100644 (file)
index 0000000..3ff3734
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+  :tags: usecase, orm, performance
+  :tickets: 5987
+
+  The :func:`.selectinload` loader strategy now selects the ``omit_join``
+  optimization for many-to-many non-self-referential relationships, reducing
+  the number of joins in the secondary SELECT by selecting from the secondary
+  table directly rather than joining back to the parent entity. ``omit_join``
+  is enabled automatically when the join condition determines that the
+  secondary table's foreign keys fully cover the parent's primary key. As
+  always, ``omit_join`` can be disabled by setting
+  :paramref:`.relationship.omit_join` to ``False``. Pull request courtesy
+  bekapono.
index 71bb4592945850ccf4981a1cf4920fde31c58622..8609fd35451cc5959e6ad3927ba47b940c6dec80 100644 (file)
@@ -1952,6 +1952,16 @@ class Mapper(
 
     _validate_polymorphic_identity = None
 
+    @HasMemoized.memoized_attribute
+    def _local_pk_cols(self) -> set[Any]:
+        pk_cols = util.column_set(self.primary_key)
+        equiv = self._equivalent_columns
+        for pk_col in pk_cols.intersection(equiv):
+            pk_cols.update(equiv[pk_col])
+        return util.column_set(
+            col for col in pk_cols if col.table is self.local_table
+        )
+
     @HasMemoized.memoized_attribute
     def _version_id_prop(self):
         if self.version_id_col is not None:
index 8ac91415b1c55255c7e31c47ba58b9a49ce291d6..d81ae32ba5504ecbd873924e76e4a13b018912a2 100644 (file)
@@ -2609,6 +2609,27 @@ class _JoinCondition:
     def _has_remote_annotations(self) -> bool:
         return self._has_annotation(self.primaryjoin, "remote")
 
+    @util.memoized_property
+    def secondary_covers_parent_primary_key(self) -> bool:
+        """Return True if the "secondary" selectable's join to the parent
+        table contains columns that encompass the complete primary key
+        value of the parent.
+
+        Used in optimizing the selectinload loader strategy to indicate
+        the parent table need not be included in the query, as a complete
+        primary key can be derived from the secondary table.
+
+        """
+        if self.secondary is None:
+            return False
+
+        secondary_synced_parent_cols = util.column_set(
+            l for (l, _) in self.synchronize_pairs
+        )
+        return self.prop.parent._local_pk_cols.issubset(
+            secondary_synced_parent_cols
+        )
+
     def _annotate_fks(self) -> None:
         """Annotate the primaryjoin and secondaryjoin
         structures with 'foreign' annotations marking columns
index d7672b3e4a052d2ad98a24880e9ad306be5d0663..c79eda64f8c0bb91da935466c6e44dbcbb41b9fc 100644 (file)
@@ -2996,6 +2996,7 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots):
         super().__init__(parent, strategy_key)
         self.join_depth = self.parent_property.join_depth
         is_m2o = self.parent_property.direction is interfaces.MANYTOONE
+        is_m2m = self.parent_property.direction is interfaces.MANYTOMANY
 
         if self.parent_property.omit_join is not None:
             self.omit_join = self.parent_property.omit_join
@@ -3005,6 +3006,9 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots):
             )
             if is_m2o:
                 self.omit_join = lazyloader.use_get
+            elif is_m2m and not self.parent_property._is_self_referential:
+                join_cond = self.parent_property._join_condition
+                self.omit_join = join_cond.secondary_covers_parent_primary_key
             else:
                 self.omit_join = self.parent._get_clause[0].compare(
                     lazyloader._rev_lazywhere,
@@ -3253,7 +3257,18 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots):
             },
         )
 
-        if not query_info.load_with_join:
+        if (
+            self.parent_property.secondary is not None
+            and self.omit_join is True
+        ):
+            # The secondaryjoin condition is used to connect the
+            # secondary table to the related entity,
+            # and is required for composite foreign keys where SQLAlchemy
+            # cannot determine the join condition.
+            q = q.select_from(self.parent_property.secondary).join(
+                entity_sql, self.parent_property._join_condition.secondaryjoin
+            )
+        elif not query_info.load_with_join:
             # the Bundle we have in the "omit_join" case is against raw, non
             # annotated columns, so to ensure the Query knows its primary
             # entity, we add it explicitly.  If we made the Bundle against
index 9bd97dd16f7cfa46599a7370d725838cee47c761..7da30bc16d3f0324e92c15ba79aae309c5f911c4 100644 (file)
@@ -1,4 +1,6 @@
 import datetime
+from typing import List
+from typing import Optional
 
 import sqlalchemy as sa
 from sqlalchemy import and_
@@ -23,6 +25,8 @@ from sqlalchemy.orm import declarative_base
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import foreign
 from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import remote
 from sqlalchemy.orm import selectinload
@@ -5636,6 +5640,471 @@ class InvalidRelationshipEscalationTestM2M(
         )
 
 
+class SecondaryCoversParentFlagTestM2M(fixtures.DeclarativeMappedTest):
+    def test_false_no_secondary(self):
+        Base = declarative_base()
+
+        class A(Base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            bs: Mapped[list["B"]] = relationship("B", back_populates="a")
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+            a: Mapped["A"] = relationship("A", back_populates="bs")
+
+        join_cond = A.bs.property._join_condition
+        flag = join_cond.secondary_covers_parent_primary_key
+
+        assert flag is False
+
+    def test_joins_only_on_unique_non_pk(self):
+        Base = declarative_base()
+
+        atob = Table(
+            "atob",
+            Base.metadata,
+            Column("a_code", ForeignKey("a.code")),
+            Column("b_id", ForeignKey("b.id")),
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            code: Mapped[str] = mapped_column(unique=True)
+            bs: Mapped[list["B"]] = relationship("B", secondary=atob)
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        join_cond = A.bs.property._join_condition
+        flag = join_cond.secondary_covers_parent_primary_key
+
+        assert flag is False
+
+    def test_joins_only_on_unique_non_pk_to_subclass(self):
+        Base = declarative_base()
+
+        a_b = Table(
+            "a_b",
+            Base.metadata,
+            Column(
+                "a_code", String, ForeignKey("a_child.code"), primary_key=True
+            ),
+            Column("b_id", Integer, ForeignKey("b.id"), primary_key=True),
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            type: Mapped[str]
+            __mapper_args__ = {
+                "polymorphic_on": "type",
+                "polymorphic_identity": "a",
+            }
+
+        class AChild(A):
+            __tablename__ = "a_child"
+            id: Mapped[int] = mapped_column(
+                ForeignKey("a.id"), primary_key=True
+            )
+            code: Mapped[str] = mapped_column(unique=True)
+            bs: Mapped[list["B"]] = relationship(secondary=a_b)
+            __mapper_args__ = {"polymorphic_identity": "a_child"}
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        join_cond = AChild.bs.property._join_condition
+        flag = join_cond.secondary_covers_parent_primary_key
+
+        assert flag is False
+
+    def test_joins_from_inherited_subclass(self):
+        Base = self.DeclarativeBasic
+
+        a_b = Table(
+            "a_b",
+            Base.metadata,
+            Column(
+                "a_child_id",
+                String,
+                ForeignKey("a_child.id"),
+                primary_key=True,
+            ),
+            Column("b_id", Integer, ForeignKey("b.id"), primary_key=True),
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            type: Mapped[str]
+            __mapper_args__ = {
+                "polymorphic_on": "type",
+                "polymorphic_identity": "a",
+            }
+
+        class AChild(A):
+            __tablename__ = "a_child"
+            id: Mapped[int] = mapped_column(
+                ForeignKey("a.id"), primary_key=True
+            )
+            code: Mapped[str] = mapped_column(unique=True)
+            bs: Mapped[list["B"]] = relationship(secondary=a_b)
+            __mapper_args__ = {"polymorphic_identity": "a_child"}
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        join_cond = AChild.bs.property._join_condition
+        flag = join_cond.secondary_covers_parent_primary_key
+
+        assert flag is True
+
+    def test_joins_from_parent_with_inheritance(self):
+        Base = declarative_base()
+
+        a_b = Table(
+            "a_b",
+            Base.metadata,
+            Column("a_id", Integer, ForeignKey("a.id"), primary_key=True),
+            Column("b_id", Integer, ForeignKey("b.id"), primary_key=True),
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            type: Mapped[str] = mapped_column()
+            bs = relationship("B", secondary=a_b)
+            __mapper_args__ = {
+                "polymorphic_on": type,
+                "polymorphic_identity": "parent",
+            }
+
+        class AChild(A):
+            __tablename__ = "a_child"
+            id: Mapped[int] = mapped_column(
+                ForeignKey("a.id"), primary_key=True
+            )
+            code: Mapped[str] = mapped_column(unique=True)
+            __mapper_args__ = {"polymorphic_identity": "child"}
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        join_cond = A.bs.property._join_condition
+        flag = join_cond.secondary_covers_parent_primary_key
+
+        assert flag is True
+
+    def test_joins_from_superclass_relationship_from_subclass(self):
+        """
+        This is a special case where relationship bs is for
+        AChild, and AChild's pk is a fk from parent A. But
+        for table a_b the keys come from parent A and B.
+
+        We are expecting the results to be False since
+        it should fail the `issubset` evaluation between AChild
+        and B.
+
+        When forcing True to allow omit_join=True for this example,
+        the IN statement still works and gets optimized since
+        AChild.id = A.id
+        """
+        Base = declarative_base()
+
+        a_b = Table(
+            "a_b",
+            Base.metadata,
+            Column("a_id", Integer, ForeignKey("a.id"), primary_key=True),
+            Column("b_id", Integer, ForeignKey("b.id"), primary_key=True),
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            type: Mapped[str] = mapped_column()
+            __mapper_args__ = {
+                "polymorphic_on": type,
+                "polymorphic_identity": "parent",
+            }
+
+        class AChild(A):
+            __tablename__ = "a_child"
+            id: Mapped[int] = mapped_column(
+                ForeignKey("a.id"), primary_key=True
+            )
+            code: Mapped[str] = mapped_column(unique=True)
+            bs = relationship("B", secondary=a_b)
+            __mapper_args__ = {"polymorphic_identity": "child"}
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        join_cond = AChild.bs.property._join_condition
+        flag = join_cond.secondary_covers_parent_primary_key
+
+        assert flag is False
+
+    def test_joins_from_composite_pk(self):
+        Base = declarative_base()
+
+        association_table = Table(
+            "a_b",
+            Base.metadata,
+            Column("a_id1", Integer, ForeignKey("a.id1")),
+            Column("a_id2", Integer, ForeignKey("a.id2")),
+            Column("b_id", Integer, ForeignKey("b.id")),
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id1: Mapped[int] = mapped_column(primary_key=True)
+            id2: Mapped[int] = mapped_column(primary_key=True)
+            bs = relationship(
+                "B",
+                secondary=association_table,
+                primaryjoin=lambda: and_(
+                    A.id1 == association_table.c.a_id1,
+                    A.id2 == association_table.c.a_id2,
+                ),
+                secondaryjoin=lambda: B.id == association_table.c.b_id,
+            )
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        join_cond = A.bs.property._join_condition
+        flag = join_cond.secondary_covers_parent_primary_key
+
+        assert flag is True
+
+    def test_only_partial_of_composite_pk(self):
+        Base = declarative_base()
+
+        association_table = Table(
+            "a_b",
+            Base.metadata,
+            Column("a_id1", Integer, ForeignKey("a.id1")),
+            Column("b_id", Integer, ForeignKey("b.id")),
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id1: Mapped[int] = mapped_column(primary_key=True)
+            id2: Mapped[int] = mapped_column(primary_key=True)
+            bs = relationship(
+                "B",
+                secondary=association_table,
+                primaryjoin=lambda: A.id1 == association_table.c.a_id1,
+                secondaryjoin=lambda: B.id == association_table.c.b_id,
+            )
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        join_cond = A.bs.property._join_condition
+        flag = join_cond.secondary_covers_parent_primary_key
+
+        assert flag is False
+
+    def test_simple(self):
+        Base = declarative_base()
+
+        atob = Table(
+            "atob",
+            Base.metadata,
+            Column("a_id", ForeignKey("a.id")),
+            Column("b_id", ForeignKey("b.id")),
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            bs: Mapped[list["B"]] = relationship("B", secondary=atob)
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        join_cond = A.bs.property._join_condition
+        flag = join_cond.secondary_covers_parent_primary_key
+
+        assert flag is True
+
+    def test_bidirectional(self):
+        Base = declarative_base()
+
+        atob = Table(
+            "atob",
+            Base.metadata,
+            Column("a_id", ForeignKey("a.id")),
+            Column("b_id", ForeignKey("b.id")),
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            bs: Mapped[list["B"]] = relationship(
+                "B", secondary=atob, back_populates="as_"
+            )
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            as_: Mapped[list["A"]] = relationship(
+                "A", secondary=atob, back_populates="bs"
+            )
+
+        join_cond = A.bs.property._join_condition
+        a_flag = join_cond.secondary_covers_parent_primary_key
+
+        join_cond = B.as_.property._join_condition
+        b_flag = join_cond.secondary_covers_parent_primary_key
+
+        assert a_flag is True
+        assert b_flag is True
+
+    def test_viewonly_association_table(self):
+        """
+        Example pulled from Sqlalchemy documentation, with the only
+        changes was to include "viewonly=True" to the secondary
+        relationship.
+        """
+        Base = declarative_base()
+
+        class Association(Base):
+            __tablename__ = "association_table"
+
+            left_id: Mapped[int] = mapped_column(
+                ForeignKey("left_table.id"), primary_key=True
+            )
+            right_id: Mapped[int] = mapped_column(
+                ForeignKey("right_table.id"), primary_key=True
+            )
+            extra_data: Mapped[Optional[str]]
+
+            # association between Association -> Child
+            child: Mapped["Child"] = relationship(
+                back_populates="parent_associations"
+            )
+
+            # association between Association -> Parent
+            parent: Mapped["Parent"] = relationship(
+                back_populates="child_associations"
+            )
+
+        class Parent(Base):
+            __tablename__ = "left_table"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+            # many-to-many relationship to Child,
+            # bypassing the `Association` class
+            children: Mapped[List["Child"]] = relationship(
+                secondary="association_table",
+                back_populates="parents",
+                viewonly=True,
+            )
+
+            # association between Parent -> Association -> Child
+            child_associations: Mapped[List["Association"]] = relationship(
+                back_populates="parent"
+            )
+
+        class Child(Base):
+            __tablename__ = "right_table"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+            # many-to-many relationship to Parent,
+            # bypassing the `Association` class
+            parents: Mapped[List["Parent"]] = relationship(
+                secondary="association_table",
+                back_populates="children",
+                viewonly=True,
+            )
+
+            # association between Child -> Association -> Parent
+            parent_associations: Mapped[List["Association"]] = relationship(
+                back_populates="child"
+            )
+
+        join_cond = Parent.children.property._join_condition
+        parent_child_flag = join_cond.secondary_covers_parent_primary_key
+
+        join_cond = Child.parents.property._join_condition
+        child_parent_flag = join_cond.secondary_covers_parent_primary_key
+
+        assert parent_child_flag is True
+        assert child_parent_flag is True
+
+    def test_mapper_only_pk_true(self):
+        """secondary joins on the mapper-specified primary_key column,
+        not the table's actual primary key."""
+        Base = declarative_base()
+
+        atob = Table(
+            "atob",
+            Base.metadata,
+            Column("a_code", ForeignKey("a.code")),
+            Column("b_id", ForeignKey("b.id")),
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            code: Mapped[str] = mapped_column(unique=True)
+            bs: Mapped[list["B"]] = relationship("B", secondary=atob)
+            __mapper_args__ = {"primary_key": "code"}
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        join_cond = A.bs.property._join_condition
+        flag = join_cond.secondary_covers_parent_primary_key
+
+        assert flag is True
+
+    def test_mapper_only_pk_false(self):
+        """secondary does not join on the mapper-specified primary_key
+        column, so the flag should be False."""
+        Base = declarative_base()
+
+        atob = Table(
+            "atob",
+            Base.metadata,
+            Column("a_id", ForeignKey("a.id")),
+            Column("b_id", ForeignKey("b.id")),
+        )
+
+        class A(Base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            code: Mapped[str] = mapped_column(unique=True)
+            bs: Mapped[list["B"]] = relationship("B", secondary=atob)
+            __mapper_args__ = {"primary_key": "code"}
+
+        class B(Base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        join_cond = A.bs.property._join_condition
+        flag = join_cond.secondary_covers_parent_primary_key
+
+        assert flag is False
+
+
 class ActiveHistoryFlagTest(_fixtures.FixtureTest):
     run_inserts = None
     run_deletes = None
@@ -6674,11 +7143,11 @@ class SecondaryIncludesLocalColsTest(fixtures.MappedTest):
                 params=[{"id_1": 2}],
             ),
             CompiledSQL(
-                "SELECT a_1.id, b.id FROM a AS a_1 JOIN "
+                "SELECT anon_1.aid, b.id FROM "
                 "(SELECT a.id AS aid, b.id AS id FROM a JOIN b ON a.b_ids "
                 "LIKE (:id_1 || b.id || :param_1)) AS anon_1 "
-                "ON a_1.id = anon_1.aid JOIN b ON b.id = anon_1.id "
-                "WHERE a_1.id IN (__[POSTCOMPILE_primary_keys])",
+                "JOIN b ON b.id = anon_1.id "
+                "WHERE anon_1.aid IN (__[POSTCOMPILE_primary_keys])",
                 params=[{"id_1": "%", "param_1": "%", "primary_keys": [2]}],
             ),
         )
index 4ca4778c0a54329997bb24014fc9ad4a53473537..ba1d7e4f5676964ce4b36a1c8e70149ad70b12b2 100644 (file)
@@ -1,5 +1,7 @@
 import sqlalchemy as sa
+from sqlalchemy import and_
 from sqlalchemy import bindparam
+from sqlalchemy import Boolean
 from sqlalchemy import ForeignKey
 from sqlalchemy import ForeignKeyConstraint
 from sqlalchemy import Integer
@@ -20,6 +22,7 @@ from sqlalchemy.orm import undefer
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import assert_warns
+from sqlalchemy.testing import AssertsExecutionResults
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -3638,6 +3641,593 @@ class M2OWDegradeTest(
         )
 
 
+class M2MOmitJoinTest(
+    fixtures.TestBase, AssertsExecutionResults, testing.AssertsCompiledSQL
+):
+    __dialect__ = "default"
+
+    @testing.fixture
+    def simple_m2m(self, decl_base, connection):
+        association_table = Table(
+            "a_b",
+            decl_base.metadata,
+            Column("a_id", Integer, ForeignKey("a.id")),
+            Column("b_id", Integer, ForeignKey("b.id")),
+        )
+
+        class A(decl_base):
+            __tablename__ = "a"
+            id = Column(Integer, primary_key=True)
+            bs = relationship("B", secondary=association_table)
+            bs_no_omit_join = relationship(
+                "B",
+                secondary=association_table,
+                omit_join=False,
+                overlaps="bs",
+            )
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id = Column(Integer, primary_key=True)
+
+        decl_base.metadata.create_all(connection)
+
+        with Session(connection) as session:
+            a1 = A(id=1)
+            a2 = A(id=2)
+            b1 = B(id=1)
+            b2 = B(id=2)
+            a1.bs = [b1, b2]
+            a2.bs = [b1]
+            session.add_all([a1, a2, b1, b2])
+            session.commit()
+
+        return A
+
+    @testing.fixture
+    def symmetric_composite_m2m(self, decl_base, connection):
+        association_table = Table(
+            "a_b",
+            decl_base.metadata,
+            Column("a_id1", Integer, ForeignKey("a.id1")),
+            Column("a_id2", Integer, ForeignKey("a.id2")),
+            Column("b_id1", Integer, ForeignKey("b.id1")),
+            Column("b_id2", Integer, ForeignKey("b.id2")),
+        )
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id1 = Column(Integer, primary_key=True)
+            id2 = Column(Integer, primary_key=True)
+
+        class A(decl_base):
+            __tablename__ = "a"
+            id1 = Column(Integer, primary_key=True)
+            id2 = Column(Integer, primary_key=True)
+            bs = relationship(
+                "B",
+                secondary=association_table,
+                primaryjoin=lambda: and_(
+                    A.id1 == association_table.c.a_id1,
+                    A.id2 == association_table.c.a_id2,
+                ),
+                secondaryjoin=lambda: and_(
+                    B.id1 == association_table.c.b_id1,
+                    B.id2 == association_table.c.b_id2,
+                ),
+            )
+            bs_no_omit_join = relationship(
+                "B",
+                secondary=association_table,
+                omit_join=False,
+                overlaps="bs",
+                primaryjoin=lambda: and_(
+                    A.id1 == association_table.c.a_id1,
+                    A.id2 == association_table.c.a_id2,
+                ),
+                secondaryjoin=lambda: and_(
+                    B.id1 == association_table.c.b_id1,
+                    B.id2 == association_table.c.b_id2,
+                ),
+            )
+
+        decl_base.metadata.create_all(connection)
+
+        with Session(connection) as session:
+            a1 = A(id1=1, id2=1)
+            a2 = A(id1=1, id2=2)
+            b1 = B(id1=1, id2=1)
+            b2 = B(id1=1, id2=2)
+            a1.bs = [b1, b2]
+            a2.bs = [b1]
+            session.add_all([a1, a2, b1, b2])
+            session.commit()
+
+        return A
+
+    @testing.fixture
+    def asymmetric_composite_m2m(self, decl_base, connection):
+        association_table = Table(
+            "a_b",
+            decl_base.metadata,
+            Column("a_id1", Integer, ForeignKey("a.id1")),
+            Column("a_id2", Integer, ForeignKey("a.id2")),
+            Column("b_id", Integer, ForeignKey("b.id")),
+        )
+
+        class A(decl_base):
+            __tablename__ = "a"
+            id1 = Column(Integer, primary_key=True)
+            id2 = Column(Integer, primary_key=True)
+            bs = relationship(
+                "B",
+                secondary=association_table,
+                primaryjoin=lambda: and_(
+                    A.id1 == association_table.c.a_id1,
+                    A.id2 == association_table.c.a_id2,
+                ),
+                secondaryjoin=(lambda: B.id == association_table.c.b_id),
+            )
+            bs_no_omit_join = relationship(
+                "B",
+                secondary=association_table,
+                omit_join=False,
+                overlaps="bs",
+                primaryjoin=lambda: and_(
+                    A.id1 == association_table.c.a_id1,
+                    A.id2 == association_table.c.a_id2,
+                ),
+                secondaryjoin=(lambda: B.id == association_table.c.b_id),
+            )
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id = Column(Integer, primary_key=True)
+
+        decl_base.metadata.create_all(connection)
+
+        with Session(connection) as session:
+            a1 = A(id1=1, id2=1)
+            a2 = A(id1=1, id2=2)
+            b1 = B(id=1)
+            b2 = B(id=2)
+            b3 = B(id=3)
+            a1.bs = [b1, b2, b3]
+            a2.bs = [b2, b3]
+            session.add_all([a1, a2, b1, b2, b3])
+            session.commit()
+
+        return A
+
+    @testing.fixture
+    def reverse_asymmetric_composite_m2m(self, decl_base, connection):
+        association_table = Table(
+            "a_b",
+            decl_base.metadata,
+            Column("a_id", Integer, ForeignKey("a.id")),
+            Column("b_id1", Integer, ForeignKey("b.id1")),
+            Column("b_id2", Integer, ForeignKey("b.id2")),
+        )
+
+        class A(decl_base):
+            __tablename__ = "a"
+            id = Column(Integer, primary_key=True)
+            bs = relationship(
+                "B",
+                secondary=association_table,
+                primaryjoin=(lambda: A.id == association_table.c.a_id),
+                secondaryjoin=lambda: and_(
+                    B.id1 == association_table.c.b_id1,
+                    B.id2 == association_table.c.b_id2,
+                ),
+            )
+            bs_no_omit_join = relationship(
+                "B",
+                secondary=association_table,
+                omit_join=False,
+                overlaps="bs",
+                primaryjoin=(lambda: A.id == association_table.c.a_id),
+                secondaryjoin=lambda: and_(
+                    B.id1 == association_table.c.b_id1,
+                    B.id2 == association_table.c.b_id2,
+                ),
+            )
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id1 = Column(Integer, primary_key=True)
+            id2 = Column(Integer, primary_key=True)
+
+        decl_base.metadata.create_all(connection)
+
+        with Session(connection) as session:
+            a1 = A(id=1)
+            a2 = A(id=2)
+            b1 = B(id1=1, id2=1)
+            b2 = B(id1=1, id2=2)
+            b3 = B(id1=2, id2=1)
+            a1.bs = [b1, b2, b3]
+            a2.bs = [b2, b3]
+            session.add_all([a1, a2, b1, b2, b3])
+            session.commit()
+
+        return A
+
+    @testing.fixture
+    def filtered_secondaryjoin_m2m(self, decl_base, connection):
+        association_table = Table(
+            "a_b",
+            decl_base.metadata,
+            Column("a_id", Integer, ForeignKey("a.id")),
+            Column("b_id", Integer, ForeignKey("b.id")),
+        )
+
+        class A(decl_base):
+            __tablename__ = "a"
+            id = Column(Integer, primary_key=True)
+            bs = relationship(
+                "B",
+                secondary=association_table,
+                primaryjoin=(lambda: A.id == association_table.c.a_id),
+                secondaryjoin=lambda: and_(
+                    B.id == association_table.c.b_id,
+                    B.active == True,  # noqa: E712
+                ),
+            )
+            bs_no_omit_join = relationship(
+                "B",
+                secondary=association_table,
+                omit_join=False,
+                overlaps="bs",
+                primaryjoin=(lambda: A.id == association_table.c.a_id),
+                secondaryjoin=lambda: and_(
+                    B.id == association_table.c.b_id,
+                    B.active == True,  # noqa: E712
+                ),
+            )
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id = Column(Integer, primary_key=True)
+            active = Column(Boolean, default=True)
+
+        decl_base.metadata.create_all(connection)
+
+        with Session(connection) as session:
+            a1 = A(id=1)
+            a2 = A(id=2)
+            b1 = B(id=1, active=True)
+            b2 = B(id=2, active=False)
+            b3 = B(id=3, active=True)
+            a1.bs = [b1, b2, b3]
+            a2.bs = [b2, b3]
+            session.add_all([a1, a2, b1, b2, b3])
+            session.commit()
+
+        return A
+
+    def test_simple_optimized(self, simple_m2m, connection):
+        A = simple_m2m
+
+        with Session(connection) as session:
+
+            def go():
+                statement = (
+                    select(A).options(selectinload(A.bs)).order_by(A.id)
+                )
+                session.execute(statement).scalars().all()
+
+            self.assert_sql_execution(
+                connection,
+                go,
+                CompiledSQL("SELECT a.id FROM a ORDER BY a.id", {}),
+                CompiledSQL(
+                    "SELECT a_b.a_id, b.id "
+                    "FROM a_b JOIN b ON b.id = a_b.b_id "
+                    "WHERE a_b.a_id IN "
+                    "(__[POSTCOMPILE_primary_keys])",
+                    {"primary_keys": [1, 2]},
+                ),
+            )
+
+    def test_simple_unoptimized(self, simple_m2m, connection):
+        A = simple_m2m
+
+        with Session(connection) as session:
+
+            def go():
+                statement = (
+                    select(A)
+                    .options(selectinload(A.bs_no_omit_join))
+                    .order_by(A.id)
+                )
+                session.execute(statement).scalars().all()
+
+            self.assert_sql_execution(
+                connection,
+                go,
+                CompiledSQL("SELECT a.id FROM a ORDER BY a.id", {}),
+                CompiledSQL(
+                    "SELECT a_1.id, b.id "
+                    "FROM a AS a_1 "
+                    "JOIN a_b AS a_b_1 ON a_1.id = a_b_1.a_id "
+                    "JOIN b ON b.id = a_b_1.b_id "
+                    "WHERE a_1.id IN "
+                    "(__[POSTCOMPILE_primary_keys])",
+                    {"primary_keys": [1, 2]},
+                ),
+            )
+
+    def test_symmetric_composite(self, symmetric_composite_m2m, connection):
+        A = symmetric_composite_m2m
+
+        with Session(connection) as session:
+
+            def go():
+                statement = (
+                    select(A)
+                    .options(selectinload(A.bs))
+                    .order_by(A.id1, A.id2)
+                )
+                session.execute(statement).scalars().all()
+
+            self.assert_sql_execution(
+                connection,
+                go,
+                CompiledSQL(
+                    "SELECT a.id1, a.id2 " "FROM a ORDER BY a.id1, a.id2",
+                    {},
+                ),
+                CompiledSQL(
+                    "SELECT a_b.a_id1, a_b.a_id2, b.id1, b.id2 "
+                    "FROM a_b "
+                    "JOIN b ON b.id1 = a_b.b_id1 "
+                    "AND b.id2 = a_b.b_id2 "
+                    "WHERE (a_b.a_id1, a_b.a_id2) IN "
+                    "(__[POSTCOMPILE_primary_keys])",
+                    {"primary_keys": [(1, 1), (1, 2)]},
+                ),
+            )
+
+    def test_symmetric_composite_unoptimized(
+        self, symmetric_composite_m2m, connection
+    ):
+        A = symmetric_composite_m2m
+
+        with Session(connection) as session:
+
+            def go():
+                statement = (
+                    select(A)
+                    .options(selectinload(A.bs_no_omit_join))
+                    .order_by(A.id1, A.id2)
+                )
+                session.execute(statement).scalars().all()
+
+            self.assert_sql_execution(
+                connection,
+                go,
+                CompiledSQL(
+                    "SELECT a.id1, a.id2 " "FROM a ORDER BY a.id1, a.id2",
+                    {},
+                ),
+                CompiledSQL(
+                    "SELECT a_1.id1, a_1.id2, b.id1, b.id2 "
+                    "FROM a AS a_1 JOIN a_b AS a_b_1 ON "
+                    "a_1.id1 = a_b_1.a_id1 "
+                    "AND a_1.id2 = a_b_1.a_id2 "
+                    "JOIN b ON b.id1 = a_b_1.b_id1 "
+                    "AND b.id2 = a_b_1.b_id2 "
+                    "WHERE (a_1.id1, a_1.id2) IN "
+                    "(__[POSTCOMPILE_primary_keys])",
+                    {"primary_keys": [(1, 1), (1, 2)]},
+                ),
+            )
+
+    def test_asymmetric_composite(self, asymmetric_composite_m2m, connection):
+        A = asymmetric_composite_m2m
+
+        with Session(connection) as session:
+
+            def go():
+                statement = (
+                    select(A)
+                    .options(selectinload(A.bs))
+                    .order_by(A.id1, A.id2)
+                )
+                session.execute(statement).scalars().all()
+
+            self.assert_sql_execution(
+                connection,
+                go,
+                CompiledSQL(
+                    "SELECT a.id1, a.id2 FROM a ORDER BY a.id1, a.id2",
+                    {},
+                ),
+                CompiledSQL(
+                    "SELECT a_b.a_id1, a_b.a_id2, b.id "
+                    "FROM a_b JOIN b ON b.id = a_b.b_id "
+                    "WHERE (a_b.a_id1, a_b.a_id2) IN "
+                    "(__[POSTCOMPILE_primary_keys])",
+                    {"primary_keys": [(1, 1), (1, 2)]},
+                ),
+            )
+
+    def test_asymmetric_composite_unoptimized(
+        self, asymmetric_composite_m2m, connection
+    ):
+        A = asymmetric_composite_m2m
+
+        with Session(connection) as session:
+
+            def go():
+                statement = (
+                    select(A)
+                    .options(selectinload(A.bs_no_omit_join))
+                    .order_by(A.id1, A.id2)
+                )
+                session.execute(statement).scalars().all()
+
+            self.assert_sql_execution(
+                connection,
+                go,
+                CompiledSQL(
+                    "SELECT a.id1, a.id2 FROM a ORDER BY a.id1, a.id2",
+                    {},
+                ),
+                CompiledSQL(
+                    "SELECT a_1.id1, a_1.id2, b.id "
+                    "FROM a AS a_1 JOIN a_b AS a_b_1 "
+                    "ON a_1.id1 = a_b_1.a_id1 "
+                    "AND a_1.id2 = a_b_1.a_id2 "
+                    "JOIN b ON b.id = a_b_1.b_id "
+                    "WHERE (a_1.id1, a_1.id2) IN "
+                    "(__[POSTCOMPILE_primary_keys])",
+                    {"primary_keys": [(1, 1), (1, 2)]},
+                ),
+            )
+
+    def test_reverse_asymmetric_composite(
+        self, reverse_asymmetric_composite_m2m, connection
+    ):
+        A = reverse_asymmetric_composite_m2m
+
+        with Session(connection) as session:
+
+            def go():
+                statement = (
+                    select(A).options(selectinload(A.bs)).order_by(A.id)
+                )
+                session.execute(statement).scalars().all()
+
+            self.assert_sql_execution(
+                connection,
+                go,
+                CompiledSQL(
+                    "SELECT a.id FROM a ORDER BY a.id",
+                    {},
+                ),
+                CompiledSQL(
+                    "SELECT a_b.a_id, b.id1, b.id2 "
+                    "FROM a_b "
+                    "JOIN b ON b.id1 = a_b.b_id1 "
+                    "AND b.id2 = a_b.b_id2 "
+                    "WHERE a_b.a_id IN "
+                    "(__[POSTCOMPILE_primary_keys])",
+                    {"primary_keys": [1, 2]},
+                ),
+            )
+
+    def test_reverse_asymmetric_composite_unoptimized(
+        self, reverse_asymmetric_composite_m2m, connection
+    ):
+        A = reverse_asymmetric_composite_m2m
+
+        with Session(connection) as session:
+
+            def go():
+                statement = (
+                    select(A)
+                    .options(selectinload(A.bs_no_omit_join))
+                    .order_by(A.id)
+                )
+                session.execute(statement).scalars().all()
+
+            self.assert_sql_execution(
+                connection,
+                go,
+                CompiledSQL(
+                    "SELECT a.id FROM a ORDER BY a.id",
+                    {},
+                ),
+                CompiledSQL(
+                    "SELECT a_1.id, b.id1, b.id2 "
+                    "FROM a AS a_1 "
+                    "JOIN a_b AS a_b_1 ON a_1.id = a_b_1.a_id "
+                    "JOIN b ON b.id1 = a_b_1.b_id1 "
+                    "AND b.id2 = a_b_1.b_id2 "
+                    "WHERE a_1.id IN "
+                    "(__[POSTCOMPILE_primary_keys])",
+                    {"primary_keys": [1, 2]},
+                ),
+            )
+
+    def test_filtered_secondaryjoin(
+        self, filtered_secondaryjoin_m2m, connection
+    ):
+        A = filtered_secondaryjoin_m2m
+
+        with Session(connection) as session:
+
+            def go():
+                statement = (
+                    select(A).options(selectinload(A.bs)).order_by(A.id)
+                )
+                return session.execute(statement).scalars().all()
+
+            results = self.assert_sql_execution(
+                connection,
+                go,
+                CompiledSQL(
+                    "SELECT a.id FROM a ORDER BY a.id",
+                    {},
+                ),
+                CompiledSQL(
+                    "SELECT a_b.a_id, b.id, b.active "
+                    "FROM a_b "
+                    "JOIN b ON b.id = a_b.b_id AND b.active = 1 "
+                    "WHERE a_b.a_id IN "
+                    "(__[POSTCOMPILE_primary_keys])",
+                    {"primary_keys": [1, 2]},
+                ),
+            )
+
+            eq_(sorted(b.id for b in results[0].bs), [1, 3])
+            eq_(sorted(b.id for b in results[1].bs), [3])
+
+    def test_filtered_secondaryjoin_unoptimized(
+        self, filtered_secondaryjoin_m2m, connection
+    ):
+        A = filtered_secondaryjoin_m2m
+
+        with Session(connection) as session:
+
+            def go():
+                statement = (
+                    select(A)
+                    .options(selectinload(A.bs_no_omit_join))
+                    .order_by(A.id)
+                )
+                return session.execute(statement).scalars().all()
+
+            results = self.assert_sql_execution(
+                connection,
+                go,
+                CompiledSQL(
+                    "SELECT a.id FROM a ORDER BY a.id",
+                    {},
+                ),
+                CompiledSQL(
+                    "SELECT a_1.id, b.id, b.active "
+                    "FROM a AS a_1 "
+                    "JOIN a_b AS a_b_1 ON a_1.id = a_b_1.a_id "
+                    "JOIN b ON b.id = a_b_1.b_id AND b.active = 1 "
+                    "WHERE a_1.id IN "
+                    "(__[POSTCOMPILE_primary_keys])",
+                    {"primary_keys": [1, 2]},
+                ),
+            )
+
+            eq_(
+                sorted(b.id for b in results[0].bs_no_omit_join),
+                [1, 3],
+            )
+            eq_(
+                sorted(b.id for b in results[1].bs_no_omit_join),
+                [3],
+            )
+
+
 class SameNamePolymorphicTest(fixtures.DeclarativeMappedTest):
     @classmethod
     def setup_classes(cls):