]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
include column.default, column.onupdate in eager_defaults
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Aug 2022 20:18:18 +0000 (16:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Aug 2022 14:07:15 +0000 (10:07 -0400)
Fixed bug in the behavior of the :paramref:`_orm.Mapper.eager_defaults`
parameter such that client-side SQL default or onupdate expressions in the
table definition alone will trigger a fetch operation using RETURNING or
SELECT when the ORM emits an INSERT or UPDATE for the row. Previously, only
server side defaults established as part of table DDL and/or server-side
onupdate expressions would trigger this fetch, even though client-side SQL
expressions would be included when the fetch was rendered.

Fixes: #7438
Change-Id: Iba719298ba4a26d185edec97ba77d2d54585e5a4

doc/build/changelog/unreleased_20/7438.rst [new file with mode: 0644]
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/dml.py
test/orm/test_unitofworkv2.py
test/sql/test_insert.py

diff --git a/doc/build/changelog/unreleased_20/7438.rst b/doc/build/changelog/unreleased_20/7438.rst
new file mode 100644 (file)
index 0000000..9aca391
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 7438
+
+    Fixed bug in the behavior of the :paramref:`_orm.Mapper.eager_defaults`
+    parameter such that client-side SQL default or onupdate expressions in the
+    table definition alone will trigger a fetch operation using RETURNING or
+    SELECT when the ORM emits an INSERT or UPDATE for the row. Previously, only
+    server side defaults established as part of table DDL and/or server-side
+    onupdate expressions would trigger this fetch, even though client-side SQL
+    expressions would be included when the fetch was rendered.
index 769b1b6236c451a4e5d3547ac7cd5a66a8d25081..6a95030b504836afc3884464abeaf5b9aec497f0 100644 (file)
@@ -28,6 +28,7 @@ from typing import cast
 from typing import Collection
 from typing import Deque
 from typing import Dict
+from typing import FrozenSet
 from typing import Generic
 from typing import Iterable
 from typing import Iterator
@@ -2397,15 +2398,21 @@ class Mapper(
         )
 
     @HasMemoized.memoized_attribute
-    def _server_default_cols(self):
+    def _server_default_cols(
+        self,
+    ) -> Mapping[FromClause, FrozenSet[Column[Any]]]:
         return dict(
             (
                 table,
                 frozenset(
                     [
-                        col.key
-                        for col in columns
+                        col
+                        for col in cast("Iterable[Column[Any]]", columns)
                         if col.server_default is not None
+                        or (
+                            col.default is not None
+                            and col.default.is_clause_element
+                        )
                     ]
                 ),
             )
@@ -2413,35 +2420,60 @@ class Mapper(
         )
 
     @HasMemoized.memoized_attribute
-    def _server_default_plus_onupdate_propkeys(self):
-        result = set()
-
-        for table, columns in self._cols_by_table.items():
-            for col in columns:
-                if (
-                    col.server_default is not None
-                    or col.server_onupdate is not None
-                ) and col in self._columntoproperty:
-                    result.add(self._columntoproperty[col].key)
-
-        return result
-
-    @HasMemoized.memoized_attribute
-    def _server_onupdate_default_cols(self):
+    def _server_onupdate_default_cols(
+        self,
+    ) -> Mapping[FromClause, FrozenSet[Column[Any]]]:
         return dict(
             (
                 table,
                 frozenset(
                     [
-                        col.key
-                        for col in columns
+                        col
+                        for col in cast("Iterable[Column[Any]]", columns)
                         if col.server_onupdate is not None
+                        or (
+                            col.onupdate is not None
+                            and col.onupdate.is_clause_element
+                        )
                     ]
                 ),
             )
             for table, columns in self._cols_by_table.items()
         )
 
+    @HasMemoized.memoized_attribute
+    def _server_default_col_keys(self) -> Mapping[FromClause, FrozenSet[str]]:
+        return {
+            table: frozenset(col.key for col in cols if col.key is not None)
+            for table, cols in self._server_default_cols.items()
+        }
+
+    @HasMemoized.memoized_attribute
+    def _server_onupdate_default_col_keys(
+        self,
+    ) -> Mapping[FromClause, FrozenSet[str]]:
+        return {
+            table: frozenset(col.key for col in cols if col.key is not None)
+            for table, cols in self._server_onupdate_default_cols.items()
+        }
+
+    @HasMemoized.memoized_attribute
+    def _server_default_plus_onupdate_propkeys(self) -> Set[str]:
+        result: Set[str] = set()
+
+        col_to_property = self._columntoproperty
+        for table, columns in self._server_default_cols.items():
+            result.update(
+                col_to_property[col].key
+                for col in columns.intersection(col_to_property)
+            )
+        for table, columns in self._server_onupdate_default_cols.items():
+            result.update(
+                col_to_property[col].key
+                for col in columns.intersection(col_to_property)
+            )
+        return result
+
     @HasMemoized.memoized_instancemethod
     def __clause_element__(self):
 
index c10f4701e031e2009a2ca2fb972eb99d111b8a5e..7cd66513ba6a8cda2094d871965d85aa4226d8ff 100644 (file)
@@ -561,9 +561,9 @@ def _collect_insert_commands(
             has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
 
             if mapper.base_mapper.eager_defaults:
-                has_all_defaults = mapper._server_default_cols[table].issubset(
-                    params
-                )
+                has_all_defaults = mapper._server_default_col_keys[
+                    table
+                ].issubset(params)
             else:
                 has_all_defaults = True
         else:
@@ -659,7 +659,7 @@ def _collect_update_commands(
 
             if mapper.base_mapper.eager_defaults:
                 has_all_defaults = (
-                    mapper._server_onupdate_default_cols[table]
+                    mapper._server_onupdate_default_col_keys[table]
                 ).issubset(params)
             else:
                 has_all_defaults = True
@@ -930,16 +930,20 @@ def _emit_update_statements(
         return_defaults = False
 
         if not has_all_pks:
-            statement = statement.return_defaults()
+            statement = statement.return_defaults(*mapper._pks_by_table[table])
             return_defaults = True
-        elif (
+
+        if (
             bookkeeping
             and not has_all_defaults
             and mapper.base_mapper.eager_defaults
         ):
-            statement = statement.return_defaults()
+            statement = statement.return_defaults(
+                *mapper._server_onupdate_default_cols[table]
+            )
             return_defaults = True
-        elif mapper.version_id_col is not None:
+
+        if mapper.version_id_col is not None:
             statement = statement.return_defaults(mapper.version_id_col)
             return_defaults = True
 
@@ -1171,8 +1175,10 @@ def _emit_insert_statements(
                 do_executemany = False
 
             if not has_all_defaults and base_mapper.eager_defaults:
-                statement = statement.return_defaults()
-            elif mapper.version_id_col is not None:
+                statement = statement.return_defaults(
+                    *mapper._server_default_cols[table]
+                )
+            if mapper.version_id_col is not None:
                 statement = statement.return_defaults(mapper.version_id_col)
             elif do_executemany:
                 statement = statement.return_defaults(*table.primary_key)
index 76a16eb1c84f90a3148631ea72e9fc054fc9f743..9d489ed983917cc9b2dd5db1e918557a71acdcac 100644 (file)
@@ -989,10 +989,26 @@ class ValuesBase(UpdateBase):
             :attr:`_engine.CursorResult.inserted_primary_key_rows`
 
         """
+
+        if self._return_defaults:
+            # note _return_defaults_columns = () means return all columns,
+            # so if we have been here before, only update collection if there
+            # are columns in the collection
+            if self._return_defaults_columns and cols:
+                self._return_defaults_columns = tuple(
+                    set(self._return_defaults_columns).union(
+                        coercions.expect(roles.ColumnsClauseRole, c)
+                        for c in cols
+                    )
+                )
+            else:
+                # set for all columns
+                self._return_defaults_columns = ()
+        else:
+            self._return_defaults_columns = tuple(
+                coercions.expect(roles.ColumnsClauseRole, c) for c in cols
+            )
         self._return_defaults = True
-        self._return_defaults_columns = tuple(
-            coercions.expect(roles.ColumnsClauseRole, c) for c in cols
-        )
         return self
 
 
index 68099a7a0ea7f0d5852e2db71135bc01c4b7373a..dd3b88915113e10a7c416060830ed14972ce49cd 100644 (file)
@@ -2353,6 +2353,21 @@ class EagerDefaultsTest(fixtures.MappedTest):
             Column("bar", Integer, server_onupdate=FetchedValue()),
         )
 
+        Table(
+            "test3",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("foo", String(50), default=func.lower("HI")),
+        )
+
+        Table(
+            "test4",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("foo", Integer),
+            Column("bar", Integer, onupdate=text("5 + 3")),
+        )
+
     @classmethod
     def setup_classes(cls):
         class Thing(cls.Basic):
@@ -2361,6 +2376,12 @@ class EagerDefaultsTest(fixtures.MappedTest):
         class Thing2(cls.Basic):
             pass
 
+        class Thing3(cls.Basic):
+            pass
+
+        class Thing4(cls.Basic):
+            pass
+
     @classmethod
     def setup_mappers(cls):
         Thing = cls.classes.Thing
@@ -2375,7 +2396,19 @@ class EagerDefaultsTest(fixtures.MappedTest):
             Thing2, cls.tables.test2, eager_defaults=True
         )
 
-    def test_insert_defaults_present(self):
+        Thing3 = cls.classes.Thing3
+
+        cls.mapper_registry.map_imperatively(
+            Thing3, cls.tables.test3, eager_defaults=True
+        )
+
+        Thing4 = cls.classes.Thing4
+
+        cls.mapper_registry.map_imperatively(
+            Thing4, cls.tables.test4, eager_defaults=True
+        )
+
+    def test_server_insert_defaults_present(self):
         Thing = self.classes.Thing
         s = fixture_session()
 
@@ -2388,7 +2421,10 @@ class EagerDefaultsTest(fixtures.MappedTest):
             s.flush,
             CompiledSQL(
                 "INSERT INTO test (id, foo) VALUES (:id, :foo)",
-                [{"foo": 5, "id": 1}, {"foo": 10, "id": 2}],
+                [
+                    {"foo": 5, "id": 1},
+                    {"foo": 10, "id": 2},
+                ],
             ),
         )
 
@@ -2398,7 +2434,7 @@ class EagerDefaultsTest(fixtures.MappedTest):
 
         self.assert_sql_count(testing.db, go, 0)
 
-    def test_insert_defaults_present_as_expr(self):
+    def test_server_insert_defaults_present_as_expr(self):
         Thing = self.classes.Thing
         s = fixture_session()
 
@@ -2414,13 +2450,15 @@ class EagerDefaultsTest(fixtures.MappedTest):
                 testing.db,
                 s.flush,
                 CompiledSQL(
-                    "INSERT INTO test (id, foo) VALUES (%(id)s, 2 + 5) "
+                    "INSERT INTO test (id, foo) "
+                    "VALUES (%(id)s, 2 + 5) "
                     "RETURNING test.foo",
                     [{"id": 1}],
                     dialect="postgresql",
                 ),
                 CompiledSQL(
-                    "INSERT INTO test (id, foo) VALUES (%(id)s, 5 + 5) "
+                    "INSERT INTO test (id, foo) "
+                    "VALUES (%(id)s, 5 + 5) "
                     "RETURNING test.foo",
                     [{"id": 2}],
                     dialect="postgresql",
@@ -2457,7 +2495,7 @@ class EagerDefaultsTest(fixtures.MappedTest):
 
         self.assert_sql_count(testing.db, go, 0)
 
-    def test_insert_defaults_nonpresent(self):
+    def test_server_insert_defaults_nonpresent(self):
         Thing = self.classes.Thing
         s = fixture_session()
 
@@ -2516,7 +2554,73 @@ class EagerDefaultsTest(fixtures.MappedTest):
             ),
         )
 
-    def test_update_defaults_nonpresent(self):
+    def test_clientsql_insert_defaults_nonpresent(self):
+        Thing3 = self.classes.Thing3
+        s = fixture_session()
+
+        t1, t2 = (Thing3(id=1), Thing3(id=2))
+
+        s.add_all([t1, t2])
+
+        self.assert_sql_execution(
+            testing.db,
+            s.commit,
+            Conditional(
+                testing.db.dialect.insert_returning,
+                [
+                    Conditional(
+                        testing.db.dialect.insert_executemany_returning,
+                        [
+                            CompiledSQL(
+                                "INSERT INTO test3 (id, foo) "
+                                "VALUES (%(id)s, lower(%(lower_1)s)) "
+                                "RETURNING test3.foo",
+                                [{"id": 1}, {"id": 2}],
+                                dialect="postgresql",
+                            ),
+                        ],
+                        [
+                            CompiledSQL(
+                                "INSERT INTO test3 (id, foo) "
+                                "VALUES (%(id)s, lower(%(lower_1)s)) "
+                                "RETURNING test3.foo",
+                                [{"id": 1}],
+                                dialect="postgresql",
+                            ),
+                            CompiledSQL(
+                                "INSERT INTO test3 (id, foo) "
+                                "VALUES (%(id)s, lower(%(lower_1)s)) "
+                                "RETURNING test3.foo",
+                                [{"id": 2}],
+                                dialect="postgresql",
+                            ),
+                        ],
+                    ),
+                ],
+                [
+                    CompiledSQL(
+                        "INSERT INTO test3 (id, foo) "
+                        "VALUES (:id, lower(:lower_1))",
+                        [
+                            {"id": 1, "lower_1": "HI"},
+                            {"id": 2, "lower_1": "HI"},
+                        ],
+                    ),
+                    CompiledSQL(
+                        "SELECT test3.foo AS test3_foo "
+                        "FROM test3 WHERE test3.id = :pk_1",
+                        [{"pk_1": 1}],
+                    ),
+                    CompiledSQL(
+                        "SELECT test3.foo AS test3_foo "
+                        "FROM test3 WHERE test3.id = :pk_1",
+                        [{"pk_1": 2}],
+                    ),
+                ],
+            ),
+        )
+
+    def test_server_update_defaults_nonpresent(self):
         Thing2 = self.classes.Thing2
         s = fixture_session()
 
@@ -2611,6 +2715,101 @@ class EagerDefaultsTest(fixtures.MappedTest):
 
         self.assert_sql_count(testing.db, go, 0)
 
+    def test_clientsql_update_defaults_nonpresent(self):
+        Thing4 = self.classes.Thing4
+        s = fixture_session()
+
+        t1, t2, t3, t4 = (
+            Thing4(id=1, foo=1),
+            Thing4(id=2, foo=2),
+            Thing4(id=3, foo=3),
+            Thing4(id=4, foo=4),
+        )
+
+        s.add_all([t1, t2, t3, t4])
+        s.flush()
+
+        t1.foo = 5
+        t2.foo = 6
+        t2.bar = 10
+        t3.foo = 7
+        t4.foo = 8
+        t4.bar = 12
+
+        self.assert_sql_execution(
+            testing.db,
+            s.flush,
+            Conditional(
+                testing.db.dialect.update_returning,
+                [
+                    CompiledSQL(
+                        "UPDATE test4 SET foo=%(foo)s, bar=5 + 3 "
+                        "WHERE test4.id = %(test4_id)s RETURNING test4.bar",
+                        [{"foo": 5, "test4_id": 1}],
+                        dialect="postgresql",
+                    ),
+                    CompiledSQL(
+                        "UPDATE test4 SET foo=%(foo)s, bar=%(bar)s "
+                        "WHERE test4.id = %(test4_id)s",
+                        [{"foo": 6, "bar": 10, "test4_id": 2}],
+                        dialect="postgresql",
+                    ),
+                    CompiledSQL(
+                        "UPDATE test4 SET foo=%(foo)s, bar=5 + 3 WHERE "
+                        "test4.id = %(test4_id)s RETURNING test4.bar",
+                        [{"foo": 7, "test4_id": 3}],
+                        dialect="postgresql",
+                    ),
+                    CompiledSQL(
+                        "UPDATE test4 SET foo=%(foo)s, bar=%(bar)s WHERE "
+                        "test4.id = %(test4_id)s",
+                        [{"foo": 8, "bar": 12, "test4_id": 4}],
+                        dialect="postgresql",
+                    ),
+                ],
+                [
+                    CompiledSQL(
+                        "UPDATE test4 SET foo=:foo, bar=5 + 3 "
+                        "WHERE test4.id = :test4_id",
+                        [{"foo": 5, "test4_id": 1}],
+                    ),
+                    CompiledSQL(
+                        "UPDATE test4 SET foo=:foo, bar=:bar "
+                        "WHERE test4.id = :test4_id",
+                        [{"foo": 6, "bar": 10, "test4_id": 2}],
+                    ),
+                    CompiledSQL(
+                        "UPDATE test4 SET foo=:foo, bar=5 + 3 "
+                        "WHERE test4.id = :test4_id",
+                        [{"foo": 7, "test4_id": 3}],
+                    ),
+                    CompiledSQL(
+                        "UPDATE test4 SET foo=:foo, bar=:bar "
+                        "WHERE test4.id = :test4_id",
+                        [{"foo": 8, "bar": 12, "test4_id": 4}],
+                    ),
+                    CompiledSQL(
+                        "SELECT test4.bar AS test4_bar FROM test4 "
+                        "WHERE test4.id = :pk_1",
+                        [{"pk_1": 1}],
+                    ),
+                    CompiledSQL(
+                        "SELECT test4.bar AS test4_bar FROM test4 "
+                        "WHERE test4.id = :pk_1",
+                        [{"pk_1": 3}],
+                    ),
+                ],
+            ),
+        )
+
+        def go():
+            eq_(t1.bar, 8)
+            eq_(t2.bar, 10)
+            eq_(t3.bar, 8)
+            eq_(t4.bar, 12)
+
+        self.assert_sql_count(testing.db, go, 0)
+
     def test_update_defaults_present_as_expr(self):
         Thing2 = self.classes.Thing2
         s = fixture_session()
index 071f595f394d5b100a5cc404688101ddabfe68c7..61e0783e4733d6018392cc264f17947bf05b000c 100644 (file)
@@ -1,4 +1,7 @@
 #! coding:utf-8
+from __future__ import annotations
+
+from typing import Tuple
 
 from sqlalchemy import bindparam
 from sqlalchemy import Column
@@ -66,6 +69,30 @@ class _InsertTestBase:
 class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
     __dialect__ = "default"
 
+    @testing.combinations(
+        ((), ("z",), ()),
+        (("x",), (), ()),
+        (("x",), ("y",), ("x", "y")),
+        (("x", "y"), ("y",), ("x", "y")),
+    )
+    def test_return_defaults_generative(
+        self,
+        initial_keys: Tuple[str, ...],
+        second_keys: Tuple[str, ...],
+        expected_keys: Tuple[str, ...],
+    ):
+        t = table("foo", column("x"), column("y"), column("z"))
+
+        initial_cols = tuple(t.c[initial_keys])
+        second_cols = tuple(t.c[second_keys])
+        expected = set(t.c[expected_keys])
+
+        stmt = t.insert().return_defaults(*initial_cols)
+        eq_(stmt._return_defaults_columns, initial_cols)
+        stmt = stmt.return_defaults(*second_cols)
+        assert isinstance(stmt._return_defaults_columns, tuple)
+        eq_(set(stmt._return_defaults_columns), expected)
+
     def test_binds_that_match_columns(self):
         """test bind params named after column names
         replace the normal SET/VALUES generation."""