]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Block Result.unique() with Result.yield_per() for ORM results
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 May 2026 18:30:37 +0000 (14:30 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 May 2026 22:47:56 +0000 (18:47 -0400)
The unique() + yield_per combination was only blocked when yield_per
was set via execution_options(yield_per=N); calling these as methods
on the result (e.g. result.unique().yield_per(N)) bypassed the check
and silently produced incorrect results.

Restructured _unique_filters on SimpleResultMetaData to be a callable
_create_unique_filters that receives the Result, allowing it to check
the yield_per state regardless of how it was activated.

Fixes: #13293
Change-Id: I7e6a5e5b2e1d4c8f9a0b3d6e7f1c2a4d5b8e9f0a

doc/build/changelog/unreleased_21/13293.rst [new file with mode: 0644]
lib/sqlalchemy/engine/_result_cy.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/orm/loading.py
test/aaa_profiling/test_memusage.py
test/base/test_result.py
test/orm/test_query.py

diff --git a/doc/build/changelog/unreleased_21/13293.rst b/doc/build/changelog/unreleased_21/13293.rst
new file mode 100644 (file)
index 0000000..0819277
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 13293
+
+    Fixed issue where the :meth:`_engine.Result.unique` filter was not properly
+    validated against the :meth:`_engine.Result.yield_per` method when both
+    were called as methods on the result object, such as
+    ``result.unique().yield_per(N)`` or ``result.yield_per(N).unique()``. The
+    uniquing filter was previously only checked when ``yield_per`` was set via
+    :paramref:`_engine.Connection.execution_options.yield_per`. Since these two
+    features are fundamentally incompatible for ORM results, an
+    :class:`.InvalidRequestError` is now raised in all cases.
index f99e351074ee10822861bb763fb9bb1ff225b451..51c56020024ca42f741007a62545404379bb2165 100644 (file)
@@ -13,6 +13,7 @@ from collections.abc import Sequence
 from enum import Enum
 import operator
 from typing import Any
+from typing import cast
 from typing import Generic
 from typing import Literal
 from typing import overload
@@ -557,17 +558,21 @@ class BaseResultInternal(Generic[_R]):
         assert self._unique_filter_state is not None
         uniques, strategy = self._unique_filter_state
 
-        if strategy is None and self._metadata._unique_filters is not None:
-            real_result = (
-                self if self._real_result is None else self._real_result
+        if (
+            strategy is None
+            and self._metadata._create_unique_filters is not None
+        ):
+            real_result = cast(
+                "Result[Any]",
+                self if self._real_result is None else self._real_result,
             )
+            filters = self._metadata._create_unique_filters(real_result)
             if (
                 real_result._source_supports_scalars
                 and not self._generate_rows
             ):
-                strategy = self._metadata._unique_filters[0]
+                strategy = filters[0]
             else:
-                filters = self._metadata._unique_filters
                 if self._metadata._tuplefilter is not None:
                     filters = self._metadata._tuplefilter(filters)
 
index 1fd66aa9ee7a060fa3032ecc9d73fb23edb21e38..ceb7dc1769fe279481045ae885c8fa2daec0722b 100644 (file)
@@ -166,7 +166,7 @@ class CursorResultMetaData(ResultMetaData):
         "_safe_for_cache",
         "_unpickled",
         "_key_to_index",
-        # don't need _unique_filters support here for now.  Can be added
+        # don't need _create_unique_filters here for now.  Can be added
         # if a need arises.
     )
 
index 471e2e4c65e31f625bd6790f3b8177b58cd3d333..ff09c98c72e8094ce581cd8aa760b6352d286e81 100644 (file)
@@ -94,7 +94,9 @@ class ResultMetaData:
 
     _tuplefilter: Optional[_TupleGetterType] = None
     _translated_indexes: Optional[Sequence[int]] = None
-    _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None
+    _create_unique_filters: Optional[
+        Callable[["Result[Any]"], Sequence[Optional[Callable[[Any], Any]]]]
+    ] = None
     _keymap: _KeyMapType
     _keys: Sequence[str]
     _processors: Optional[_ProcessorsType]
@@ -250,7 +252,7 @@ class SimpleResultMetaData(ResultMetaData):
         "_processors",
         "_tuplefilter",
         "_translated_indexes",
-        "_unique_filters",
+        "_create_unique_filters",
         "_key_to_index",
     )
 
@@ -263,12 +265,17 @@ class SimpleResultMetaData(ResultMetaData):
         _processors: Optional[_ProcessorsType] = None,
         _tuplefilter: Optional[_TupleGetterType] = None,
         _translated_indexes: Optional[Sequence[int]] = None,
-        _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None,
+        _create_unique_filters: Optional[
+            Callable[
+                [Any],
+                Sequence[Optional[Callable[[Any], Any]]],
+            ]
+        ] = None,
     ):
         self._keys = list(keys)
         self._tuplefilter = _tuplefilter
         self._translated_indexes = _translated_indexes
