From: Mike Bayer Date: Tue, 12 May 2026 18:30:37 +0000 (-0400) Subject: Block Result.unique() with Result.yield_per() for ORM results X-Git-Url: http://git.ipfire.org/gitweb/index.cgi?a=commitdiff_plain;h=15a9df9b2db9603ecaa8635396596b4ecca7d481;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Block Result.unique() with Result.yield_per() for ORM results The unique() + yield_per combination was only blocked when yield_per was set via execution_options(yield_per=N); calling these as methods on the result (e.g. result.unique().yield_per(N)) bypassed the check and silently produced incorrect results. Restructured _unique_filters on SimpleResultMetaData to be a callable _create_unique_filters that receives the Result, allowing it to check the yield_per state regardless of how it was activated. Fixes: #13293 Change-Id: I7e6a5e5b2e1d4c8f9a0b3d6e7f1c2a4d5b8e9f0a --- diff --git a/doc/build/changelog/unreleased_21/13293.rst b/doc/build/changelog/unreleased_21/13293.rst new file mode 100644 index 0000000000..0819277032 --- /dev/null +++ b/doc/build/changelog/unreleased_21/13293.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, orm + :tickets: 13293 + + Fixed issue where the :meth:`_engine.Result.unique` filter was not properly + validated against the :meth:`_engine.Result.yield_per` method when both + were called as methods on the result object, such as + ``result.unique().yield_per(N)`` or ``result.yield_per(N).unique()``. The + uniquing filter was previously only checked when ``yield_per`` was set via + :paramref:`_engine.Connection.execution_options.yield_per`. Since these two + features are fundamentally incompatible for ORM results, an + :class:`.InvalidRequestError` is now raised in all cases. diff --git a/lib/sqlalchemy/engine/_result_cy.py b/lib/sqlalchemy/engine/_result_cy.py index f99e351074..51c5602002 100644 --- a/lib/sqlalchemy/engine/_result_cy.py +++ b/lib/sqlalchemy/engine/_result_cy.py @@ -13,6 +13,7 @@ from collections.abc import Sequence from enum import Enum import operator from typing import Any +from typing import cast from typing import Generic from typing import Literal from typing import overload @@ -557,17 +558,21 @@ class BaseResultInternal(Generic[_R]): assert self._unique_filter_state is not None uniques, strategy = self._unique_filter_state - if strategy is None and self._metadata._unique_filters is not None: - real_result = ( - self if self._real_result is None else self._real_result + if ( + strategy is None + and self._metadata._create_unique_filters is not None + ): + real_result = cast( + "Result[Any]", + self if self._real_result is None else self._real_result, ) + filters = self._metadata._create_unique_filters(real_result) if ( real_result._source_supports_scalars and not self._generate_rows ): - strategy = self._metadata._unique_filters[0] + strategy = filters[0] else: - filters = self._metadata._unique_filters if self._metadata._tuplefilter is not None: filters = self._metadata._tuplefilter(filters) diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 1fd66aa9ee..ceb7dc1769 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -166,7 +166,7 @@ class CursorResultMetaData(ResultMetaData): "_safe_for_cache", "_unpickled", "_key_to_index", - # don't need _unique_filters support here for now. Can be added + # don't need _create_unique_filters here for now. Can be added # if a need arises. ) diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 471e2e4c65..ff09c98c72 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -94,7 +94,9 @@ class ResultMetaData: _tuplefilter: Optional[_TupleGetterType] = None _translated_indexes: Optional[Sequence[int]] = None - _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None + _create_unique_filters: Optional[ + Callable[["Result[Any]"], Sequence[Optional[Callable[[Any], Any]]]] + ] = None _keymap: _KeyMapType _keys: Sequence[str] _processors: Optional[_ProcessorsType] @@ -250,7 +252,7 @@ class SimpleResultMetaData(ResultMetaData): "_processors", "_tuplefilter", "_translated_indexes", - "_unique_filters", + "_create_unique_filters", "_key_to_index", ) @@ -263,12 +265,17 @@ class SimpleResultMetaData(ResultMetaData): _processors: Optional[_ProcessorsType] = None, _tuplefilter: Optional[_TupleGetterType] = None, _translated_indexes: Optional[Sequence[int]] = None, - _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None, + _create_unique_filters: Optional[ + Callable[ + [Any], + Sequence[Optional[Callable[[Any], Any]]], + ] + ] = None, ): self._keys = list(keys) self._tuplefilter = _tuplefilter self._translated_indexes = _translated_indexes - self._unique_filters = _unique_filters + self._create_unique_filters = _create_unique_filters if extra: assert len(self._keys) == len(extra) recs_names = [ @@ -294,16 +301,24 @@ class SimpleResultMetaData(ResultMetaData): return key in self._keymap def _for_freeze(self) -> ResultMetaData: - unique_filters = self._unique_filters - if unique_filters and self._tuplefilter: - unique_filters = self._tuplefilter(unique_filters) - # TODO: are we freezing the result with or without uniqueness # applied? + create_unique_filters = self._create_unique_filters + if create_unique_filters is not None and self._tuplefilter is not None: + _tuplefilter = self._tuplefilter + _orig_create_unique_filters = create_unique_filters + + def create_unique_filters_filtered( + result: Result[Any], + ) -> Sequence[Optional[Callable[[Any], Any]]]: + return _tuplefilter(_orig_create_unique_filters(result)) + + create_unique_filters = create_unique_filters_filtered + return SimpleResultMetaData( self._keys, extra=[self._keymap[key][2] for key in self._keys], - _unique_filters=unique_filters, + _create_unique_filters=create_unique_filters, ) def __getstate__(self) -> Dict[str, Any]: @@ -376,7 +391,7 @@ class SimpleResultMetaData(ResultMetaData): _tuplefilter=tup, _translated_indexes=indexes, _processors=self._processors, - _unique_filters=self._unique_filters, + _create_unique_filters=self._create_unique_filters, ) return new_metadata diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 73b5bce29b..2b1ff19acb 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -136,11 +136,6 @@ def instances( with util.safe_reraise(): cursor.close() - def _no_unique(entry): - raise sa_exc.InvalidRequestError( - "Can't use the ORM yield_per feature in conjunction with unique()" - ) - def _not_hashable(datatype, *, legacy=False, uncertain=False): if not legacy: @@ -184,11 +179,20 @@ def instances( return go - unique_filters = [ - ( - _no_unique - if context.yield_per - else ( + _uniquing_is_active = False + + def _create_unique_filters(result): + nonlocal _uniquing_is_active + + if result._yield_per: + raise sa_exc.InvalidRequestError( + "Can't use the ORM yield_per feature " + "in conjunction with unique()" + ) + + _uniquing_is_active = True + return [ + ( _not_hashable( ent.column.type, # type: ignore legacy=context.load_options._legacy_uniquing, @@ -200,12 +204,11 @@ def instances( ) else id if ent.use_id_for_hash else None ) - ) - for ent in context.compile_state._entities - ] + for ent in context.compile_state._entities + ] row_metadata = SimpleResultMetaData( - labels, extra, _unique_filters=unique_filters + labels, extra, _create_unique_filters=_create_unique_filters ) def chunks(size): # type: ignore @@ -215,6 +218,11 @@ def instances( context.partials = {} if yield_per: + if _uniquing_is_active: + raise sa_exc.InvalidRequestError( + "Can't use the ORM yield_per feature " + "in conjunction with unique()" + ) fetch = cursor.fetchmany(yield_per) if not fetch: diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index 6223b86821..5b7bbbdc99 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -1507,7 +1507,7 @@ class CycleTest(_fixtures.FixtureTest): stmt = s.query(User).join(User.addresses).statement - @assert_cycles(8) + @assert_cycles(20) def go(): result = s.execute(stmt) rows = result.fetchall() # noqa @@ -1522,7 +1522,7 @@ class CycleTest(_fixtures.FixtureTest): stmt = s.query(User).join(User.addresses).statement - @assert_cycles(8) + @assert_cycles(20) def go(): result = s.execute(stmt) for partition in result.partitions(3): @@ -1538,7 +1538,7 @@ class CycleTest(_fixtures.FixtureTest): stmt = s.query(User).join(User.addresses).statement - @assert_cycles(8) + @assert_cycles(20) def go(): result = s.execute(stmt) for partition in result.unique().partitions(3): diff --git a/test/base/test_result.py b/test/base/test_result.py index 57970c740b..dcf8559633 100644 --- a/test/base/test_result.py +++ b/test/base/test_result.py @@ -308,7 +308,9 @@ class ResultTest(fixtures.TestBase): iter(data), ) if default_filters: - res._metadata._unique_filters = default_filters + res._metadata._create_unique_filters = ( + lambda result: default_filters + ) if alt_row: res._process_row = alt_row @@ -958,6 +960,16 @@ class ResultTest(fixtures.TestBase): r1 = frozen() eq_(r1.fetchall(), [(1, 1), (1, 2), (3, 2)]) + def test_columns_unique_freeze_w_unique_filters(self): + result = self._fixture(default_filters=[id, None, None]) + + result = result.columns("b", "c") + + frozen = result.freeze() + + r1 = frozen().unique() + eq_(r1.fetchall(), [(1, 1), (1, 2), (3, 2)]) + def test_columns_freeze(self): result = self._fixture() @@ -1206,7 +1218,7 @@ class OnlyScalarsTest(fixtures.TestBase): def test_scalar_mode_mfiltered_unique_rows_all(self, no_tuple_fixture): metadata = result.SimpleResultMetaData( - ["a", "b", "c"], _unique_filters=[int] + ["a", "b", "c"], _create_unique_filters=lambda result: [int] ) r = result.ChunkedIteratorResult( @@ -1227,7 +1239,7 @@ class OnlyScalarsTest(fixtures.TestBase): ) def test_unique_scalar_accessors(self, no_tuple_one_fixture, get): metadata = result.SimpleResultMetaData( - ["a", "b", "c"], _unique_filters=[int] + ["a", "b", "c"], _create_unique_filters=lambda result: [int] ) r = result.ChunkedIteratorResult( @@ -1242,7 +1254,7 @@ class OnlyScalarsTest(fixtures.TestBase): def test_scalar_mode_mfiltered_unique_mappings_all(self, no_tuple_fixture): metadata = result.SimpleResultMetaData( - ["a", "b", "c"], _unique_filters=[int] + ["a", "b", "c"], _create_unique_filters=lambda result: [int] ) r = result.ChunkedIteratorResult( @@ -1257,7 +1269,7 @@ class OnlyScalarsTest(fixtures.TestBase): def test_scalar_mode_mfiltered_unique_scalars_all(self, no_tuple_fixture): metadata = result.SimpleResultMetaData( - ["a", "b", "c"], _unique_filters=[int] + ["a", "b", "c"], _create_unique_filters=lambda result: [int] ) r = result.ChunkedIteratorResult( diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 0dff01612a..547750427a 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -5671,6 +5671,65 @@ class YieldTest(_fixtures.FixtureTest): ): next(result) + result.close() + + def test_no_unique_w_yield_per_method_unique_first(self): + self._eagerload_mappings() + + User = self.classes.User + + sess = fixture_session() + stmt = select(User) + + result = sess.execute(stmt).unique().yield_per(10) + + with expect_raises_message( + sa_exc.InvalidRequestError, + r"Can't use the ORM yield_per feature in " + r"conjunction with unique\(\)", + ): + next(result) + + result.close() + + def test_no_unique_w_yield_per_method_yield_per_first(self): + self._eagerload_mappings() + + User = self.classes.User + + sess = fixture_session() + stmt = select(User) + + result = sess.execute(stmt).yield_per(10).unique() + + with expect_raises_message( + sa_exc.InvalidRequestError, + r"Can't use the ORM yield_per feature in " + r"conjunction with unique\(\)", + ): + next(result) + + result.close() + + def test_no_unique_w_yield_per_stream_results(self): + self._eagerload_mappings() + + User = self.classes.User + + sess = fixture_session() + stmt = select(User).execution_options(stream_results=True) + + result = sess.execute(stmt).unique().yield_per(10) + + with expect_raises_message( + sa_exc.InvalidRequestError, + r"Can't use the ORM yield_per feature in " + r"conjunction with unique\(\)", + ): + next(result) + + result.close() + class YieldIterationTest(_fixtures.FixtureTest): run_inserts = "once"