]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow update.returing() to work with from_statement()
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 14 Feb 2021 04:21:04 +0000 (23:21 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 14 Feb 2021 19:51:45 +0000 (14:51 -0500)
The ORM used in :term:`2.0 style` can now return ORM objects from the rows
returned by an UPDATE..RETURNING or INSERT..RETURNING statement, by
supplying the construct to :meth:`_sql.Select.from_statement` in an ORM
context.

Change-Id: I59c9754ff1cb3184580dd5194ecd2971d4e7f8e8
References: #5940

doc/build/changelog/unreleased_14/orm_from_returning.rst [new file with mode: 0644]
doc/build/orm/persistence_techniques.rst
doc/build/orm/queryguide.rst
doc/build/orm/session_basics.rst
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/dml.py
test/orm/test_update_delete.py

diff --git a/doc/build/changelog/unreleased_14/orm_from_returning.rst b/doc/build/changelog/unreleased_14/orm_from_returning.rst
new file mode 100644 (file)
index 0000000..c3e720a
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: feature, orm
+
+    The ORM used in :term:`2.0 style` can now return ORM objects from the rows
+    returned by an UPDATE..RETURNING or INSERT..RETURNING statement, by
+    supplying the construct to :meth:`_sql.Select.from_statement` in an ORM
+    context.
+
+    .. seealso::
+
+      :ref:`orm_dml_returning_objects`
+
+
index c8daea2e6fb8ec7d2da1d259802629340f8cb0df..dad1f9f46081bad4b229f913e0ab0e548f6a10eb 100644 (file)
@@ -681,6 +681,11 @@ Bulk Operations
    bulk inserts, it's better to use the Core :class:`_sql.Insert` construct
    directly.   Please read all caveats at :ref:`bulk_operations_caveats`.
 
+.. note:: Bulk INSERT and UPDATE should not be confused with the
+   more common feature known as :ref:`orm_expression_update_delete`.   This
+   feature allows a single UPDATE or DELETE statement with arbitrary WHERE
+   criteria to be emitted.
+
 .. versionadded:: 1.0.0
 
 Bulk INSERT/per-row UPDATE operations on the :class:`.Session` include
index 5678d7cc73fefc08d504f17d6de5554c9689de96..7d23821383bad29a26bfc13dc27f02bca715d5ab 100644 (file)
@@ -365,6 +365,12 @@ is that in the former case, no subquery is produced in the resulting SQL.
 This can in some scenarios be advantageous from a performance or complexity
 perspective.
 
+.. seealso::
+
+  :ref:`orm_dml_returning_objects` - The :meth:`_sql.Select.from_statement`
+  method also works with :term:`DML` statements that support RETURNING.
+
+
 .. _orm_queryguide_joins:
 
 Joins
index fe6bb8a67da6bfff74708ddee67adf3558167a24..0ea797fa9511ef8c17a43730e2597166c49a76ca 100644 (file)
@@ -581,6 +581,10 @@ ORM-enabled delete, :term:`2.0 style`::
 
     session.execute(stmt)
 
+.. _orm_expression_update_delete_sync:
+
+Selecting a Synchronization Strategy
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 With both the 1.x and 2.0 form of ORM-enabled updates and deletes, the following
 values for ``synchronize_session`` are supported:
@@ -594,10 +598,12 @@ values for ``synchronize_session`` are supported:
   can lead to confusing results.
 
 * ``'fetch'`` - Retrieves the primary key identity of affected rows by either
-  performing a SELECT before the UPDATE or DELETE, or by using RETURNING
-  if the database supports it, so that in-memory objects which are affected
-  by the operation can be refreshed with new values (updates) or expunged
-  from the :class:`_orm.Session` (deletes)
+  performing a SELECT before the UPDATE or DELETE, or by using RETURNING if the
+  database supports it, so that in-memory objects which are affected by the
+  operation can be refreshed with new values (updates) or expunged from the
+  :class:`_orm.Session` (deletes). Note that this synchronization strategy is
+  not available if the given :func:`_dml.update` or :func:`_dml.delete`
+  construct specifies columns for :meth:`_dml.UpdateBase.returning` explicitly.
 
 * ``'evaluate'`` - Evaluate the WHERE criteria given in the UPDATE or DELETE
   statement in Python, to locate matching objects within the
@@ -669,7 +675,78 @@ values for ``synchronize_session`` are supported:
     * In order to intercept ORM-enabled UPDATE and DELETE operations with event
       handlers, use the :meth:`_orm.SessionEvents.do_orm_execute` event.
 
+.. _orm_dml_returning_objects:
+
+Selecting ORM Objects Inline with UPDATE.. RETURNING or INSERT..RETURNING
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. deepalchemy:: The feature of linking ORM objects to RETURNING is a new and
+   experimental feature.
+
+.. versionadded:: 1.4.0b3
+
+The :term:`DML` constructs :func:`_dml.insert`, :func:`_dml.update`, and
+:func:`_dml.delete` feature a method :meth:`_dml.UpdateBase.returning` which on
+database backends that support RETURNING (PostgreSQL, SQL Server, some MariaDB
+versions) may be used to return database rows generated or matched by
+the statement as though they were SELECTed. The ORM-enabled UPDATE and DELETE
+statements may be combined with this feature, so that they return rows
+corresponding to all the rows which were matched by the criteria::
+
+    from sqlalchemy import update
+
+    stmt = update(User).where(User.name == "squidward").values(name="spongebob").\
+        returning(User.id)
+
+    for row in session.execute(stmt):
+        print(f"id: {row.id}")
+
+The above example returns the ``User.id`` attribute for each row matched.
+Provided that each row contains at least a primary key value, we may opt to
+receive these rows as ORM objects, allowing ORM objects to be loaded from the
+database corresponding atomically to an UPDATE statement against those rows. To
+achieve this, we may combine the :class:`_dml.Update` construct which returns
+``User`` rows with a :func:`_sql.select` that's adapted to run this UPDATE
+statement in an ORM context using the :meth:`_sql.Select.from_statement`
+method::
+
+    stmt = update(User).where(User.name == "squidward").values(name="spongebob").\
+        returning(User)
+
+    orm_stmt = select(User).from_statement(stmt).execution_options(populate_existing=True)
+
+    for user in session.execute(orm_stmt).scalars():
+        print("updated user: %s" % user)
+
+Above, we produce an :func:`_dml.update` construct that includes
+:meth:`_dml.Update.returning` given the full ``User`` entity, which will
+produce complete rows from the database table as it UPDATEs them; any arbitrary
+set of columns to load may be specified as long as the full primary key is
+included. Next, these rows are adapted to an ORM load by producing a
+:func:`_sql.select` for the desired entity, then adapting it to the UPDATE
+statement by passing the :class:`_dml.Update` construct to the
+:meth:`_sql.Select.from_statement` method; this special ORM method, introduced
+at :ref:`orm_queryguide_selecting_text`, produces an ORM-specific adapter that
+allows the given statement to act as though it were the SELECT of rows that is
+first described.   No SELECT is actually emitted in the database, only the
+UPDATE..RETURNING we've constructed.
+
+Finally, we make use of :ref:`orm_queryguide_populate_existing` on the
+construct so that all the data returned by the UPDATE, including the columns
+we've updated, are populated into the returned objects, replacing any
+values which were there already.  This has the same effect as if we had
+used the ``synchronize_session='fetch'`` strategy described previously
+at :ref:`orm_expression_update_delete_sync`.
+
+The above approach can be used with INSERTs as well (and technically
+DELETEs too, though this makes less sense as the returned ORM objects
+by definition don't exist in the database anymore), as both of these
+constructs support RETURNING as well.
+
+.. seealso::
 
+  :ref:`orm_queryguide_selecting_text` - introduces the
+  :meth:`_sql.Select.from_statement` method.
 
 .. _session_committing:
 
index 674d5417949f39fcbe00797afb246d6f2beacc2b..a0aa67c69a6a59ce1524e600262ac20390b5ff87 100644 (file)
@@ -1526,7 +1526,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
         """Activate IDENTITY_INSERT if needed."""
 
         if self.isinsert:
-            tbl = self.compiled.statement.table
+            tbl = self.compiled.compile_state.dml_table
             id_column = tbl._autoincrement_column
             insert_has_identity = (id_column is not None) and (
                 not isinstance(id_column.default, Sequence)
@@ -1607,7 +1607,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
                 self._opt_encode(
                     "SET IDENTITY_INSERT %s OFF"
                     % self.identifier_preparer.format_table(
-                        self.compiled.statement.table
+                        self.compiled.compile_state.dml_table
                     )
                 ),
                 (),
@@ -1631,7 +1631,7 @@ class MSExecutionContext(default.DefaultExecutionContext):
                     self._opt_encode(
                         "SET IDENTITY_INSERT %s OFF"
                         % self.identifier_preparer.format_table(
-                            self.compiled.statement.table
+                            self.compiled.compile_state.dml_table
                         )
                     )
                 )
index fa192a17e5f440f47999ae0ecbb9ebe362af15c9..23bae5cc0821fa74c076a8d2575a14cdc604fd31 100644 (file)
@@ -357,6 +357,9 @@ class ORMFromStatementCompileState(ORMCompileState):
         self.statement_container = self.select_statement = statement_container
         self.requested_statement = statement = statement_container.element
 
+        if statement.is_dml:
+            self.dml_table = statement.table
+
         self._entities = []
         self._polymorphic_adapters = {}
         self._no_yield_pers = set()
@@ -367,6 +370,7 @@ class ORMFromStatementCompileState(ORMCompileState):
             self.use_legacy_query_style
             and isinstance(statement, expression.SelectBase)
             and not statement._is_textual
+            and not statement.is_dml
             and statement._label_style is LABEL_STYLE_NONE
         ):
             self.statement = statement.set_label_style(
@@ -377,7 +381,7 @@ class ORMFromStatementCompileState(ORMCompileState):
 
         self._label_convention = self._column_naming_convention(
             statement._label_style
-            if not statement._is_textual
+            if not statement._is_textual and not statement.is_dml
             else LABEL_STYLE_NONE,
             self.use_legacy_query_style,
         )
@@ -409,7 +413,9 @@ class ORMFromStatementCompileState(ORMCompileState):
 
         self.order_by = None
 
-        if isinstance(self.statement, expression.TextClause):
+        if isinstance(
+            self.statement, (expression.TextClause, expression.UpdateBase)
+        ):
             # setup for all entities. Currently, this is not useful
             # for eager loaders, as the eager loaders that work are able
             # to do their work entirely in row_processor.
@@ -790,12 +796,13 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         query = util.preloaded.orm_query
 
         from_statement = coercions.expect(
-            roles.SelectStatementRole,
+            roles.ReturnsRowsRole,
             from_statement,
             apply_propagate_attrs=statement,
         )
 
         stmt = query.FromStatement(statement._raw_columns, from_statement)
+
         stmt.__dict__.update(
             _with_options=statement._with_options,
             _with_context_options=statement._with_context_options,
index f19f29daa7e550d930dfb4f9a478be5f3e3e7273..7ab9eeda79d8eeeede8094f592f70ad1bc14f7ba 100644 (file)
@@ -2179,6 +2179,11 @@ class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
             compiler._annotations.get("synchronize_session", None) == "fetch"
             and compiler.dialect.full_returning
         ):
+            if new_stmt._returning:
+                raise sa_exc.InvalidRequestError(
+                    "Can't use synchronize_session='fetch' "
+                    "with explicit returning()"
+                )
             new_stmt = new_stmt.returning(*mapper.primary_key)
 
         UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
index 30cb9e73017a4f3f5da9fb5949c0df307447e0e9..c444b557bb814cda94ad115824a02b2fef88d546 100644 (file)
@@ -57,10 +57,12 @@ from ..sql.base import _generative
 from ..sql.base import Executable
 from ..sql.selectable import _SelectFromElements
 from ..sql.selectable import ForUpdateArg
+from ..sql.selectable import GroupedElement
 from ..sql.selectable import HasHints
 from ..sql.selectable import HasPrefixes
 from ..sql.selectable import HasSuffixes
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..sql.selectable import SelectBase
 from ..sql.selectable import SelectStatementGrouping
 from ..sql.visitors import InternalTraversal
 from ..util import collections_abc
@@ -3178,7 +3180,7 @@ class Query(
         return context
 
 
-class FromStatement(SelectStatementGrouping, Executable):
+class FromStatement(GroupedElement, SelectBase, Executable):
     """Core construct that represents a load of ORM objects from a finished
     select or text construct.
 
@@ -3210,7 +3212,19 @@ class FromStatement(SelectStatementGrouping, Executable):
             )
             for ent in util.to_list(entities)
         ]
-        super(FromStatement, self).__init__(element)
+        self.element = element
+
+    def get_label_style(self):
+        return self._label_style
+
+    def set_label_style(self, label_style):
+        return SelectStatementGrouping(
+            self.element.set_label_style(label_style)
+        )
+
+    @property
+    def _label_style(self):
+        return self.element._label_style
 
     def _compiler_dispatch(self, compiler, **kw):
 
@@ -3241,6 +3255,14 @@ class FromStatement(SelectStatementGrouping, Executable):
         for elem in super(FromStatement, self).get_children(**kw):
             yield elem
 
+    @property
+    def _returning(self):
+        return self.element._returning if self.element.is_dml else None
+
+    @property
+    def _inline(self):
+        return self.element._inline if self.element.is_dml else None
+
 
 class AliasOption(interfaces.LoaderOption):
     @util.deprecated(
index 3f492a490eb38ff74f59d0afc81dc98c8116b29e..ea10bfc27979e5830ab9933772299680e239139b 100644 (file)
@@ -47,6 +47,10 @@ class DMLState(CompileState):
     def __init__(self, statement, compiler, **kw):
         raise NotImplementedError()
 
+    @property
+    def dml_table(self):
+        return self.statement.table
+
     def _make_extra_froms(self, statement):
         froms = []
 
@@ -407,7 +411,9 @@ class UpdateBase(
             raise exc.InvalidRequestError(
                 "return_defaults() is already configured on this statement"
             )
-        self._returning += cols
+        self._returning += tuple(
+            coercions.expect(roles.ColumnsClauseRole, c) for c in cols
+        )
 
     def _exported_columns_iterator(self):
         """Return the RETURNING columns as a sequence for this statement.
index e350ee018c989aa1e129fcbc69ed814c3bbb6b99..d437748f16cb8de04fe5713419eec383a7ab9bc2 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
+from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import lambda_stmt
 from sqlalchemy import or_
@@ -28,6 +29,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import in_
 from sqlalchemy.testing import not_in
+from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
@@ -903,6 +905,58 @@ class UpdateDeleteTest(fixtures.MappedTest):
                 ),
             )
 
+    @testing.requires.full_returning
+    def test_update_explicit_returning(self):
+        User = self.classes.User
+
+        sess = fixture_session()
+
+        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
+
+        with self.sql_execution_asserter() as asserter:
+            stmt = (
+                update(User)
+                .filter(User.age > 29)
+                .values({"age": User.age - 10})
+                .returning(User.id)
+            )
+
+            rows = sess.execute(stmt).all()
+            eq_(rows, [(2,), (4,)])
+
+            # these are simple values, these are now evaluated even with
+            # the "fetch" strategy, new in 1.4, so there is no expiry
+            eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
+
+        asserter.assert_(
+            CompiledSQL(
+                "UPDATE users SET age_int=(users.age_int - %(age_int_1)s) "
+                "WHERE users.age_int > %(age_int_2)s RETURNING users.id",
+                [{"age_int_1": 10, "age_int_2": 29}],
+                dialect="postgresql",
+            ),
+        )
+
+    @testing.requires.full_returning
+    def test_no_fetch_w_explicit_returning(self):
+        User = self.classes.User
+
+        sess = fixture_session()
+
+        stmt = (
+            update(User)
+            .filter(User.age > 29)
+            .values({"age": User.age - 10})
+            .execution_options(synchronize_session="fetch")
+            .returning(User.id)
+        )
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            r"Can't use synchronize_session='fetch' "
+            r"with explicit returning\(\)",
+        ):
+            sess.execute(stmt)
+
     def test_delete_fetch_returning(self):
         User = self.classes.User
 