-        self._unique_filters = _unique_filters
+        self._create_unique_filters = _create_unique_filters
         if extra:
             assert len(self._keys) == len(extra)
             recs_names = [
@@ -294,16 +301,24 @@ class SimpleResultMetaData(ResultMetaData):
         return key in self._keymap
 
     def _for_freeze(self) -> ResultMetaData:
-        unique_filters = self._unique_filters
-        if unique_filters and self._tuplefilter:
-            unique_filters = self._tuplefilter(unique_filters)
-
         # TODO: are we freezing the result with or without uniqueness
         # applied?
+        create_unique_filters = self._create_unique_filters
+        if create_unique_filters is not None and self._tuplefilter is not None:
+            _tuplefilter = self._tuplefilter
+            _orig_create_unique_filters = create_unique_filters
+
+            def create_unique_filters_filtered(
+                result: Result[Any],
+            ) -> Sequence[Optional[Callable[[Any], Any]]]:
+                return _tuplefilter(_orig_create_unique_filters(result))
+
+            create_unique_filters = create_unique_filters_filtered
+
         return SimpleResultMetaData(
             self._keys,
             extra=[self._keymap[key][2] for key in self._keys],
-            _unique_filters=unique_filters,
+            _create_unique_filters=create_unique_filters,
         )
 
     def __getstate__(self) -> Dict[str, Any]:
@@ -376,7 +391,7 @@ class SimpleResultMetaData(ResultMetaData):
             _tuplefilter=tup,
             _translated_indexes=indexes,
             _processors=self._processors,
-            _unique_filters=self._unique_filters,
+            _create_unique_filters=self._create_unique_filters,
         )
 
         return new_metadata
