From 82a1d4096fbfe94e2fa626d65d5c3beb2c6afa37 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 2 Aug 2022 16:18:18 -0400 Subject: [PATCH] include column.default, column.onupdate in eager_defaults Fixed bug in the behavior of the :paramref:`_orm.Mapper.eager_defaults` parameter such that client-side SQL default or onupdate expressions in the table definition alone will trigger a fetch operation using RETURNING or SELECT when the ORM emits an INSERT or UPDATE for the row. Previously, only server side defaults established as part of table DDL and/or server-side onupdate expressions would trigger this fetch, even though client-side SQL expressions would be included when the fetch was rendered. Fixes: #7438 Change-Id: Iba719298ba4a26d185edec97ba77d2d54585e5a4 --- doc/build/changelog/unreleased_20/7438.rst | 11 ++ lib/sqlalchemy/orm/mapper.py | 72 +++++-- lib/sqlalchemy/orm/persistence.py | 26 ++- lib/sqlalchemy/sql/dml.py | 22 ++- test/orm/test_unitofworkv2.py | 213 ++++++++++++++++++++- test/sql/test_insert.py | 27 +++ 6 files changed, 331 insertions(+), 40 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/7438.rst diff --git a/doc/build/changelog/unreleased_20/7438.rst b/doc/build/changelog/unreleased_20/7438.rst new file mode 100644 index 0000000000..9aca39171d --- /dev/null +++ b/doc/build/changelog/unreleased_20/7438.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, orm + :tickets: 7438 + + Fixed bug in the behavior of the :paramref:`_orm.Mapper.eager_defaults` + parameter such that client-side SQL default or onupdate expressions in the + table definition alone will trigger a fetch operation using RETURNING or + SELECT when the ORM emits an INSERT or UPDATE for the row. Previously, only + server side defaults established as part of table DDL and/or server-side + onupdate expressions would trigger this fetch, even though client-side SQL + expressions would be included when the fetch was rendered. diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 769b1b6236..6a95030b50 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -28,6 +28,7 @@ from typing import cast from typing import Collection from typing import Deque from typing import Dict +from typing import FrozenSet from typing import Generic from typing import Iterable from typing import Iterator @@ -2397,15 +2398,21 @@ class Mapper( ) @HasMemoized.memoized_attribute - def _server_default_cols(self): + def _server_default_cols( + self, + ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: return dict( ( table, frozenset( [ - col.key - for col in columns + col + for col in cast("Iterable[Column[Any]]", columns) if col.server_default is not None + or ( + col.default is not None + and col.default.is_clause_element + ) ] ), ) @@ -2413,35 +2420,60 @@ class Mapper( ) @HasMemoized.memoized_attribute - def _server_default_plus_onupdate_propkeys(self): - result = set() - - for table, columns in self._cols_by_table.items(): - for col in columns: - if ( - col.server_default is not None - or col.server_onupdate is not None - ) and col in self._columntoproperty: - result.add(self._columntoproperty[col].key) - - return result - - @HasMemoized.memoized_attribute - def _server_onupdate_default_cols(self): + def _server_onupdate_default_cols( + self, + ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: return dict( ( table, frozenset( [ - col.key - for col in columns + col + for col in cast("Iterable[Column[Any]]", columns) if col.server_onupdate is not None + or ( + col.onupdate is not None + and col.onupdate.is_clause_element + ) ] ), ) for table, columns in self._cols_by_table.items() ) + @HasMemoized.memoized_attribute + def _server_default_col_keys(self) -> Mapping[FromClause, FrozenSet[str]]: + return { + table: frozenset(col.key for col in cols if col.key is not None) + for table, cols in self._server_default_cols.items() + } + + @HasMemoized.memoized_attribute + def _server_onupdate_default_col_keys( + self, + ) -> Mapping[FromClause, FrozenSet[str]]: + return { + table: frozenset(col.key for col in cols if col.key is not None) + for table, cols in self._server_onupdate_default_cols.items() + } + + @HasMemoized.memoized_attribute + def _server_default_plus_onupdate_propkeys(self) -> Set[str]: + result: Set[str] = set() + + col_to_property = self._columntoproperty + for table, columns in self._server_default_cols.items(): + result.update( + col_to_property[col].key + for col in columns.intersection(col_to_property) + ) + for table, columns in self._server_onupdate_default_cols.items(): + result.update( + col_to_property[col].key + for col in columns.intersection(col_to_property) + ) + return result + @HasMemoized.memoized_instancemethod def __clause_element__(self): diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index c10f4701e0..7cd66513ba 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -561,9 +561,9 @@ def _collect_insert_commands( has_all_pks = mapper._pk_keys_by_table[table].issubset(params) if mapper.base_mapper.eager_defaults: - has_all_defaults = mapper._server_default_cols[table].issubset( - params - ) + has_all_defaults = mapper._server_default_col_keys[ + table + ].issubset(params) else: has_all_defaults = True else: @@ -659,7 +659,7 @@ def _collect_update_commands( if mapper.base_mapper.eager_defaults: has_all_defaults = ( - mapper._server_onupdate_default_cols[table] + mapper._server_onupdate_default_col_keys[table] ).issubset(params) else: has_all_defaults = True @@ -930,16 +930,20 @@ def _emit_update_statements( return_defaults = False if not has_all_pks: - statement = statement.return_defaults() + statement = statement.return_defaults(*mapper._pks_by_table[table]) return_defaults = True - elif ( + + if ( bookkeeping and not has_all_defaults and mapper.base_mapper.eager_defaults ): - statement = statement.return_defaults() + statement = statement.return_defaults( + *mapper._server_onupdate_default_cols[table] + ) return_defaults = True - elif mapper.version_id_col is not None: + + if mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) return_defaults = True @@ -1171,8 +1175,10 @@ def _emit_insert_statements( do_executemany = False if not has_all_defaults and base_mapper.eager_defaults: - statement = statement.return_defaults() - elif mapper.version_id_col is not None: + statement = statement.return_defaults( + *mapper._server_default_cols[table] + ) + if mapper.version_id_col is not None: statement = statement.return_defaults(mapper.version_id_col) elif do_executemany: statement = statement.return_defaults(*table.primary_key) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 76a16eb1c8..9d489ed983 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -989,10 +989,26 @@ class ValuesBase(UpdateBase): :attr:`_engine.CursorResult.inserted_primary_key_rows` """ + + if self._return_defaults: + # note _return_defaults_columns = () means return all columns, + # so if we have been here before, only update collection if there + # are columns in the collection + if self._return_defaults_columns and cols: + self._return_defaults_columns = tuple( + set(self._return_defaults_columns).union( + coercions.expect(roles.ColumnsClauseRole, c) + for c in cols + ) + ) + else: + # set for all columns + self._return_defaults_columns = () + else: + self._return_defaults_columns = tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) self._return_defaults = True - self._return_defaults_columns = tuple( - coercions.expect(roles.ColumnsClauseRole, c) for c in cols - ) return self diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 68099a7a0e..dd3b889151 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -2353,6 +2353,21 @@ class EagerDefaultsTest(fixtures.MappedTest): Column("bar", Integer, server_onupdate=FetchedValue()), ) + Table( + "test3", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", String(50), default=func.lower("HI")), + ) + + Table( + "test4", + metadata, + Column("id", Integer, primary_key=True), + Column("foo", Integer), + Column("bar", Integer, onupdate=text("5 + 3")), + ) + @classmethod def setup_classes(cls): class Thing(cls.Basic): @@ -2361,6 +2376,12 @@ class EagerDefaultsTest(fixtures.MappedTest): class Thing2(cls.Basic): pass + class Thing3(cls.Basic): + pass + + class Thing4(cls.Basic): + pass + @classmethod def setup_mappers(cls): Thing = cls.classes.Thing @@ -2375,7 +2396,19 @@ class EagerDefaultsTest(fixtures.MappedTest): Thing2, cls.tables.test2, eager_defaults=True ) - def test_insert_defaults_present(self): + Thing3 = cls.classes.Thing3 + + cls.mapper_registry.map_imperatively( + Thing3, cls.tables.test3, eager_defaults=True + ) + + Thing4 = cls.classes.Thing4 + + cls.mapper_registry.map_imperatively( + Thing4, cls.tables.test4, eager_defaults=True + ) + + def test_server_insert_defaults_present(self): Thing = self.classes.Thing s = fixture_session() @@ -2388,7 +2421,10 @@ class EagerDefaultsTest(fixtures.MappedTest): s.flush, CompiledSQL( "INSERT INTO test (id, foo) VALUES (:id, :foo)", - [{"foo": 5, "id": 1}, {"foo": 10, "id": 2}], + [ + {"foo": 5, "id": 1}, + {"foo": 10, "id": 2}, + ], ), ) @@ -2398,7 +2434,7 @@ class EagerDefaultsTest(fixtures.MappedTest): self.assert_sql_count(testing.db, go, 0) - def test_insert_defaults_present_as_expr(self): + def test_server_insert_defaults_present_as_expr(self): Thing = self.classes.Thing s = fixture_session() @@ -2414,13 +2450,15 @@ class EagerDefaultsTest(fixtures.MappedTest): testing.db, s.flush, CompiledSQL( - "INSERT INTO test (id, foo) VALUES (%(id)s, 2 + 5) " + "INSERT INTO test (id, foo) " + "VALUES (%(id)s, 2 + 5) " "RETURNING test.foo", [{"id": 1}], dialect="postgresql", ), CompiledSQL( - "INSERT INTO test (id, foo) VALUES (%(id)s, 5 + 5) " + "INSERT INTO test (id, foo) " + "VALUES (%(id)s, 5 + 5) " "RETURNING test.foo", [{"id": 2}], dialect="postgresql", @@ -2457,7 +2495,7 @@ class EagerDefaultsTest(fixtures.MappedTest): self.assert_sql_count(testing.db, go, 0) - def test_insert_defaults_nonpresent(self): + def test_server_insert_defaults_nonpresent(self): Thing = self.classes.Thing s = fixture_session() @@ -2516,7 +2554,73 @@ class EagerDefaultsTest(fixtures.MappedTest): ), ) - def test_update_defaults_nonpresent(self): + def test_clientsql_insert_defaults_nonpresent(self): + Thing3 = self.classes.Thing3 + s = fixture_session() + + t1, t2 = (Thing3(id=1), Thing3(id=2)) + + s.add_all([t1, t2]) + + self.assert_sql_execution( + testing.db, + s.commit, + Conditional( + testing.db.dialect.insert_returning, + [ + Conditional( + testing.db.dialect.insert_executemany_returning, + [ + CompiledSQL( + "INSERT INTO test3 (id, foo) " + "VALUES (%(id)s, lower(%(lower_1)s)) " + "RETURNING test3.foo", + [{"id": 1}, {"id": 2}], + dialect="postgresql", + ), + ], + [ + CompiledSQL( + "INSERT INTO test3 (id, foo) " + "VALUES (%(id)s, lower(%(lower_1)s)) " + "RETURNING test3.foo", + [{"id": 1}], + dialect="postgresql", + ), + CompiledSQL( + "INSERT INTO test3 (id, foo) " + "VALUES (%(id)s, lower(%(lower_1)s)) " + "RETURNING test3.foo", + [{"id": 2}], + dialect="postgresql", + ), + ], + ), + ], + [ + CompiledSQL( + "INSERT INTO test3 (id, foo) " + "VALUES (:id, lower(:lower_1))", + [ + {"id": 1, "lower_1": "HI"}, + {"id": 2, "lower_1": "HI"}, + ], + ), + CompiledSQL( + "SELECT test3.foo AS test3_foo " + "FROM test3 WHERE test3.id = :pk_1", + [{"pk_1": 1}], + ), + CompiledSQL( + "SELECT test3.foo AS test3_foo " + "FROM test3 WHERE test3.id = :pk_1", + [{"pk_1": 2}], + ), + ], + ), + ) + + def test_server_update_defaults_nonpresent(self): Thing2 = self.classes.Thing2 s = fixture_session() @@ -2611,6 +2715,101 @@ class EagerDefaultsTest(fixtures.MappedTest): self.assert_sql_count(testing.db, go, 0) + def test_clientsql_update_defaults_nonpresent(self): + Thing4 = self.classes.Thing4 + s = fixture_session() + + t1, t2, t3, t4 = ( + Thing4(id=1, foo=1), + Thing4(id=2, foo=2), + Thing4(id=3, foo=3), + Thing4(id=4, foo=4), + ) + + s.add_all([t1, t2, t3, t4]) + s.flush() + + t1.foo = 5 + t2.foo = 6 + t2.bar = 10 + t3.foo = 7 + t4.foo = 8 + t4.bar = 12 + + self.assert_sql_execution( + testing.db, + s.flush, + Conditional( + testing.db.dialect.update_returning, + [ + CompiledSQL( + "UPDATE test4 SET foo=%(foo)s, bar=5 + 3 " + "WHERE test4.id = %(test4_id)s RETURNING test4.bar", + [{"foo": 5, "test4_id": 1}], + dialect="postgresql", + ), + CompiledSQL( + "UPDATE test4 SET foo=%(foo)s, bar=%(bar)s " + "WHERE test4.id = %(test4_id)s", + [{"foo": 6, "bar": 10, "test4_id": 2}], + dialect="postgresql", + ), + CompiledSQL( + "UPDATE test4 SET foo=%(foo)s, bar=5 + 3 WHERE " + "test4.id = %(test4_id)s RETURNING test4.bar", + [{"foo": 7, "test4_id": 3}], + dialect="postgresql", + ), + CompiledSQL( + "UPDATE test4 SET foo=%(foo)s, bar=%(bar)s WHERE " + "test4.id = %(test4_id)s", + [{"foo": 8, "bar": 12, "test4_id": 4}], + dialect="postgresql", + ), + ], + [ + CompiledSQL( + "UPDATE test4 SET foo=:foo, bar=5 + 3 " + "WHERE test4.id = :test4_id", + [{"foo": 5, "test4_id": 1}], + ), + CompiledSQL( + "UPDATE test4 SET foo=:foo, bar=:bar " + "WHERE test4.id = :test4_id", + [{"foo": 6, "bar": 10, "test4_id": 2}], + ), + CompiledSQL( + "UPDATE test4 SET foo=:foo, bar=5 + 3 " + "WHERE test4.id = :test4_id", + [{"foo": 7, "test4_id": 3}], + ), + CompiledSQL( + "UPDATE test4 SET foo=:foo, bar=:bar " + "WHERE test4.id = :test4_id", + [{"foo": 8, "bar": 12, "test4_id": 4}], + ), + CompiledSQL( + "SELECT test4.bar AS test4_bar FROM test4 " + "WHERE test4.id = :pk_1", + [{"pk_1": 1}], + ), + CompiledSQL( + "SELECT test4.bar AS test4_bar FROM test4 " + "WHERE test4.id = :pk_1", + [{"pk_1": 3}], + ), + ], + ), + ) + + def go(): + eq_(t1.bar, 8) + eq_(t2.bar, 10) + eq_(t3.bar, 8) + eq_(t4.bar, 12) + + self.assert_sql_count(testing.db, go, 0) + def test_update_defaults_present_as_expr(self): Thing2 = self.classes.Thing2 s = fixture_session() diff --git a/test/sql/test_insert.py b/test/sql/test_insert.py index 071f595f39..61e0783e47 100644 --- a/test/sql/test_insert.py +++ b/test/sql/test_insert.py @@ -1,4 +1,7 @@ #! coding:utf-8 +from __future__ import annotations + +from typing import Tuple from sqlalchemy import bindparam from sqlalchemy import Column @@ -66,6 +69,30 @@ class _InsertTestBase: class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL): __dialect__ = "default" + @testing.combinations( + ((), ("z",), ()), + (("x",), (), ()), + (("x",), ("y",), ("x", "y")), + (("x", "y"), ("y",), ("x", "y")), + ) + def test_return_defaults_generative( + self, + initial_keys: Tuple[str, ...], + second_keys: Tuple[str, ...], + expected_keys: Tuple[str, ...], + ): + t = table("foo", column("x"), column("y"), column("z")) + + initial_cols = tuple(t.c[initial_keys]) + second_cols = tuple(t.c[second_keys]) + expected = set(t.c[expected_keys]) + + stmt = t.insert().return_defaults(*initial_cols) + eq_(stmt._return_defaults_columns, initial_cols) + stmt = stmt.return_defaults(*second_cols) + assert isinstance(stmt._return_defaults_columns, tuple) + eq_(set(stmt._return_defaults_columns), expected) + def test_binds_that_match_columns(self): """test bind params named after column names replace the normal SET/VALUES generation.""" -- 2.47.2