From: Federico Caselli Date: Fri, 29 Dec 2023 21:30:24 +0000 (+0100) Subject: Improve handling of sentinel columns X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=06204730851785ffcb230774d27e5ba555186580;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve handling of sentinel columns Fixed issue in "insertmanyvalues" feature where an INSERT..RETURNING that also made use of a sentinel column to track results would fail to filter out the additional column when :meth:`.Result.unique` were used to uniquify the result set. Fixes: #10802 Change-Id: Ie4f9dab96193099002088c5219cc41a543a00f62 --- diff --git a/doc/build/changelog/unreleased_21/10802.rst b/doc/build/changelog/unreleased_21/10802.rst new file mode 100644 index 0000000000..cb84386515 --- /dev/null +++ b/doc/build/changelog/unreleased_21/10802.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, engine + :tickets: 10802 + + Fixed issue in "insertmanyvalues" feature where an INSERT..RETURNING + that also made use of a sentinel column to track results would fail to + filter out the additional column when :meth:`.Result.unique` were used + to uniquify the result set. diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 351ccda4c3..165ae2feaa 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -200,11 +200,14 @@ class CursorResultMetaData(ResultMetaData): new_obj._key_to_index = self._make_key_to_index(keymap, MD_INDEX) return new_obj - def _remove_processors(self) -> Self: - assert not self._tuplefilter + def _remove_processors_and_tuple_filter(self) -> Self: + if self._tuplefilter: + proc = self._tuplefilter(self._processors) + else: + proc = self._processors return self._make_new_metadata( unpickled=self._unpickled, - processors=[None] * len(self._processors), + processors=[None] * len(proc), tuplefilter=None, translated_indexes=None, keymap={ @@ -217,8 +220,6 @@ class CursorResultMetaData(ResultMetaData): ) def _splice_horizontally(self, other: CursorResultMetaData) -> Self: - assert not self._tuplefilter - keymap = dict(self._keymap) offset = len(self._keys) keymap.update( @@ -236,12 +237,25 @@ class CursorResultMetaData(ResultMetaData): for key, value in other._keymap.items() } ) + self_tf = self._tuplefilter + other_tf = other._tuplefilter + + proc: List[Any] = [] + for pp, tf in [ + (self._processors, self_tf), + (other._processors, other_tf), + ]: + proc.extend(pp if tf is None else tf(pp)) + + new_keys = [*self._keys, *other._keys] + assert len(proc) == len(new_keys) + return self._make_new_metadata( unpickled=self._unpickled, - processors=self._processors + other._processors, # type: ignore + processors=proc, tuplefilter=None, translated_indexes=None, - keys=self._keys + other._keys, # type: ignore + keys=new_keys, keymap=keymap, safe_for_cache=self._safe_for_cache, keymap_by_result_column_idx={ @@ -323,7 +337,6 @@ class CursorResultMetaData(ResultMetaData): for metadata_entry in self._keymap.values() } - assert not self._tuplefilter return self._make_new_metadata( keymap=self._keymap | { @@ -335,7 +348,7 @@ class CursorResultMetaData(ResultMetaData): }, unpickled=self._unpickled, processors=self._processors, - tuplefilter=None, + tuplefilter=self._tuplefilter, translated_indexes=None, keys=self._keys, safe_for_cache=self._safe_for_cache, @@ -348,9 +361,17 @@ class CursorResultMetaData(ResultMetaData): cursor_description: _DBAPICursorDescription, *, driver_column_names: bool = False, + num_sentinel_cols: int = 0, ): context = parent.context - self._tuplefilter = None + if num_sentinel_cols > 0: + # this is slightly faster than letting tuplegetter use the indexes + self._tuplefilter = tuplefilter = operator.itemgetter( + slice(-num_sentinel_cols) + ) + cursor_description = tuplefilter(cursor_description) + else: + self._tuplefilter = tuplefilter = None self._translated_indexes = None self._safe_for_cache = self._unpickled = False @@ -362,6 +383,8 @@ class CursorResultMetaData(ResultMetaData): ad_hoc_textual, loose_column_name_matching, ) = context.result_column_struct + if tuplefilter is not None: + result_columns = tuplefilter(result_columns) num_ctx_cols = len(result_columns) else: result_columns = cols_are_ordered = ( # type: ignore @@ -389,6 +412,10 @@ class CursorResultMetaData(ResultMetaData): self._processors = [ metadata_entry[MD_PROCESSOR] for metadata_entry in raw ] + if num_sentinel_cols > 0: + # add the number of sentinel columns since these are passed + # to the tuplefilters before being used + self._processors.extend([None] * num_sentinel_cols) # this is used when using this ResultMetaData in a Core-only cache # retrieval context. it's initialized on first cache retrieval @@ -951,7 +978,7 @@ class CursorResultMetaData(ResultMetaData): self, keys: Sequence[Any] ) -> Iterator[_NonAmbigCursorKeyMapRecType]: for key in keys: - if int in key.__class__.__mro__: + if isinstance(key, int): key = self._keys[key] try: @@ -995,10 +1022,11 @@ class CursorResultMetaData(ResultMetaData): self._keys = state["_keys"] self._unpickled = True if state["_translated_indexes"]: - self._translated_indexes = cast( - "List[int]", state["_translated_indexes"] - ) - self._tuplefilter = tuplegetter(*self._translated_indexes) + translated_indexes: List[Any] + self._translated_indexes = translated_indexes = state[ + "_translated_indexes" + ] + self._tuplefilter = tuplegetter(*translated_indexes) else: self._translated_indexes = self._tuplefilter = None @@ -1537,20 +1565,19 @@ class CursorResult(Result[Unpack[_Ts]]): metadata = self._init_metadata(context, cursor_description) _make_row: Any + proc = metadata._effective_processors + tf = metadata._tuplefilter _make_row = functools.partial( Row, metadata, - metadata._effective_processors, + proc if tf is None or proc is None else tf(proc), metadata._key_to_index, ) - - if context._num_sentinel_cols: - sentinel_filter = operator.itemgetter( - slice(-context._num_sentinel_cols) - ) + if tf is not None: + _fixed_tf = tf # needed to make mypy happy... def _sliced_row(raw_data): - return _make_row(sentinel_filter(raw_data)) + return _make_row(_fixed_tf(raw_data)) sliced_row = _sliced_row else: @@ -1577,7 +1604,11 @@ class CursorResult(Result[Unpack[_Ts]]): assert context._num_sentinel_cols == 0 self._metadata = self._no_result_metadata - def _init_metadata(self, context, cursor_description): + def _init_metadata( + self, + context: DefaultExecutionContext, + cursor_description: _DBAPICursorDescription, + ) -> CursorResultMetaData: driver_column_names = context.execution_options.get( "driver_column_names", False ) @@ -1587,14 +1618,25 @@ class CursorResult(Result[Unpack[_Ts]]): metadata: CursorResultMetaData if driver_column_names: + # TODO: test this case metadata = CursorResultMetaData( - self, cursor_description, driver_column_names=True + self, + cursor_description, + driver_column_names=True, + num_sentinel_cols=context._num_sentinel_cols, ) assert not metadata._safe_for_cache elif compiled._cached_metadata: metadata = compiled._cached_metadata else: - metadata = CursorResultMetaData(self, cursor_description) + metadata = CursorResultMetaData( + self, + cursor_description, + # the number of sentinel columns is stored on the context + # but it's a characteristic of the compiled object + # so it's ok to apply it to a cacheable metadata. + num_sentinel_cols=context._num_sentinel_cols, + ) if metadata._safe_for_cache: compiled._cached_metadata = metadata @@ -1618,7 +1660,7 @@ class CursorResult(Result[Unpack[_Ts]]): ) and compiled._result_columns and context.cache_hit is context.dialect.CACHE_HIT - and compiled.statement is not context.invoked_statement + and compiled.statement is not context.invoked_statement # type: ignore[comparison-overlap] # noqa: E501 ): metadata = metadata._adapt_to_context(context) @@ -1838,7 +1880,9 @@ class CursorResult(Result[Unpack[_Ts]]): """ return self.context.returned_default_rows - def splice_horizontally(self, other): + def splice_horizontally( + self, other: CursorResult[Any] + ) -> CursorResult[Any]: """Return a new :class:`.CursorResult` that "horizontally splices" together the rows of this :class:`.CursorResult` with that of another :class:`.CursorResult`. @@ -1893,17 +1937,23 @@ class CursorResult(Result[Unpack[_Ts]]): """ # noqa: E501 - clone = self._generate() + clone: CursorResult[Any] = self._generate() + assert clone is self # just to note + assert isinstance(other._metadata, CursorResultMetaData) + assert isinstance(self._metadata, CursorResultMetaData) + self_tf = self._metadata._tuplefilter + other_tf = other._metadata._tuplefilter + clone._metadata = self._metadata._splice_horizontally(other._metadata) + total_rows = [ - tuple(r1) + tuple(r2) + tuple(r1 if self_tf is None else self_tf(r1)) + + tuple(r2 if other_tf is None else other_tf(r2)) for r1, r2 in zip( list(self._raw_row_iterator()), list(other._raw_row_iterator()), ) ] - clone._metadata = clone._metadata._splice_horizontally(other._metadata) - clone.cursor_strategy = FullyBufferedCursorFetchStrategy( None, initial_buffer=total_rows, @@ -1951,6 +2001,9 @@ class CursorResult(Result[Unpack[_Ts]]): :meth:`.Insert.return_defaults` along with the "supplemental columns" feature. + NOTE: this method has not effect then an unique filter is applied + to the result, meaning that no row will be returned. + """ if self._echo: @@ -1963,7 +2016,7 @@ class CursorResult(Result[Unpack[_Ts]]): # rows self._metadata = cast( CursorResultMetaData, self._metadata - )._remove_processors() + )._remove_processors_and_tuple_filter() self.cursor_strategy = FullyBufferedCursorFetchStrategy( None, diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4eb45c1d59..c8bdb56635 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1956,11 +1956,8 @@ class DefaultExecutionContext(ExecutionContext): strategy = _cursor._NO_CURSOR_DML elif self._num_sentinel_cols: assert self.execute_style is ExecuteStyle.INSERTMANYVALUES - # strip out the sentinel columns from cursor description - # a similar logic is done to the rows only in CursorResult - cursor_description = cursor_description[ - 0 : -self._num_sentinel_cols - ] + # the sentinel columns are handled in CursorResult._init_metadata + # using essentially _reduce result: _cursor.CursorResult[Any] = _cursor.CursorResult( self, strategy, cursor_description diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 46c85d6f6c..49b4b97dd7 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -325,7 +325,7 @@ class SimpleResultMetaData(ResultMetaData): ) def _index_for_key(self, key: Any, raiseerr: bool = True) -> int: - if int in key.__class__.__mro__: + if isinstance(key, int): key = self._keys[key] try: rec = self._keymap[key] @@ -341,7 +341,7 @@ class SimpleResultMetaData(ResultMetaData): self, keys: Sequence[Any] ) -> Iterator[_KeyMapRecType]: for key in keys: - if int in key.__class__.__mro__: + if isinstance(key, int): key = self._keys[key] try: @@ -354,9 +354,7 @@ class SimpleResultMetaData(ResultMetaData): def _reduce(self, keys: Sequence[Any]) -> ResultMetaData: try: metadata_for_keys = [ - self._keymap[ - self._keys[key] if int in key.__class__.__mro__ else key - ] + self._keymap[self._keys[key] if isinstance(key, int) else key] for key in keys ] except KeyError as ke: @@ -2187,7 +2185,8 @@ class FrozenResult(Generic[Unpack[_Ts]]): else: self.data = result.fetchall() - def rewrite_rows(self) -> Sequence[Sequence[Any]]: + def _rewrite_rows(self) -> Sequence[Sequence[Any]]: + # used only by the orm fn merge_frozen_result if self._source_supports_scalars: return [[elem] for elem in self.data] else: diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index deee8bc3ad..ff28e2e204 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -340,7 +340,7 @@ def merge_frozen_result(session, statement, frozen_result, load=True): ) result = [] - for newrow in frozen_result.rewrite_rows(): + for newrow in frozen_result._rewrite_rows(): for i in mapped_entities: if newrow[i] is not None: newrow[i] = session._merge( diff --git a/test/requirements.py b/test/requirements.py index 72b609f21f..cfd2c74dad 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1204,6 +1204,10 @@ class DefaultRequirements(SuiteRequirements): def sqlite_memory(self): return only_on(self._sqlite_memory_db) + @property + def sqlite_file(self): + return only_on(self._sqlite_file_db) + def _sqlite_partial_idx(self, config): if not against(config, "sqlite"): return False diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index f80b4c447e..a865bc1bb0 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -2959,3 +2959,211 @@ class IMVSentinelTest(fixtures.TestBase): coll = set eq_(coll(result), coll(expected_data)) + + @testing.variation("kind", ["returning", "returning_default"]) + @testing.variation("operation", ["none", "yield_per", "unique", "columns"]) + @testing.variation("has_processor", [True, False]) + @testing.variation("freeze", [True, False]) + @testing.variation("driver_column_names", [True, False]) + def test_generative_cases( + self, + connection, + metadata, + sort_by_parameter_order, + kind, + operation, + has_processor, + freeze, + driver_column_names, + ): + class MyInt(TypeDecorator): + cache_ok = True + impl = Integer + + def result_processor(self, dialect, coltype): + return str + + class MyStr(TypeDecorator): + cache_ok = True + impl = String(42) + + def result_processor(self, dialect, coltype): + return str.upper + + t1 = Table( + "t1", + metadata, + Column( + "id", + MyInt() if has_processor else Integer(), + primary_key=True, + test_needs_autoincrement=True, + ), + Column("data", MyStr() if has_processor else String(42)), + Column("w_d", String(42), server_default="foo"), + ) + + stmt = t1.insert() + data = [{"data": "a"}, {"data": "b"}, {"data": "c"}] + if kind.returning: + stmt = stmt.returning( + t1.c.data, + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + if has_processor: + expected = [("A",), ("B",), ("C",)] + else: + expected = [("a",), ("b",), ("c",)] + elif kind.returning_default: + stmt = stmt.return_defaults( + supplemental_cols=[t1.c.data], + sort_by_parameter_order=bool(sort_by_parameter_order), + ) + if has_processor: + expected = [ + ("1", "A", "foo"), + ("2", "B", "foo"), + ("3", "C", "foo"), + ] + else: + expected = [(1, "a", "foo"), (2, "b", "foo"), (3, "c", "foo")] + else: + kind.fail() + + if driver_column_names: + exec_options = {"driver_column_names": True} + else: + exec_options = {} + t1.create(connection) + r = connection.execute(stmt, data, execution_options=exec_options) + + orig_expected = expected + if operation.none: + pass + elif operation.yield_per: + r = r.yield_per(2) + elif operation.unique: + r = r.unique() + elif operation.columns: + r = r.columns("data", "data") + if has_processor: + expected = [("A", "A"), ("B", "B"), ("C", "C")] + else: + expected = [("a", "a"), ("b", "b"), ("c", "c")] + else: + operation.fail() + + if freeze: + rf = r.freeze() + res = rf().all() + else: + res = r.all() + eq_(res, expected) + + rr = r._rewind(res) + if operation.unique: + # TODO: this seems like a bug. maybe just document it? + eq_(rr.all(), []) + else: + eq_(rr.all(), expected) + + # re-execute to ensure it works also with the cache. The table is + # dropped and recreated to reset the autoincrement + t1.drop(connection) + t1.create(connection) + r2 = connection.execute(stmt, data, execution_options=exec_options) + eq_(r2.all(), orig_expected) + + @testing.variation("sentinel", ["left", "right", "both"]) + @testing.variation("has_processor", [True, False]) + @testing.variation("freeze", [True, False]) + @testing.skip_if(testing.requires.sqlite_file) + def test_splice_horizontally( + self, connection, metadata, sentinel, has_processor, freeze + ): + class MyInt(TypeDecorator): + cache_ok = True + impl = Integer + + def result_processor(self, dialect, coltype): + return str + + class MyStr(TypeDecorator): + cache_ok = True + impl = String(42) + + def result_processor(self, dialect, coltype): + return str.upper + + t1 = Table( + "t1", + metadata, + Column( + "id", + MyInt() if has_processor else Integer(), + primary_key=True, + test_needs_autoincrement=True, + ), + Column("data", MyStr() if has_processor else String(42)), + ) + t2 = Table( + "t2", + metadata, + Column( + "pk", + MyInt() if has_processor else Integer(), + primary_key=True, + test_needs_autoincrement=True, + ), + Column("dd", MyStr() if has_processor else String(42)), + ) + + left = t1.insert().returning( + t1.c.data, sort_by_parameter_order=sentinel.left or sentinel.both + ) + data_left = [{"data": "a"}, {"data": "b"}, {"data": "c"}] + right = t2.insert().returning( + t2.c.dd, sort_by_parameter_order=sentinel.right or sentinel.both + ) + data_right = [{"dd": "x"}, {"dd": "y"}, {"dd": "z"}] + if has_processor: + expected = [("A", "X"), ("B", "Y"), ("C", "Z")] + else: + expected = [("a", "x"), ("b", "y"), ("c", "z")] + + with config.db.connect() as c2: + t1.create(connection) + t2.create(c2) + rl = connection.execute(left, data_left) + rr = c2.execute(right, data_right) + + r = rl.splice_horizontally(rr) + if freeze: + rf = r.freeze() + res = rf().all() + else: + res = r.all() + eq_(res, expected) + rr = r._rewind(res) + eq_(rr.all(), expected) + + def test_sentinel_not_in_result(self, connection, metadata): + t1 = Table( + "t1", + metadata, + Column( + "id", + Integer(), + primary_key=True, + test_needs_autoincrement=True, + ), + Column("data", String(42)), + ) + stmt = t1.insert().returning(t1.c.data, sort_by_parameter_order=True) + t1.create(connection) + r = connection.execute(stmt, [{"data": "a"}, {"data": "b"}]) + + with expect_raises_message(IndexError, "list index out of range"): + r.scalars(1) + eq_(r.keys(), ["data"]) + eq_(r.all(), [("a",), ("b",)]) diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index f87c6520d9..ea9ce57a1f 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -158,47 +158,95 @@ class CursorResultTest(fixtures.TablesTest): ) @testing.requires.insert_executemany_returning - def test_splice_horizontally(self, connection): + @testing.variation("filters", ["unique", "sliced", "plain"]) + def test_splice_horizontally(self, connection, filters): users = self.tables.users addresses = self.tables.addresses - r1 = connection.execute( - users.insert().returning(users.c.user_name, users.c.user_id), - [ - dict(user_id=1, user_name="john"), - dict(user_id=2, user_name="jack"), - ], - ) + if filters.unique: + r1 = connection.execute( + users.insert().returning(users.c.user_name), + [ + dict(user_id=1, user_name="john"), + dict(user_id=2, user_name="john"), + ], + ) + r2 = connection.execute( + addresses.insert().returning( + addresses.c.address, + ), + [ + dict(address_id=1, user_id=1, address="foo@bar.com"), + dict(address_id=2, user_id=2, address="foo@bar.com"), + ], + ) + else: + r1 = connection.execute( + users.insert().returning(users.c.user_name, users.c.user_id), + [ + dict(user_id=1, user_name="john"), + dict(user_id=2, user_name="jack"), + ], + ) + r2 = connection.execute( + addresses.insert().returning( + addresses.c.address_id, + addresses.c.address, + addresses.c.user_id, + ), + [ + dict(address_id=1, user_id=1, address="foo@bar.com"), + dict(address_id=2, user_id=2, address="bar@bat.com"), + ], + ) - r2 = connection.execute( - addresses.insert().returning( - addresses.c.address_id, - addresses.c.address, - addresses.c.user_id, - ), - [ - dict(address_id=1, user_id=1, address="foo@bar.com"), - dict(address_id=2, user_id=2, address="bar@bat.com"), - ], - ) + if filters.sliced: + r1 = r1.columns(users.c.user_name) + r2 = r2.columns(addresses.c.address, addresses.c.user_id) + elif filters.unique: + r1 = r1.unique() + r2 = r2.unique() rows = r1.splice_horizontally(r2).all() - eq_( - rows, - [ - ("john", 1, 1, "foo@bar.com", 1), - ("jack", 2, 2, "bar@bat.com", 2), - ], - ) - eq_(rows[0]._mapping[users.c.user_id], 1) - eq_(rows[0]._mapping[addresses.c.user_id], 1) - eq_(rows[1].address, "bar@bat.com") + if filters.sliced: + eq_( + rows, + [ + ("john", "foo@bar.com", 1), + ("jack", "bar@bat.com", 2), + ], + ) + eq_(rows[0]._mapping[users.c.user_name], "john") + eq_(rows[0].address, "foo@bar.com") + elif filters.unique: + eq_( + rows, + [ + ("john", "foo@bar.com"), + ], + ) + eq_(rows[0]._mapping[users.c.user_name], "john") + eq_(rows[0].address, "foo@bar.com") + elif filters.plain: + eq_( + rows, + [ + ("john", 1, 1, "foo@bar.com", 1), + ("jack", 2, 2, "bar@bat.com", 2), + ], + ) - with expect_raises_message( - exc.InvalidRequestError, "Ambiguous column name 'user_id'" - ): - rows[0].user_id + eq_(rows[0]._mapping[users.c.user_id], 1) + eq_(rows[0]._mapping[addresses.c.user_id], 1) + eq_(rows[1].address, "bar@bat.com") + + with expect_raises_message( + exc.InvalidRequestError, "Ambiguous column name 'user_id'" + ): + rows[0].user_id + else: + filters.fail() def test_keys_no_rows(self, connection): for i in range(2):