index 73b5bce29b3ee288e87aa922acccbaf6dab255fd..2b1ff19acb150ca4f61feb09b7bf818206953326 100644 (file)
@@ -136,11 +136,6 @@ def instances(
         with util.safe_reraise():
             cursor.close()
 
-    def _no_unique(entry):
-        raise sa_exc.InvalidRequestError(
-            "Can't use the ORM yield_per feature in conjunction with unique()"
-        )
-
     def _not_hashable(datatype, *, legacy=False, uncertain=False):
         if not legacy:
 
@@ -184,11 +179,20 @@ def instances(
 
             return go
 
-    unique_filters = [
-        (
-            _no_unique
-            if context.yield_per
-            else (
+    _uniquing_is_active = False
+
+    def _create_unique_filters(result):
+        nonlocal _uniquing_is_active
+
+        if result._yield_per:
+            raise sa_exc.InvalidRequestError(
+                "Can't use the ORM yield_per feature "
+                "in conjunction with unique()"
+            )
+
+        _uniquing_is_active = True
+        return [
+            (
                 _not_hashable(
                     ent.column.type,  # type: ignore
                     legacy=context.load_options._legacy_uniquing,
@@ -200,12 +204,11 @@ def instances(
                 )
                 else id if ent.use_id_for_hash else None
             )
-        )
-        for ent in context.compile_state._entities
-    ]
+            for ent in context.compile_state._entities
+        ]
 
     row_metadata = SimpleResultMetaData(
-        labels, extra, _unique_filters=unique_filters
+        labels, extra, _create_unique_filters=_create_unique_filters
     )
 
     def chunks(size):  # type: ignore
@@ -215,6 +218,11 @@ def instances(
             context.partials = {}
 
             if yield_per:
+                if _uniquing_is_active:
+                    raise sa_exc.InvalidRequestError(
+                        "Can't use the ORM yield_per feature "
+                        "in conjunction with unique()"
+                    )
                 fetch = cursor.fetchmany(yield_per)
 
                 if not fetch:
index 6223b86821121c4dfe67b95e28bc87dbb138c0e2..5b7bbbdc99e9e4b7dc2908998e598e8ba4bb93f6 100644 (file)
@@ -1507,7 +1507,7 @@ class CycleTest(_fixtures.FixtureTest):
 
         stmt = s.query(User).join(User.addresses).statement
 
-        @assert_cycles(8)
+        @assert_cycles(20)
         def go():
             result = s.execute(stmt)
             rows = result.fetchall()  # noqa
@@ -1522,7 +1522,7 @@ class CycleTest(_fixtures.FixtureTest):
 
         stmt = s.query(User).join(User.addresses).statement
 
-        @assert_cycles(8)
+        @assert_cycles(20)
         def go():
             result = s.execute(stmt)
             for partition in result.partitions(3):
@@ -1538,7 +1538,7 @@ class CycleTest(_fixtures.FixtureTest):
 
         stmt = s.query(User).join(User.addresses).statement
 
-        @assert_cycles(8)
+        @assert_cycles(20)
         def go():
             result = s.execute(stmt)
             for partition in result.unique().partitions(3):
index 57970c740b73daf8ad642de447a26434c9f5e432..dcf85596333dcc7c3b3d94a07825753aafaf87c7 100644 (file)
@@ -308,7 +308,9 @@ class ResultTest(fixtures.TestBase):
             iter(data),
         )
         if default_filters:
-            res._metadata._unique_filters = default_filters
+            res._metadata._create_unique_filters = (
+                lambda result: default_filters
+            )
 
         if alt_row:
             res._process_row = alt_row
@@ -958,6 +960,16 @@ class ResultTest(fixtures.TestBase):
         r1 = frozen()
         eq_(r1.fetchall(), [(1, 1), (1, 2), (3, 2)])
 
+    def test_columns_unique_freeze_w_unique_filters(self):
+        result = self._fixture(default_filters=[id, None, None])
+
+        result = result.columns("b", "c")
+
+        frozen = result.freeze()
+
+        r1 = frozen().unique()
+        eq_(r1.fetchall(), [(1, 1), (1, 2), (3, 2)])
+
     def test_columns_freeze(self):
         result = self._fixture()
 
@@ -1206,7 +1218,7 @@ class OnlyScalarsTest(fixtures.TestBase):
 
     def test_scalar_mode_mfiltered_unique_rows_all(self, no_tuple_fixture):
         metadata = result.SimpleResultMetaData(
-            ["a", "b", "c"], _unique_filters=[int]
+            ["a", "b", "c"], _create_unique_filters=lambda result: [int]
         )
 
         r = result.ChunkedIteratorResult(
@@ -1227,7 +1239,7 @@ class OnlyScalarsTest(fixtures.TestBase):
     )
     def test_unique_scalar_accessors(self, no_tuple_one_fixture, get):
         metadata = result.SimpleResultMetaData(
-            ["a", "b", "c"], _unique_filters=[int]
+            ["a", "b", "c"], _create_unique_filters=lambda result: [int]
         )
 
         r = result.ChunkedIteratorResult(
@@ -1242,7 +1254,7 @@ class OnlyScalarsTest(fixtures.TestBase):
 
     def test_scalar_mode_mfiltered_unique_mappings_all(self, no_tuple_fixture):
         metadata = result.SimpleResultMetaData(
-            ["a", "b", "c"], _unique_filters=[int]
+            ["a", "b", "c"], _create_unique_filters=lambda result: [int]
         )
 
         r = result.ChunkedIteratorResult(
@@ -1257,7 +1269,7 @@ class OnlyScalarsTest(fixtures.TestBase):
 
     def test_scalar_mode_mfiltered_unique_scalars_all(self, no_tuple_fixture):
         metadata = result.SimpleResultMetaData(
-            ["a", "b", "c"], _unique_filters=[int]
+            ["a", "b", "c"], _create_unique_filters=lambda result: [int]
         )
 
         r = result.ChunkedIteratorResult(
index 0dff01612af471835a3f7a850883fa3c47e047fe..547750427a6fd0eff867c2ff6c4dbe79f11eab6e 100644 (file)
@@ -5671,6 +5671,65 @@ class YieldTest(_fixtures.FixtureTest):
         ):
             next(result)
 
+        result.close()
+
+    def test_no_unique_w_yield_per_method_unique_first(self):
+        self._eagerload_mappings()
+
+        User = self.classes.User
+
+        sess = fixture_session()
+        stmt = select(User)
+
+        result = sess.execute(stmt).unique().yield_per(10)
+
+        with expect_raises_message(
+            sa_exc.InvalidRequestError,
+            r"Can't use the ORM yield_per feature in "
+            r"conjunction with unique\(\)",
+        ):
+            next(result)
+
+        result.close()
+
+    def test_no_unique_w_yield_per_method_yield_per_first(self):
+        self._eagerload_mappings()
+
+        User = self.classes.User
+
+        sess = fixture_session()
+        stmt = select(User)
+
+        result = sess.execute(stmt).yield_per(10).unique()
+
+        with expect_raises_message(
+            sa_exc.InvalidRequestError,
+            r"Can't use the ORM yield_per feature in "
+            r"conjunction with unique\(\)",
+        ):
+            next(result)
+
+        result.close()
+
+    def test_no_unique_w_yield_per_stream_results(self):
+        self._eagerload_mappings()
+
+        User = self.classes.User
+
+        sess = fixture_session()
+        stmt = select(User).execution_options(stream_results=True)
+
+        result = sess.execute(stmt).unique().yield_per(10)
+
+        with expect_raises_message(
+            sa_exc.InvalidRequestError,
+            r"Can't use the ORM yield_per feature in "
+            r"conjunction with unique\(\)",
+        ):
+            next(result)
+
+        result.close()
+
 
 class YieldIterationTest(_fixtures.FixtureTest):
     run_inserts = "once"