]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
configurable chunksize parameter for selectinload
authorbekapono <bsiliezar2@gmail.com>
Wed, 22 Apr 2026 21:15:29 +0000 (17:15 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 23 Apr 2026 18:26:22 +0000 (14:26 -0400)
Added :paramref:`.selectinload.chunksize` parameter to :func`.selectinload`
allowing users to configure the number of primary keys sent per IN clause
when loading reltaionships. Pull request courtesy bekapono.

Fixes: #11450
Closes: #13235
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13235
Pull-request-sha: 360585a48c0fe898aa249769e9c7c1171f9e0988

Change-Id: Id09776b7ba53c630a780f128fc67dfdc085a4062

doc/build/changelog/unreleased_21/11450.rst [new file with mode: 0644]
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/strategy_options.py
test/orm/test_selectin_relations.py

diff --git a/doc/build/changelog/unreleased_21/11450.rst b/doc/build/changelog/unreleased_21/11450.rst
new file mode 100644 (file)
index 0000000..371f86d
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+  :tags: feature, orm
+  :tickets: 11450
+
+  Added :paramref:`.selectinload.chunksize` parameter to :func`.selectinload`
+  allowing users to configure the number of primary keys sent per IN clause
+  when loading reltaionships. Pull request courtesy bekapono.
index 69bb716663c87e78e3e193212d8f37b44265325f..d7672b3e4a052d2ad98a24880e9ad306be5d0663 100644 (file)
@@ -2977,6 +2977,21 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots):
 
     _chunksize = 500
 
+    @classmethod
+    def _set_chunksize(cls, loadopt) -> int:
+        if loadopt is None or hasattr(loadopt, "local_opts") is None:
+            return cls._chunksize
+
+        user_input = loadopt.local_opts.get("chunksize", None)
+        if user_input is None:
+            return cls._chunksize
+        elif not isinstance(user_input, int) or user_input < 1:
+            raise sa_exc.ArgumentError(
+                f"'chunksize={user_input}' is not an appropriate input, "
+                f"please use a positive non-zero integer."
+            )
+        return user_input
+
     def __init__(self, parent, strategy_key):
         super().__init__(parent, strategy_key)
         self.join_depth = self.parent_property.join_depth
@@ -3347,6 +3362,8 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots):
                     _setup_outermost_orderby, self.parent_property
                 )
 
+        chunksize = self._set_chunksize(loadopt)
+
         if query_info.load_only_child:
             self._load_via_child(
                 our_states,
@@ -3355,10 +3372,16 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots):
                 q,
                 context,
                 execution_options,
+                chunksize,
             )
         else:
             self._load_via_parent(
-                our_states, query_info, q, context, execution_options
+                our_states,
+                query_info,
+                q,
+                context,
+                execution_options,
+                chunksize,
             )
 
     def _load_via_child(
@@ -3369,14 +3392,15 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots):
         q,
         context,
         execution_options,
+        chunksize,
     ):
         uselist = self.uselist
 
         # this sort is really for the benefit of the unit tests
         our_keys = sorted(our_states)
         while our_keys:
-            chunk = our_keys[0 : self._chunksize]
-            our_keys = our_keys[self._chunksize :]
+            chunk = our_keys[0:chunksize]
+            our_keys = our_keys[chunksize:]
             data = {
                 k: v
                 for k, v in context.session.execute(
@@ -3417,14 +3441,14 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots):
             state.get_impl(self.key).set_committed_value(state, dict_, None)
 
     def _load_via_parent(
-        self, our_states, query_info, q, context, execution_options
+        self, our_states, query_info, q, context, execution_options, chunksize
     ):
         uselist = self.uselist
         _empty_result = () if uselist else None
 
         while our_states:
-            chunk = our_states[0 : self._chunksize]
-            our_states = our_states[self._chunksize :]
+            chunk = our_states[0:chunksize]
+            our_states = our_states[chunksize:]
 
             primary_keys = [
                 key[0] if query_info.zero_idx else key
index d0ed2803d330b1887382accb9a9b09b905b4b596..c9bd03b67a3f3f14b3866e31e5354b7b47953fad 100644 (file)
@@ -359,6 +359,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         self,
         attr: _AttrType,
         recursion_depth: Optional[int] = None,
+        chunksize: Optional[int] = None,
     ) -> Self:
         """Indicate that the given attribute should be loaded using
         SELECT IN eager loading.
@@ -397,6 +398,11 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
          .. versionadded:: 2.0 added
             :paramref:`_orm.selectinload.recursion_depth`
 
+        :param chunksize: optional int; when set to a positive non-zero
+         integer, the keys from the IN statement will be chunked relative
+         to the passed parameter
+
+         .. versionadded:: 2.1.0b3
 
         .. seealso::
 
@@ -408,7 +414,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         return self._set_relationship_strategy(
             attr,
             {"lazy": "selectin"},
-            opts={"recursion_depth": recursion_depth},
+            opts={"recursion_depth": recursion_depth, "chunksize": chunksize},
         )
 
     def lazyload(self, attr: _AttrType) -> Self:
@@ -2453,10 +2459,15 @@ def subqueryload(*keys: _AttrType) -> _AbstractLoad:
 
 @loader_unbound_fn
 def selectinload(
-    *keys: _AttrType, recursion_depth: Optional[int] = None
+    *keys: _AttrType,
+    recursion_depth: Optional[int] = None,
+    chunksize: Optional[int] = None,
 ) -> _AbstractLoad:
     return _generate_from_keys(
-        Load.selectinload, keys, False, {"recursion_depth": recursion_depth}
+        Load.selectinload,
+        keys,
+        False,
+        {"recursion_depth": recursion_depth, "chunksize": chunksize},
     )
 
 
index 9623cf0fae4229b05d1bfec821a2bd354f921de7..4ca4778c0a54329997bb24014fc9ad4a53473537 100644 (file)
@@ -2379,44 +2379,81 @@ class ChunkingTest(fixtures.DeclarativeMappedTest):
         )
         session.commit()
 
-    def test_odd_number_chunks(self):
+    @testing.combinations(
+        (None, (1, 101)),
+        (47, (1, 48, 95, 101)),
+        (50, (1, 51, 101)),
+        (99, (1, 100, 101)),
+        (108, (1, 101)),
+        argnames="chunksize, expected_range",
+    )
+    @testing.variation("chunksize_spec", ["monkeypatch", "parameter"])
+    def test_odd_number_chunks(
+        self, chunksize, expected_range, chunksize_spec
+    ):
         A, B = self.classes("A", "B")
 
         session = fixture_session()
 
         def go():
-            with mock.patch(
-                "sqlalchemy.orm.strategies._SelectInLoader._chunksize", 47
-            ):
-                q = session.query(A).options(selectinload(A.bs)).order_by(A.id)
+            if chunksize_spec.monkeypatch:
+                if chunksize is None:
+                    statement = (
+                        select(A).options(selectinload(A.bs)).order_by(A.id)
+                    )
+
+                    session.scalars(statement).all()
+                else:
+                    with mock.patch(
+                        "sqlalchemy.orm.strategies._SelectInLoader._chunksize",
+                        chunksize,
+                    ):
+
+                        statement = (
+                            select(A)
+                            .options(selectinload(A.bs))
+                            .order_by(A.id)
+                        )
 
-                for a in q:
-                    a.bs
+                        session.scalars(statement).all()
+            else:
+
+                statement = (
+                    select(A)
+                    .options(selectinload(A.bs, chunksize=chunksize))
+                    .order_by(A.id)
+                )
+
+                session.scalars(statement).all()
 
         self.assert_sql_execution(
             testing.db,
             go,
-            CompiledSQL("SELECT a.id AS a_id FROM a ORDER BY a.id", {}),
-            CompiledSQL(
-                "SELECT b.a_id, b.id "
-                "FROM b WHERE b.a_id IN "
-                "(__[POSTCOMPILE_primary_keys]) ORDER BY b.id",
-                {"primary_keys": list(range(1, 48))},
-            ),
-            CompiledSQL(
-                "SELECT b.a_id, b.id "
-                "FROM b WHERE b.a_id IN "
-                "(__[POSTCOMPILE_primary_keys]) ORDER BY b.id",
-                {"primary_keys": list(range(48, 95))},
-            ),
-            CompiledSQL(
-                "SELECT b.a_id, b.id "
-                "FROM b WHERE b.a_id IN "
-                "(__[POSTCOMPILE_primary_keys]) ORDER BY b.id",
-                {"primary_keys": list(range(95, 101))},
-            ),
+            CompiledSQL("SELECT a.id FROM a ORDER BY a.id", {}),
+            *[
+                CompiledSQL(
+                    "SELECT b.a_id, b.id "
+                    "FROM b WHERE b.a_id IN "
+                    "(__[POSTCOMPILE_primary_keys]) ORDER BY b.id",
+                    {"primary_keys": list(range(a, b))},
+                )
+                for a, b in zip(expected_range, expected_range[1:])
+            ],
         )
 
+    @testing.combinations(-250, "a", 0)
+    def test_chunksize_value_error(self, chunksize):
+        A, B = self.classes("A", "B")
+
+        def go():
+            with testing.expect_raises_message(
+                sa.exc.ArgumentError,
+                ".*please use a positive non-zero integer.*",
+            ):
+                select(A).options(
+                    selectinload(A.bs, chunksize=chunksize)
+                ).order_by(A.id)
+
     @testing.requires.independent_cursors
     def test_yield_per(self):
         # the docs make a lot of guarantees about yield_per
@@ -2493,6 +2530,153 @@ class ChunkingTest(fixtures.DeclarativeMappedTest):
         )
 
 
+class ChainedChunkingTest(fixtures.DeclarativeMappedTest):
+    @classmethod
+    def setup_mappers(cls):
+        Base = cls.DeclarativeBasic
+
+        class A(ComparableEntity, Base):
+            __tablename__ = "a"
+            id = Column(Integer, primary_key=True)
+            bs = relationship("B", order_by="B.id", back_populates="a")
+
+        class B(ComparableEntity, Base):
+            __tablename__ = "b"
+            id = Column(Integer, primary_key=True)
+            a_id = Column(ForeignKey("a.id"))
+            a = relationship("A", back_populates="bs")
+            cs = relationship("C", order_by="C.id", back_populates="b")
+
+        class C(ComparableEntity, Base):
+            __tablename__ = "c"
+            id = Column(Integer, primary_key=True)
+            b_id = Column(ForeignKey("b.id"))
+            b = relationship("B", back_populates="cs")
+
+    @classmethod
+    def insert_data(cls, connection):
+        A, B, C = cls.classes("A", "B", "C")
+
+        session = Session(connection)
+
+        for i in range(1, 6):
+            b_list = []
+            for j in range(1, 4):
+                b_id = (i * 6) + j
+                c_id = b_id + 1
+
+                b_list.append(B(id=b_id, cs=[C(id=c_id)]))
+            session.add(A(id=i, bs=b_list))
+        session.commit()
+
+    def test_chained_selectinload_with_two_custom_chunksize(self):
+        A, B, C = self.classes("A", "B", "C")
+
+        b_list = [7, 8, 9, 13, 14, 15, 19, 20, 21, 25, 26, 27, 31, 32, 33]
+
+        session = fixture_session()
+
+        def go():
+            statement = (
+                select(A)
+                .options(
+                    selectinload(A.bs, chunksize=3).selectinload(
+                        B.cs, chunksize=4
+                    )
+                )
+                .order_by(A.id)
+            )
+
+            session.scalars(statement).all()
+
+        self.assert_sql_execution(
+            testing.db,
+            go,
+            CompiledSQL("SELECT a.id FROM a ORDER BY a.id", {}),
+            CompiledSQL(
+                "SELECT b.a_id, b.id "
+                "FROM b WHERE b.a_id IN "
+                "(__[POSTCOMPILE_primary_keys]) ORDER BY b.id",
+                {"primary_keys": list(range(1, 4))},
+            ),
+            CompiledSQL(
+                "SELECT b.a_id, b.id "
+                "FROM b WHERE b.a_id IN "
+                "(__[POSTCOMPILE_primary_keys]) ORDER BY b.id",
+                {"primary_keys": list(range(4, 6))},
+            ),
+            CompiledSQL(
+                "SELECT c.b_id, c.id "
+                "FROM c WHERE c.b_id IN "
+                "(__[POSTCOMPILE_primary_keys]) ORDER BY c.id",
+                {"primary_keys": b_list[0:4]},
+            ),
+            CompiledSQL(
+                "SELECT c.b_id, c.id "
+                "FROM c WHERE c.b_id IN "
+                "(__[POSTCOMPILE_primary_keys]) ORDER BY c.id",
+                {"primary_keys": b_list[4:8]},
+            ),
+            CompiledSQL(
+                "SELECT c.b_id, c.id "
+                "FROM c WHERE c.b_id IN "
+                "(__[POSTCOMPILE_primary_keys]) ORDER BY c.id",
+                {"primary_keys": b_list[8:12]},
+            ),
+            CompiledSQL(
+                "SELECT c.b_id, c.id "
+                "FROM c WHERE c.b_id IN "
+                "(__[POSTCOMPILE_primary_keys]) ORDER BY c.id",
+                {"primary_keys": b_list[12:]},
+            ),
+        )
+
+    def test_chained_selectinload_with_one_chunksize(self):
+        """
+        This test is to make sure that a previous custom chunksize doesn't
+        effect chunksize in remaining selectinload
+        """
+
+        A, B, C = self.classes("A", "B", "C")
+
+        b_list = [7, 8, 9, 13, 14, 15, 19, 20, 21, 25, 26, 27, 31, 32, 33]
+
+        session = fixture_session()
+
+        def go():
+            statement = (
+                select(A)
+                .options(selectinload(A.bs, chunksize=3).selectinload(B.cs))
+                .order_by(A.id)
+            )
+
+            session.scalars(statement).all()
+
+        self.assert_sql_execution(
+            testing.db,
+            go,
+            CompiledSQL("SELECT a.id FROM a ORDER BY a.id", {}),
+            CompiledSQL(
+                "SELECT b.a_id, b.id "
+                "FROM b WHERE b.a_id IN "
+                "(__[POSTCOMPILE_primary_keys]) ORDER BY b.id",
+                {"primary_keys": list(range(1, 4))},
+            ),
+            CompiledSQL(
+                "SELECT b.a_id, b.id "
+                "FROM b WHERE b.a_id IN "
+                "(__[POSTCOMPILE_primary_keys]) ORDER BY b.id",
+                {"primary_keys": list(range(4, 6))},
+            ),
+            CompiledSQL(
+                "SELECT c.b_id, c.id "
+                "FROM c WHERE c.b_id IN "
+                "(__[POSTCOMPILE_primary_keys]) ORDER BY c.id",
+                {"primary_keys": b_list},
+            ),
+        )
+
+
 class SubRelationFromJoinedSubclassMultiLevelTest(_Polymorphic):
     @classmethod
     def define_tables(cls, metadata):