@@ -2019,3 +2073,94 @@ class SingleTablePolymorphicTest(fixtures.DeclarativeMappedTest):
                 ("support", "n2", "d"),
             ],
         )
+
+
+class LoadFromReturningTest(fixtures.MappedTest):
+    __backend__ = True
+    __requires__ = ("full_returning",)
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "users",
+            metadata,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("name", String(32)),
+            Column("age_int", Integer),
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class User(cls.Comparable):
+            pass
+
+        class Address(cls.Comparable):
+            pass
+
+    @classmethod
+    def insert_data(cls, connection):
+        users = cls.tables.users
+
+        connection.execute(
+            users.insert(),
+            [
+                dict(id=1, name="john", age_int=25),
+                dict(id=2, name="jack", age_int=47),
+                dict(id=3, name="jill", age_int=29),
+                dict(id=4, name="jane", age_int=37),
+            ],
+        )
+
+    @classmethod
+    def setup_mappers(cls):
+        User = cls.classes.User
+        users = cls.tables.users
+
+        mapper(
+            User,
+            users,
+            properties={
+                "age": users.c.age_int,
+            },
+        )
+
+    def test_load_from_update(self, connection):
+        User = self.classes.User
+
+        stmt = (
+            update(User)
+            .where(User.name.in_(["jack", "jill"]))
+            .values(age=User.age + 5)
+            .returning(User)
+        )
+
+        stmt = select(User).from_statement(stmt)
+
+        with Session(connection) as sess:
+            rows = sess.execute(stmt).scalars().all()
+
+            eq_(
+                rows,
+                [User(name="jack", age=52), User(name="jill", age=34)],
+            )
+
+    def test_load_from_insert(self, connection):
+        User = self.classes.User
+
+        stmt = (
+            insert(User)
+            .values({User.id: 5, User.age: 25, User.name: "spongebob"})
+            .returning(User)
+        )
+
+        stmt = select(User).from_statement(stmt)
+
+        with Session(connection) as sess:
+            rows = sess.execute(stmt).scalars().all()
+
+            eq_(
+                rows,
+                [User(name="spongebob", age=25)],
+            )