]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ORM bulk insert via execute
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 7 Aug 2022 16:14:19 +0000 (12:14 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Sep 2022 15:18:01 +0000 (11:18 -0400)
* ORM Insert now includes "bulk" mode that will run
  essentially the same process as session.bulk_insert_mappings;
  interprets the given list of values as ORM attributes for
  key names
* ORM UPDATE has a similar feature, without RETURNING support,
  for session.bulk_update_mappings
* Added support for upserts to do RETURNING ORM objects as well
* ORM UPDATE/DELETE with list of parameters + WHERE criteria
  is a not implemented; use connection
* ORM UPDATE/DELETE defaults to "auto" synchronize_session;
  use fetch if RETURNING is present, evaluate if not, as
  "fetch" is much more efficient (no expired object SELECT problem)
  and less error prone if RETURNING is available
  UPDATE: howver this is inefficient!   please continue to
  use evaluate for simple cases, auto can move to fetch
  if criteria not evaluable
* "Evaluate" criteria will now not preemptively
  unexpire and SELECT attributes that were individually
  expired. Instead, if evaluation of the criteria indicates that
  the necessary attrs were expired, we expire the object
  completely (delete) or expire the SET attrs unconditionally
  (update). This keeps the object in the same unloaded state
  where it will refresh those attrs on the next pass, for
  this generally unusual case.  (originally #5664)
* Core change! update/delete rowcount comes from len(rows)
  if RETURNING was used.  SQLite at least otherwise did not
  support this.  adjusted test_rowcount accordingly
* ORM DELETE with a list of parameters at all is also a not
  implemented as this would imply "bulk", and there is no
  bulk_delete_mappings (could be, but we dont have that)
* ORM insert().values() with single or multi-values translates
  key names based on ORM attribute names
* ORM returning() implemented for insert, update, delete;
  explcit returning clauses now interpret rows in an ORM
  context, with support for qualifying loader options as well
* session.bulk_insert_mappings() assigns polymorphic identity
  if not set.
* explicit RETURNING + synchronize_session='fetch' is now
  supported with UPDATE and DELETE.
* expanded return_defaults() to work with DELETE also.
* added support for composite attributes to be present
  in the dictionaries used by bulk_insert_mappings and
  bulk_update_mappings, which is also the new ORM bulk
  insert/update feature, that will expand the composite
  values into their individual mapped attributes the way they'd
  be on a mapped instance.
* bulk UPDATE supports "synchronize_session=evaluate", is the
  default.  this does not apply to session.bulk_update_mappings,
  just the new version
* both bulk UPDATE and bulk INSERT, the latter with or without
  RETURNING, support *heterogenous* parameter sets.
  session.bulk_insert/update_mappings did this, so this feature
  is maintained.  now cursor result can be both horizontally
  and vertically spliced :)

This is now a long story with a lot of options, which in
itself is a problem to be able to document all of this
in some way that makes sense.  raising exceptions for
use cases we haven't supported is pretty important here
too, the tradition of letting unsupported things just not work
is likely not a good idea at this point, though there
are still many cases that aren't easily avoidable

Fixes: #8360
Fixes: #7864
Fixes: #7865
Change-Id: Idf28379f8705e403a3c6a937f6a798a042ef2540

45 files changed:
doc/build/orm/session_basics.rst
lib/sqlalchemy/dialects/mssql/pyodbc.py
lib/sqlalchemy/dialects/postgresql/provision.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/evaluator.py
lib/sqlalchemy/orm/identity.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/testing/assertsql.py
lib/sqlalchemy/testing/fixtures.py
lib/sqlalchemy/testing/suite/test_rowcount.py
lib/sqlalchemy/util/_py_collections.py
test/ext/test_horizontal_shard.py
test/ext/test_hybrid.py
test/orm/dml/__init__.py [new file with mode: 0644]
test/orm/dml/test_bulk.py [moved from test/orm/test_bulk.py with 80% similarity]
test/orm/dml/test_bulk_statements.py [new file with mode: 0644]
test/orm/dml/test_evaluator.py [moved from test/orm/test_evaluator.py with 99% similarity]
test/orm/dml/test_update_delete_where.py [moved from test/orm/test_update_delete.py with 83% similarity]
test/orm/inheritance/test_basic.py
test/orm/test_bind.py
test/orm/test_composites.py
test/orm/test_cycles.py
test/orm/test_defaults.py
test/orm/test_events.py
test/orm/test_unitofwork.py
test/orm/test_unitofworkv2.py
test/orm/test_versioning.py
test/sql/test_resultset.py
test/sql/test_returning.py
test/sql/test_selectable.py

index 96b9d8b5cace77281fcae2a24c0615741919800f..6b7ef329917a9061d0aa8ea9f4dce8029da477ee 100644 (file)
@@ -660,6 +660,17 @@ Selecting a Synchronization Strategy
 With both the 1.x and 2.0 form of ORM-enabled updates and deletes, the following
 values for ``synchronize_session`` are supported:
 
+* ``'auto'`` - this is the default.   The ``'fetch'`` strategy will be used on
+  backends that support RETURNING, which includes all SQLAlchemy-native drivers
+  except for MySQL.   If RETURNING is not supported, the ``'evaluate'``
+  strategy will be used instead.
+
+  .. versionchanged:: 2.0  Added the ``'auto'`` synchronization strategy.  As
+     most backends now support RETURNING, selecting ``'fetch'`` for these
+     backends specifically is the more efficient and error-free default for
+     these backends. The MySQL backend as well as third party backends without
+     RETURNING support will continue to use ``'evaluate'`` by default.
+
 * ``False`` - don't synchronize the session. This option is the most
   efficient and is reliable once the session is expired, which
   typically occurs after a commit(), or explicitly using
index 2eef971cc50f863648b0d809151af41ba8963c89..5eb6b952824f3d5d8f79e3c4a565c00e6bdfcf05 100644 (file)
@@ -301,7 +301,7 @@ Fast Executemany Mode
    The SQL Server ``fast_executemany`` parameter may be used at the same time
    as ``insertmanyvalues`` is enabled; however, the parameter will not be used
    in as many cases as INSERT statements that are invoked using Core
-   :class:`.Insert` constructs as well as all ORM use no longer use the
+   :class:`_dml.Insert` constructs as well as all ORM use no longer use the
    ``.executemany()`` DBAPI cursor method.
 
 The PyODBC driver includes support for a "fast executemany" mode of execution
index 8dd8a4995152ae7843588f69da39a7d7ef86d612..4609701a2cf776385c93689767709a9eefbafb25 100644 (file)
@@ -134,9 +134,11 @@ def _upsert(cfg, table, returning, set_lambda=None):
 
     stmt = insert(table)
 
+    table_pk = inspect(table).selectable
+
     if set_lambda:
         stmt = stmt.on_conflict_do_update(
-            index_elements=table.primary_key, set_=set_lambda(stmt.excluded)
+            index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded)
         )
     else:
         stmt = stmt.on_conflict_do_nothing()
index e57a84fe0d4975db87614ea75ac9e761bad4b4ee..5f468edbe30c049ec070e042010bfde9b1dbdbac 100644 (file)
@@ -1466,11 +1466,6 @@ class SQLiteCompiler(compiler.SQLCompiler):
 
         return target_text
 
-    def visit_insert(self, insert_stmt, **kw):
-        if insert_stmt._post_values_clause is not None:
-            kw["disable_implicit_returning"] = True
-        return super().visit_insert(insert_stmt, **kw)
-
     def visit_on_conflict_do_nothing(self, on_conflict, **kw):
 
         target_text = self._on_conflict_target(on_conflict, **kw)
index 8840b5916138d513ad7543b381baf8f5bcf199bd..07e78229678d9504af7a94a3369b730f8f1517aa 100644 (file)
@@ -23,12 +23,14 @@ from typing import Iterator
 from typing import List
 from typing import NoReturn
 from typing import Optional
+from typing import overload
 from typing import Sequence
 from typing import Tuple
 from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
+from .result import IteratorResult
 from .result import MergedResult
 from .result import Result
 from .result import ResultMetaData
@@ -62,36 +64,80 @@ if typing.TYPE_CHECKING:
     from .interfaces import ExecutionContext
     from .result import _KeyIndexType
     from .result import _KeyMapRecType
+    from .result import _KeyMapType
     from .result import _KeyType
     from .result import _ProcessorsType
+    from .result import _TupleGetterType
     from ..sql.type_api import _ResultProcessorType
 
 
 _T = TypeVar("_T", bound=Any)
 
+
 # metadata entry tuple indexes.
 # using raw tuple is faster than namedtuple.
-MD_INDEX: Literal[0] = 0  # integer index in cursor.description
-MD_RESULT_MAP_INDEX: Literal[
-    1
-] = 1  # integer index in compiled._result_columns
-MD_OBJECTS: Literal[
-    2
-] = 2  # other string keys and ColumnElement obj that can match
-MD_LOOKUP_KEY: Literal[
-    3
-] = 3  # string key we usually expect for key-based lookup
-MD_RENDERED_NAME: Literal[4] = 4  # name that is usually in cursor.description
-MD_PROCESSOR: Literal[5] = 5  # callable to process a result value into a row
-MD_UNTRANSLATED: Literal[6] = 6  # raw name from cursor.description
+# these match up to the positions in
+# _CursorKeyMapRecType
+MD_INDEX: Literal[0] = 0
+"""integer index in cursor.description
+
+"""
+
+MD_RESULT_MAP_INDEX: Literal[1] = 1
+"""integer index in compiled._result_columns"""
+
+MD_OBJECTS: Literal[2] = 2
+"""other string keys and ColumnElement obj that can match.
+
+This comes from compiler.RM_OBJECTS / compiler.ResultColumnsEntry.objects
+
+"""
+
+MD_LOOKUP_KEY: Literal[3] = 3
+"""string key we usually expect for key-based lookup
+
+this comes from compiler.RM_NAME / compiler.ResultColumnsEntry.name
+"""
+
+
+MD_RENDERED_NAME: Literal[4] = 4
+"""name that is usually in cursor.description
+
+this comes from compiler.RENDERED_NAME / compiler.ResultColumnsEntry.keyname
+"""
+
+
+MD_PROCESSOR: Literal[5] = 5
+"""callable to process a result value into a row"""
+
+MD_UNTRANSLATED: Literal[6] = 6
+"""raw name from cursor.description"""
 
 
 _CursorKeyMapRecType = Tuple[
-    int, int, List[Any], str, str, Optional["_ResultProcessorType"], str
+    Optional[int],  # MD_INDEX, None means the record is ambiguously named
+    int,  # MD_RESULT_MAP_INDEX
+    List[Any],  # MD_OBJECTS
+    str,  # MD_LOOKUP_KEY
+    str,  # MD_RENDERED_NAME
+    Optional["_ResultProcessorType"],  # MD_PROCESSOR
+    Optional[str],  # MD_UNTRANSLATED
 ]
 
 _CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType]
 
+# same as _CursorKeyMapRecType except the MD_INDEX value is definitely
+# not None
+_NonAmbigCursorKeyMapRecType = Tuple[
+    int,
+    int,
+    List[Any],
+    str,
+    str,
+    Optional["_ResultProcessorType"],
+    str,
+]
+
 
 class CursorResultMetaData(ResultMetaData):
     """Result metadata for DBAPI cursors."""
@@ -127,38 +173,112 @@ class CursorResultMetaData(ResultMetaData):
             extra=[self._keymap[key][MD_OBJECTS] for key in self._keys],
         )
 
-    def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
-        recs = cast(
-            "List[_CursorKeyMapRecType]", list(self._metadata_for_keys(keys))
+    def _make_new_metadata(
+        self,
+        *,
+        unpickled: bool,
+        processors: _ProcessorsType,
+        keys: Sequence[str],
+        keymap: _KeyMapType,
+        tuplefilter: Optional[_TupleGetterType],
+        translated_indexes: Optional[List[int]],
+        safe_for_cache: bool,
+        keymap_by_result_column_idx: Any,
+    ) -> CursorResultMetaData:
+        new_obj = self.__class__.__new__(self.__class__)
+        new_obj._unpickled = unpickled
+        new_obj._processors = processors
+        new_obj._keys = keys
+        new_obj._keymap = keymap
+        new_obj._tuplefilter = tuplefilter
+        new_obj._translated_indexes = translated_indexes
+        new_obj._safe_for_cache = safe_for_cache
+        new_obj._keymap_by_result_column_idx = keymap_by_result_column_idx
+        return new_obj
+
+    def _remove_processors(self) -> CursorResultMetaData:
+        assert not self._tuplefilter
+        return self._make_new_metadata(
+            unpickled=self._unpickled,
+            processors=[None] * len(self._processors),
+            tuplefilter=None,
+            translated_indexes=None,
+            keymap={
+                key: value[0:5] + (None,) + value[6:]
+                for key, value in self._keymap.items()
+            },
+            keys=self._keys,
+            safe_for_cache=self._safe_for_cache,
+            keymap_by_result_column_idx=self._keymap_by_result_column_idx,
         )
 
+    def _splice_horizontally(
+        self, other: CursorResultMetaData
+    ) -> CursorResultMetaData:
+
+        assert not self._tuplefilter
+
+        keymap = self._keymap.copy()
+        offset = len(self._keys)
+        keymap.update(
+            {
+                key: (
+                    # int index should be None for ambiguous key
+                    value[0] + offset
+                    if value[0] is not None and key not in keymap
+                    else None,
+                    value[1] + offset,
+                    *value[2:],
+                )
+                for key, value in other._keymap.items()
+            }
+        )
+
+        return self._make_new_metadata(
+            unpickled=self._unpickled,
+            processors=self._processors + other._processors,  # type: ignore
+            tuplefilter=None,
+            translated_indexes=None,
+            keys=self._keys + other._keys,  # type: ignore
+            keymap=keymap,
+            safe_for_cache=self._safe_for_cache,
+            keymap_by_result_column_idx={
+                metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry
+                for metadata_entry in keymap.values()
+            },
+        )
+
+    def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
+        recs = list(self._metadata_for_keys(keys))
+
         indexes = [rec[MD_INDEX] for rec in recs]
         new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs]
 
         if self._translated_indexes:
             indexes = [self._translated_indexes[idx] for idx in indexes]
         tup = tuplegetter(*indexes)
-
-        new_metadata = self.__class__.__new__(self.__class__)
-        new_metadata._unpickled = self._unpickled
-        new_metadata._processors = self._processors
-        new_metadata._keys = new_keys
-        new_metadata._tuplefilter = tup
-        new_metadata._translated_indexes = indexes
-
         new_recs = [(index,) + rec[1:] for index, rec in enumerate(recs)]
-        new_metadata._keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
 
+        keymap: _KeyMapType = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
         # TODO: need unit test for:
         # result = connection.execute("raw sql, no columns").scalars()
         # without the "or ()" it's failing because MD_OBJECTS is None
-        new_metadata._keymap.update(
+        keymap.update(
             (e, new_rec)
             for new_rec in new_recs
             for e in new_rec[MD_OBJECTS] or ()
         )
 
-        return new_metadata
+        return self._make_new_metadata(
+            unpickled=self._unpickled,
+            processors=self._processors,
+            keys=new_keys,
+            tuplefilter=tup,
+            translated_indexes=indexes,
+            keymap=keymap,
+            safe_for_cache=self._safe_for_cache,
+            keymap_by_result_column_idx=self._keymap_by_result_column_idx,
+        )
 
     def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData:
         """When using a cached Compiled construct that has a _result_map,
@@ -168,6 +288,7 @@ class CursorResultMetaData(ResultMetaData):
         as matched to those of the cached statement.
 
         """
+
         if not context.compiled or not context.compiled._result_columns:
             return self
 
@@ -189,7 +310,6 @@ class CursorResultMetaData(ResultMetaData):
 
         # make a copy and add the columns from the invoked statement
         # to the result map.
-        md = self.__class__.__new__(self.__class__)
 
         keymap_by_position = self._keymap_by_result_column_idx
 
@@ -201,26 +321,26 @@ class CursorResultMetaData(ResultMetaData):
                 for metadata_entry in self._keymap.values()
             }
 
-        md._keymap = compat.dict_union(
-            self._keymap,
-            {
-                new: keymap_by_position[idx]
-                for idx, new in enumerate(
-                    invoked_statement._all_selected_columns
-                )
-                if idx in keymap_by_position
-            },
-        )
-
-        md._unpickled = self._unpickled
-        md._processors = self._processors
         assert not self._tuplefilter
-        md._tuplefilter = None
-        md._translated_indexes = None
-        md._keys = self._keys
-        md._keymap_by_result_column_idx = self._keymap_by_result_column_idx
-        md._safe_for_cache = self._safe_for_cache
-        return md
+        return self._make_new_metadata(
+            keymap=compat.dict_union(
+                self._keymap,
+                {
+                    new: keymap_by_position[idx]
+                    for idx, new in enumerate(
+                        invoked_statement._all_selected_columns
+                    )
+                    if idx in keymap_by_position
+                },
+            ),
+            unpickled=self._unpickled,
+            processors=self._processors,
+            tuplefilter=None,
+            translated_indexes=None,
+            keys=self._keys,
+            safe_for_cache=self._safe_for_cache,
+            keymap_by_result_column_idx=self._keymap_by_result_column_idx,
+        )
 
     def __init__(
         self,
@@ -683,7 +803,27 @@ class CursorResultMetaData(ResultMetaData):
                 untranslated,
             )
 
-    def _key_fallback(self, key, err, raiseerr=True):
+    @overload
+    def _key_fallback(
+        self, key: Any, err: Exception, raiseerr: Literal[True] = ...
+    ) -> NoReturn:
+        ...
+
+    @overload
+    def _key_fallback(
+        self, key: Any, err: Exception, raiseerr: Literal[False] = ...
+    ) -> None:
+        ...
+
+    @overload
+    def _key_fallback(
+        self, key: Any, err: Exception, raiseerr: bool = ...
+    ) -> Optional[NoReturn]:
+        ...
+
+    def _key_fallback(
+        self, key: Any, err: Exception, raiseerr: bool = True
+    ) -> Optional[NoReturn]:
 
         if raiseerr:
             if self._unpickled and isinstance(key, elements.ColumnElement):
@@ -714,9 +854,9 @@ class CursorResultMetaData(ResultMetaData):
         try:
             rec = self._keymap[key]
         except KeyError as ke:
-            rec = self._key_fallback(key, ke, raiseerr)
-            if rec is None:
-                return None
+            x = self._key_fallback(key, ke, raiseerr)
+            assert x is None
+            return None
 
         index = rec[0]
 
@@ -734,7 +874,7 @@ class CursorResultMetaData(ResultMetaData):
 
     def _metadata_for_keys(
         self, keys: Sequence[Any]
-    ) -> Iterator[_CursorKeyMapRecType]:
+    ) -> Iterator[_NonAmbigCursorKeyMapRecType]:
         for key in keys:
             if int in key.__class__.__mro__:
                 key = self._keys[key]
@@ -750,7 +890,7 @@ class CursorResultMetaData(ResultMetaData):
             if index is None:
                 self._raise_for_ambiguous_column_name(rec)
 
-            yield rec
+            yield cast(_NonAmbigCursorKeyMapRecType, rec)
 
     def __getstate__(self):
         return {
@@ -1237,6 +1377,12 @@ _NO_RESULT_METADATA = _NoResultMetaData()
 SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]")
 
 
+def null_dml_result() -> IteratorResult[Any]:
+    it: IteratorResult[Any] = IteratorResult(_NoResultMetaData(), iter([]))
+    it._soft_close()
+    return it
+
+
 class CursorResult(Result[_T]):
     """A Result that is representing state from a DBAPI cursor.
 
@@ -1586,6 +1732,142 @@ class CursorResult(Result[_T]):
         """
         return self.context.returned_default_rows
 
+    def splice_horizontally(self, other):
+        """Return a new :class:`.CursorResult` that "horizontally splices"
+        together the rows of this :class:`.CursorResult` with that of another
+        :class:`.CursorResult`.
+
+        .. tip::  This method is for the benefit of the SQLAlchemy ORM and is
+           not intended for general use.
+
+        "horizontally splices" means that for each row in the first and second
+        result sets, a new row that concatenates the two rows together is
+        produced, which then becomes the new row.  The incoming
+        :class:`.CursorResult` must have the identical number of rows.  It is
+        typically expected that the two result sets come from the same sort
+        order as well, as the result rows are spliced together based on their
+        position in the result.
+
+        The expected use case here is so that multiple INSERT..RETURNING
+        statements against different tables can produce a single result
+        that looks like a JOIN of those two tables.
+
+        E.g.::
+
+            r1 = connection.execute(
+                users.insert().returning(users.c.user_name, users.c.user_id),
+                user_values
+            )
+
+            r2 = connection.execute(
+                addresses.insert().returning(
+                    addresses.c.address_id,
+                    addresses.c.address,
+                    addresses.c.user_id,
+                ),
+                address_values
+            )
+
+            rows = r1.splice_horizontally(r2).all()
+            assert (
+                rows ==
+                [
+                    ("john", 1, 1, "foo@bar.com", 1),
+                    ("jack", 2, 2, "bar@bat.com", 2),
+                ]
+            )
+
+        .. versionadded:: 2.0
+
+        .. seealso::
+
+            :meth:`.CursorResult.splice_vertically`
+
+
+        """
+
+        clone = self._generate()
+        total_rows = [
+            tuple(r1) + tuple(r2)
+            for r1, r2 in zip(
+                list(self._raw_row_iterator()),
+                list(other._raw_row_iterator()),
+            )
+        ]
+
+        clone._metadata = clone._metadata._splice_horizontally(other._metadata)
+
+        clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
+            None,
+            initial_buffer=total_rows,
+        )
+        clone._reset_memoizations()
+        return clone
+
+    def splice_vertically(self, other):
+        """Return a new :class:`.CursorResult` that "vertically splices",
+        i.e. "extends", the rows of this :class:`.CursorResult` with that of
+        another :class:`.CursorResult`.
+
+        .. tip::  This method is for the benefit of the SQLAlchemy ORM and is
+           not intended for general use.
+
+        "vertically splices" means the rows of the given result are appended to
+        the rows of this cursor result. The incoming :class:`.CursorResult`
+        must have rows that represent the identical list of columns in the
+        identical order as they are in this :class:`.CursorResult`.
+
+        .. versionadded:: 2.0
+
+        .. seealso::
+
+            :ref:`.CursorResult.splice_horizontally`
+
+        """
+        clone = self._generate()
+        total_rows = list(self._raw_row_iterator()) + list(
+            other._raw_row_iterator()
+        )
+
+        clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
+            None,
+            initial_buffer=total_rows,
+        )
+        clone._reset_memoizations()
+        return clone
+
+    def _rewind(self, rows):
+        """rewind this result back to the given rowset.
+
+        this is used internally for the case where an :class:`.Insert`
+        construct combines the use of
+        :meth:`.Insert.return_defaults` along with the
+        "supplemental columns" feature.
+
+        """
+
+        if self._echo:
+            self.context.connection._log_debug(
+                "CursorResult rewound %d row(s)", len(rows)
+            )
+
+        # the rows given are expected to be Row objects, so we
+        # have to clear out processors which have already run on these
+        # rows
+        self._metadata = cast(
+            CursorResultMetaData, self._metadata
+        )._remove_processors()
+
+        self.cursor_strategy = FullyBufferedCursorFetchStrategy(
+            None,
+            # TODO: if these are Row objects, can we save on not having to
+            # re-make new Row objects out of them a second time?  is that
+            # what's actually happening right now?  maybe look into this
+            initial_buffer=rows,
+        )
+        self._reset_memoizations()
+        return self
+
     @property
     def returned_defaults(self):
         """Return the values of default columns that were fetched using
index 11ab713d08a88a3399644808ecd07cb3ff129461..cb3d0528fd679e387ea844e9679cd4ced8970abd 100644 (file)
@@ -1007,6 +1007,7 @@ class DefaultExecutionContext(ExecutionContext):
 
     _is_implicit_returning = False
     _is_explicit_returning = False
+    _is_supplemental_returning = False
     _is_server_side = False
 
     _soft_closed = False
@@ -1125,18 +1126,19 @@ class DefaultExecutionContext(ExecutionContext):
         self.is_text = compiled.isplaintext
 
         if ii or iu or id_:
+            dml_statement = compiled.compile_state.statement  # type: ignore
             if TYPE_CHECKING:
-                assert isinstance(compiled.statement, UpdateBase)
+                assert isinstance(dml_statement, UpdateBase)
             self.is_crud = True
-            self._is_explicit_returning = ier = bool(
-                compiled.statement._returning
-            )
-            self._is_implicit_returning = iir = is_implicit_returning = bool(
+            self._is_explicit_returning = ier = bool(dml_statement._returning)
+            self._is_implicit_returning = iir = bool(
                 compiled.implicit_returning
             )
-            assert not (
-                is_implicit_returning and compiled.statement._returning
-            )
+            if iir and dml_statement._supplemental_returning:
+                self._is_supplemental_returning = True
+
+            # dont mix implicit and explicit returning
+            assert not (iir and ier)
 
             if (ier or iir) and compiled.for_executemany:
                 if ii and not self.dialect.insert_executemany_returning:
@@ -1711,7 +1713,14 @@ class DefaultExecutionContext(ExecutionContext):
                 # are that the result has only one row, until executemany()
                 # support is added here.
                 assert result._metadata.returns_rows
-                result._soft_close()
+
+                # Insert statement has both return_defaults() and
+                # returning().  rewind the result on the list of rows
+                # we just used.
+                if self._is_supplemental_returning:
+                    result._rewind(rows)
+                else:
+                    result._soft_close()
             elif not self._is_explicit_returning:
                 result._soft_close()
 
@@ -1721,21 +1730,18 @@ class DefaultExecutionContext(ExecutionContext):
                 # function so this is not necessarily true.
                 # assert not result.returns_rows
 
-        elif self.isupdate and self._is_implicit_returning:
-            # get rowcount
-            # (which requires open cursor on some drivers)
-            # we were not doing this in 1.4, however
-            # test_rowcount -> test_update_rowcount_return_defaults
-            # is testing this, and psycopg will no longer return
-            # rowcount after cursor is closed.
-            result.rowcount
-            self._has_rowcount = True
+        elif self._is_implicit_returning:
+            rows = result.all()
 
-            row = result.fetchone()
-            if row is not None:
-                self.returned_default_rows = [row]
+            if rows:
+                self.returned_default_rows = rows
+            result.rowcount = len(rows)
+            self._has_rowcount = True
 
-            result._soft_close()
+            if self._is_supplemental_returning:
+                result._rewind(rows)
+            else:
+                result._soft_close()
 
             # test that it has a cursor metadata that is accurate.
             # the rows have all been fetched however.
@@ -1750,7 +1756,6 @@ class DefaultExecutionContext(ExecutionContext):
         elif self.isupdate or self.isdelete:
             result.rowcount
             self._has_rowcount = True
-
         return result
 
     @util.memoized_property
index df5a8199cc8818a620bff98073ae21bd3bfd0a8c..05ca17063b028244b744bfaf55948a849a465c0f 100644 (file)
@@ -109,9 +109,27 @@ class ResultMetaData:
     def _for_freeze(self) -> ResultMetaData:
         raise NotImplementedError()
 
+    @overload
     def _key_fallback(
-        self, key: _KeyType, err: Exception, raiseerr: bool = True
+        self, key: Any, err: Exception, raiseerr: Literal[True] = ...
     ) -> NoReturn:
+        ...
+
+    @overload
+    def _key_fallback(
+        self, key: Any, err: Exception, raiseerr: Literal[False] = ...
+    ) -> None:
+        ...
+
+    @overload
+    def _key_fallback(
+        self, key: Any, err: Exception, raiseerr: bool = ...
+    ) -> Optional[NoReturn]:
+        ...
+
+    def _key_fallback(
+        self, key: Any, err: Exception, raiseerr: bool = True
+    ) -> Optional[NoReturn]:
         assert raiseerr
         raise KeyError(key) from err
 
@@ -2148,6 +2166,7 @@ class IteratorResult(Result[_TP]):
     """
 
     _hard_closed = False
+    _soft_closed = False
 
     def __init__(
         self,
@@ -2168,6 +2187,7 @@ class IteratorResult(Result[_TP]):
             self.raw._soft_close(hard=hard, **kw)
         self.iterator = iter([])
         self._reset_memoizations()
+        self._soft_closed = True
 
     def _raise_hard_closed(self) -> NoReturn:
         raise exc.ResourceClosedError("This result object is closed.")
index 225292d17de124d4b2febeac0cd2e39588f38357..3ed34a57a4d7967c2480b9e9047748a34fda96b1 100644 (file)
@@ -15,24 +15,32 @@ specifically outside of the flush() process.
 from __future__ import annotations
 
 from typing import Any
+from typing import cast
 from typing import Dict
 from typing import Iterable
+from typing import Optional
+from typing import overload
 from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
 from . import attributes
+from . import context
 from . import evaluator
 from . import exc as orm_exc
+from . import loading
 from . import persistence
 from .base import NO_VALUE
 from .context import AbstractORMCompileState
+from .context import FromStatement
+from .context import ORMFromStatementCompileState
+from .context import QueryContext
 from .. import exc as sa_exc
-from .. import sql
 from .. import util
 from ..engine import Dialect
 from ..engine import result as _result
 from ..sql import coercions
+from ..sql import dml
 from ..sql import expression
 from ..sql import roles
 from ..sql import select
@@ -48,16 +56,24 @@ from ..util.typing import Literal
 
 if TYPE_CHECKING:
     from .mapper import Mapper
+    from .session import _BindArguments
     from .session import ORMExecuteState
+    from .session import Session
     from .session import SessionTransaction
     from .state import InstanceState
+    from ..engine import Connection
+    from ..engine import cursor
+    from ..engine.interfaces import _CoreAnyExecuteParams
+    from ..engine.interfaces import _ExecuteOptionsParameter
 
 _O = TypeVar("_O", bound=object)
 
 
-_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"]
+_SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"]
+_DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"]
 
 
+@overload
 def _bulk_insert(
     mapper: Mapper[_O],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
@@ -65,7 +81,36 @@ def _bulk_insert(
     isstates: bool,
     return_defaults: bool,
     render_nulls: bool,
+    use_orm_insert_stmt: Literal[None] = ...,
+    execution_options: Optional[_ExecuteOptionsParameter] = ...,
 ) -> None:
+    ...
+
+
+@overload
+def _bulk_insert(
+    mapper: Mapper[_O],
+    mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+    session_transaction: SessionTransaction,
+    isstates: bool,
+    return_defaults: bool,
+    render_nulls: bool,
+    use_orm_insert_stmt: Optional[dml.Insert] = ...,
+    execution_options: Optional[_ExecuteOptionsParameter] = ...,
+) -> cursor.CursorResult[Any]:
+    ...
+
+
+def _bulk_insert(
+    mapper: Mapper[_O],
+    mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+    session_transaction: SessionTransaction,
+    isstates: bool,
+    return_defaults: bool,
+    render_nulls: bool,
+    use_orm_insert_stmt: Optional[dml.Insert] = None,
+    execution_options: Optional[_ExecuteOptionsParameter] = None,
+) -> Optional[cursor.CursorResult[Any]]:
     base_mapper = mapper.base_mapper
 
     if session_transaction.session.connection_callable:
@@ -81,13 +126,27 @@ def _bulk_insert(
         else:
             mappings = [state.dict for state in mappings]
     else:
-        mappings = list(mappings)
+        mappings = [dict(m) for m in mappings]
+        _expand_composites(mapper, mappings)
 
     connection = session_transaction.connection(base_mapper)
+
+    return_result: Optional[cursor.CursorResult[Any]] = None
+
     for table, super_mapper in base_mapper._sorted_tables.items():
-        if not mapper.isa(super_mapper):
+        if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
             continue
 
+        is_joined_inh_supertable = super_mapper is not mapper
+        bookkeeping = (
+            is_joined_inh_supertable
+            or return_defaults
+            or (
+                use_orm_insert_stmt is not None
+                and bool(use_orm_insert_stmt._returning)
+            )
+        )
+
         records = (
             (
                 None,
@@ -112,18 +171,25 @@ def _bulk_insert(
                 table,
                 ((None, mapping, mapper, connection) for mapping in mappings),
                 bulk=True,
-                return_defaults=return_defaults,
+                return_defaults=bookkeeping,
                 render_nulls=render_nulls,
             )
         )
-        persistence._emit_insert_statements(
+        result = persistence._emit_insert_statements(
             base_mapper,
             None,
             super_mapper,
             table,
             records,
-            bookkeeping=return_defaults,
+            bookkeeping=bookkeeping,
+            use_orm_insert_stmt=use_orm_insert_stmt,
+            execution_options=execution_options,
         )
+        if use_orm_insert_stmt is not None:
+            if not use_orm_insert_stmt._returning or return_result is None:
+                return_result = result
+            elif result.returns_rows:
+                return_result = return_result.splice_horizontally(result)
 
     if return_defaults and isstates:
         identity_cls = mapper._identity_class
@@ -134,14 +200,43 @@ def _bulk_insert(
                 tuple([dict_[key] for key in identity_props]),
             )
 
+    if use_orm_insert_stmt is not None:
+        assert return_result is not None
+        return return_result
 
+
+@overload
 def _bulk_update(
     mapper: Mapper[Any],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
     session_transaction: SessionTransaction,
     isstates: bool,
     update_changed_only: bool,
+    use_orm_update_stmt: Literal[None] = ...,
 ) -> None:
+    ...
+
+
+@overload
+def _bulk_update(
+    mapper: Mapper[Any],
+    mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+    session_transaction: SessionTransaction,
+    isstates: bool,
+    update_changed_only: bool,
+    use_orm_update_stmt: Optional[dml.Update] = ...,
+) -> _result.Result[Any]:
+    ...
+
+
+def _bulk_update(
+    mapper: Mapper[Any],
+    mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+    session_transaction: SessionTransaction,
+    isstates: bool,
+    update_changed_only: bool,
+    use_orm_update_stmt: Optional[dml.Update] = None,
+) -> Optional[_result.Result[Any]]:
     base_mapper = mapper.base_mapper
 
     search_keys = mapper._primary_key_propkeys
@@ -161,7 +256,8 @@ def _bulk_update(
         else:
             mappings = [state.dict for state in mappings]
     else:
-        mappings = list(mappings)
+        mappings = [dict(m) for m in mappings]
+        _expand_composites(mapper, mappings)
 
     if session_transaction.session.connection_callable:
         raise NotImplementedError(
@@ -172,7 +268,7 @@ def _bulk_update(
     connection = session_transaction.connection(base_mapper)
 
     for table, super_mapper in base_mapper._sorted_tables.items():
-        if not mapper.isa(super_mapper):
+        if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
             continue
 
         records = persistence._collect_update_commands(
@@ -193,8 +289,8 @@ def _bulk_update(
                 for mapping in mappings
             ),
             bulk=True,
+            use_orm_update_stmt=use_orm_update_stmt,
         )
-
         persistence._emit_update_statements(
             base_mapper,
             None,
@@ -202,10 +298,125 @@ def _bulk_update(
             table,
             records,
             bookkeeping=False,
+            use_orm_update_stmt=use_orm_update_stmt,
         )
 
+    if use_orm_update_stmt is not None:
+        return _result.null_result()
+
+
+def _expand_composites(mapper, mappings):
+    composite_attrs = mapper.composites
+    if not composite_attrs:
+        return
+
+    composite_keys = set(composite_attrs.keys())
+    populators = {
+        key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
+        for key in composite_keys
+    }
+    for mapping in mappings:
+        for key in composite_keys.intersection(mapping):
+            populators[key](mapping)
+
 
 class ORMDMLState(AbstractORMCompileState):
+    is_dml_returning = True
+    from_statement_ctx: Optional[ORMFromStatementCompileState] = None
+
+    @classmethod
+    def _get_orm_crud_kv_pairs(
+        cls, mapper, statement, kv_iterator, needs_to_be_cacheable
+    ):
+
+        core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
+
+        for k, v in kv_iterator:
+            k = coercions.expect(roles.DMLColumnRole, k)
+
+            if isinstance(k, str):
+                desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
+                if desc is NO_VALUE:
+                    yield (
+                        coercions.expect(roles.DMLColumnRole, k),
+                        coercions.expect(
+                            roles.ExpressionElementRole,
+                            v,
+                            type_=sqltypes.NullType(),
+                            is_crud=True,
+                        )
+                        if needs_to_be_cacheable
+                        else v,
+                    )
+                else:
+                    yield from core_get_crud_kv_pairs(
+                        statement,
+                        desc._bulk_update_tuples(v),
+                        needs_to_be_cacheable,
+                    )
+            elif "entity_namespace" in k._annotations:
+                k_anno = k._annotations
+                attr = _entity_namespace_key(
+                    k_anno["entity_namespace"], k_anno["proxy_key"]
+                )
+                yield from core_get_crud_kv_pairs(
+                    statement,
+                    attr._bulk_update_tuples(v),
+                    needs_to_be_cacheable,
+                )
+            else:
+                yield (
+                    k,
+                    v
+                    if not needs_to_be_cacheable
+                    else coercions.expect(
+                        roles.ExpressionElementRole,
+                        v,
+                        type_=sqltypes.NullType(),
+                        is_crud=True,
+                    ),
+                )
+
+    @classmethod
+    def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
+        plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+        if not plugin_subject or not plugin_subject.mapper:
+            return UpdateDMLState._get_multi_crud_kv_pairs(
+                statement, kv_iterator
+            )
+
+        return [
+            dict(
+                cls._get_orm_crud_kv_pairs(
+                    plugin_subject.mapper, statement, value_dict.items(), False
+                )
+            )
+            for value_dict in kv_iterator
+        ]
+
+    @classmethod
+    def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
+        assert (
+            needs_to_be_cacheable
+        ), "no test coverage for needs_to_be_cacheable=False"
+
+        plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+        if not plugin_subject or not plugin_subject.mapper:
+            return UpdateDMLState._get_crud_kv_pairs(
+                statement, kv_iterator, needs_to_be_cacheable
+            )
+
+        return list(
+            cls._get_orm_crud_kv_pairs(
+                plugin_subject.mapper,
+                statement,
+                kv_iterator,
+                needs_to_be_cacheable,
+            )
+        )
+
     @classmethod
     def get_entity_description(cls, statement):
         ext_info = statement.table._annotations["parententity"]
@@ -250,18 +461,101 @@ class ORMDMLState(AbstractORMCompileState):
             ]
         ]
 
+    def _setup_orm_returning(
+        self,
+        compiler,
+        orm_level_statement,
+        dml_level_statement,
+        use_supplemental_cols=True,
+        dml_mapper=None,
+    ):
+        """establish ORM column handlers for an INSERT, UPDATE, or DELETE
+        which uses explicit returning().
+
+        called within compilation level create_for_statement.
+
+        The _return_orm_returning() method then receives the Result
+        after the statement was executed, and applies ORM loading to the
+        state that we first established here.
+
+        """
+
+        if orm_level_statement._returning:
+
+            fs = FromStatement(
+                orm_level_statement._returning, dml_level_statement
+            )
+            fs = fs.options(*orm_level_statement._with_options)
+            self.select_statement = fs
+            self.from_statement_ctx = (
+                fsc
+            ) = ORMFromStatementCompileState.create_for_statement(fs, compiler)
+            fsc.setup_dml_returning_compile_state(dml_mapper)
+
+            dml_level_statement = dml_level_statement._generate()
+            dml_level_statement._returning = ()
+
+            cols_to_return = [c for c in fsc.primary_columns if c is not None]
+
+            # since we are splicing result sets together, make sure there
+            # are columns of some kind returned in each result set
+            if not cols_to_return:
+                cols_to_return.extend(dml_mapper.primary_key)
+
+            if use_supplemental_cols:
+                dml_level_statement = dml_level_statement.return_defaults(
+                    supplemental_cols=cols_to_return
+                )
+            else:
+                dml_level_statement = dml_level_statement.returning(
+                    *cols_to_return
+                )
+
+        return dml_level_statement
+
+    @classmethod
+    def _return_orm_returning(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        result,
+    ):
+
+        execution_context = result.context
+        compile_state = execution_context.compiled.compile_state
+
+        if compile_state.from_statement_ctx:
+            load_options = execution_options.get(
+                "_sa_orm_load_options", QueryContext.default_load_options
+            )
+            querycontext = QueryContext(
+                compile_state.from_statement_ctx,
+                compile_state.select_statement,
+                params,
+                session,
+                load_options,
+                execution_options,
+                bind_arguments,
+            )
+            return loading.instances(result, querycontext)
+        else:
+            return result
+
 
 class BulkUDCompileState(ORMDMLState):
     class default_update_options(Options):
-        _synchronize_session: _SynchronizeSessionArgument = "evaluate"
-        _is_delete_using = False
-        _is_update_from = False
-        _autoflush = True
-        _subject_mapper = None
+        _dml_strategy: _DMLStrategyArgument = "auto"
+        _synchronize_session: _SynchronizeSessionArgument = "auto"
+        _can_use_returning: bool = False
+        _is_delete_using: bool = False
+        _is_update_from: bool = False
+        _autoflush: bool = True
+        _subject_mapper: Optional[Mapper[Any]] = None
         _resolved_values = EMPTY_DICT
-        _resolved_keys_as_propnames = EMPTY_DICT
-        _value_evaluators = EMPTY_DICT
-        _matched_objects = None
+        _eval_condition = None
         _matched_rows = None
         _refresh_identity_token = None
 
@@ -295,19 +589,16 @@ class BulkUDCompileState(ORMDMLState):
             execution_options,
         ) = BulkUDCompileState.default_update_options.from_execution_options(
             "_sa_orm_update_options",
-            {"synchronize_session", "is_delete_using", "is_update_from"},
+            {
+                "synchronize_session",
+                "is_delete_using",
+                "is_update_from",
+                "dml_strategy",
+            },
             execution_options,
             statement._execution_options,
         )
 
-        sync = update_options._synchronize_session
-        if sync is not None:
-            if sync not in ("evaluate", "fetch", False):
-                raise sa_exc.ArgumentError(
-                    "Valid strategies for session synchronization "
-                    "are 'evaluate', 'fetch', False"
-                )
-
         bind_arguments["clause"] = statement
         try:
             plugin_subject = statement._propagate_attrs["plugin_subject"]
@@ -318,43 +609,86 @@ class BulkUDCompileState(ORMDMLState):
 
         update_options += {"_subject_mapper": plugin_subject.mapper}
 
+        if not isinstance(params, list):
+            if update_options._dml_strategy == "auto":
+                update_options += {"_dml_strategy": "orm"}
+            elif update_options._dml_strategy == "bulk":
+                raise sa_exc.InvalidRequestError(
+                    'Can\'t use "bulk" ORM insert strategy without '
+                    "passing separate parameters"
+                )
+        else:
+            if update_options._dml_strategy == "auto":
+                update_options += {"_dml_strategy": "bulk"}
+            elif update_options._dml_strategy == "orm":
+                raise sa_exc.InvalidRequestError(
+                    'Can\'t use "orm" ORM insert strategy with a '
+                    "separate parameter list"
+                )
+
+        sync = update_options._synchronize_session
+        if sync is not None:
+            if sync not in ("auto", "evaluate", "fetch", False):
+                raise sa_exc.ArgumentError(
+                    "Valid strategies for session synchronization "
+                    "are 'auto', 'evaluate', 'fetch', False"
+                )
+            if update_options._dml_strategy == "bulk" and sync == "fetch":
+                raise sa_exc.InvalidRequestError(
+                    "The 'fetch' synchronization strategy is not available "
+                    "for 'bulk' ORM updates (i.e. multiple parameter sets)"
+                )
+
         if update_options._autoflush:
             session._autoflush()
 
+        if update_options._dml_strategy == "orm":
+
+            if update_options._synchronize_session == "auto":
+                update_options = cls._do_pre_synchronize_auto(
+                    session,
+                    statement,
+                    params,
+                    execution_options,
+                    bind_arguments,
+                    update_options,
+                )
+            elif update_options._synchronize_session == "evaluate":
+                update_options = cls._do_pre_synchronize_evaluate(
+                    session,
+                    statement,
+                    params,
+                    execution_options,
+                    bind_arguments,
+                    update_options,
+                )
+            elif update_options._synchronize_session == "fetch":
+                update_options = cls._do_pre_synchronize_fetch(
+                    session,
+                    statement,
+                    params,
+                    execution_options,
+                    bind_arguments,
+                    update_options,
+                )
+        elif update_options._dml_strategy == "bulk":
+            if update_options._synchronize_session == "auto":
+                update_options += {"_synchronize_session": "evaluate"}
+
+        # indicators from the "pre exec" step that are then
+        # added to the DML statement, which will also be part of the cache
+        # key.  The compile level create_for_statement() method will then
+        # consume these at compiler time.
         statement = statement._annotate(
             {
                 "synchronize_session": update_options._synchronize_session,
                 "is_delete_using": update_options._is_delete_using,
                 "is_update_from": update_options._is_update_from,
+                "dml_strategy": update_options._dml_strategy,
+                "can_use_returning": update_options._can_use_returning,
             }
         )
 
-        # this stage of the execution is called before the do_orm_execute event
-        # hook.  meaning for an extension like horizontal sharding, this step
-        # happens before the extension splits out into multiple backends and
-        # runs only once.  if we do pre_sync_fetch, we execute a SELECT
-        # statement, which the horizontal sharding extension splits amongst the
-        # shards and combines the results together.
-
-        if update_options._synchronize_session == "evaluate":
-            update_options = cls._do_pre_synchronize_evaluate(
-                session,
-                statement,
-                params,
-                execution_options,
-                bind_arguments,
-                update_options,
-            )
-        elif update_options._synchronize_session == "fetch":
-            update_options = cls._do_pre_synchronize_fetch(
-                session,
-                statement,
-                params,
-                execution_options,
-                bind_arguments,
-                update_options,
-            )
-
         return (
             statement,
             util.immutabledict(execution_options).union(
@@ -382,12 +716,30 @@ class BulkUDCompileState(ORMDMLState):
         # individual ones we return here.
 
         update_options = execution_options["_sa_orm_update_options"]
-        if update_options._synchronize_session == "evaluate":
-            cls._do_post_synchronize_evaluate(session, result, update_options)
-        elif update_options._synchronize_session == "fetch":
-            cls._do_post_synchronize_fetch(session, result, update_options)
+        if update_options._dml_strategy == "orm":
+            if update_options._synchronize_session == "evaluate":
+                cls._do_post_synchronize_evaluate(
+                    session, statement, result, update_options
+                )
+            elif update_options._synchronize_session == "fetch":
+                cls._do_post_synchronize_fetch(
+                    session, statement, result, update_options
+                )
+        elif update_options._dml_strategy == "bulk":
+            if update_options._synchronize_session == "evaluate":
+                cls._do_post_synchronize_bulk_evaluate(
+                    session, params, result, update_options
+                )
+            return result
 
-        return result
+        return cls._return_orm_returning(
+            session,
+            statement,
+            params,
+            execution_options,
+            bind_arguments,
+            result,
+        )
 
     @classmethod
     def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
@@ -473,11 +825,76 @@ class BulkUDCompileState(ORMDMLState):
         primary_key_convert = [
             lookup[bpk] for bpk in mapper.base_mapper.primary_key
         ]
-
         return [tuple(row[idx] for idx in primary_key_convert) for row in rows]
 
     @classmethod
-    def _do_pre_synchronize_evaluate(
+    def _get_matched_objects_on_criteria(cls, update_options, states):
+        mapper = update_options._subject_mapper
+        eval_condition = update_options._eval_condition
+
+        raw_data = [
+            (state.obj(), state, state.dict)
+            for state in states
+            if state.mapper.isa(mapper) and not state.expired
+        ]
+
+        identity_token = update_options._refresh_identity_token
+        if identity_token is not None:
+            raw_data = [
+                (obj, state, dict_)
+                for obj, state, dict_ in raw_data
+                if state.identity_token == identity_token
+            ]
+
+        result = []
+        for obj, state, dict_ in raw_data:
+            evaled_condition = eval_condition(obj)
+
+            # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
+            # evaluates as True for all comparisons
+            if (
+                evaled_condition is True
+                or evaled_condition is evaluator._EXPIRED_OBJECT
+            ):
+                result.append(
+                    (
+                        obj,
+                        state,
+                        dict_,
+                        evaled_condition is evaluator._EXPIRED_OBJECT,
+                    )
+                )
+        return result
+
+    @classmethod
+    def _eval_condition_from_statement(cls, update_options, statement):
+        mapper = update_options._subject_mapper
+        target_cls = mapper.class_
+
+        evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+        crit = ()
+        if statement._where_criteria:
+            crit += statement._where_criteria
+
+        global_attributes = {}
+        for opt in statement._with_options:
+            if opt._is_criteria_option:
+                opt.get_global_criteria(global_attributes)
+
+        if global_attributes:
+            crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
+
+        if crit:
+            eval_condition = evaluator_compiler.process(*crit)
+        else:
+
+            def eval_condition(obj):
+                return True
+
+        return eval_condition
+
+    @classmethod
+    def _do_pre_synchronize_auto(
         cls,
         session,
         statement,
@@ -486,33 +903,59 @@ class BulkUDCompileState(ORMDMLState):
         bind_arguments,
         update_options,
     ):
-        mapper = update_options._subject_mapper
-        target_cls = mapper.class_
+        """setup auto sync strategy
+
+
+        "auto" checks if we can use "evaluate" first, then falls back
+        to "fetch"
 
-        value_evaluators = resolved_keys_as_propnames = EMPTY_DICT
+        evaluate is vastly more efficient for the common case
+        where session is empty, only has a few objects, and the UPDATE
+        statement can potentially match thousands/millions of rows.
+
+        OTOH more complex criteria that fails to work with "evaluate"
+        we would hope usually correlates with fewer net rows.
+
+        """
 
         try:
-            evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
-            crit = ()
-            if statement._where_criteria:
-                crit += statement._where_criteria
+            eval_condition = cls._eval_condition_from_statement(
+                update_options, statement
+            )
 
-            global_attributes = {}
-            for opt in statement._with_options:
-                if opt._is_criteria_option:
-                    opt.get_global_criteria(global_attributes)
+        except evaluator.UnevaluatableError:
+            pass
+        else:
+            return update_options + {
+                "_eval_condition": eval_condition,
+                "_synchronize_session": "evaluate",
+            }
 
-            if global_attributes:
-                crit += cls._adjust_for_extra_criteria(
-                    global_attributes, mapper
-                )
+        update_options += {"_synchronize_session": "fetch"}
+        return cls._do_pre_synchronize_fetch(
+            session,
+            statement,
+            params,
+            execution_options,
+            bind_arguments,
+            update_options,
+        )
 
-            if crit:
-                eval_condition = evaluator_compiler.process(*crit)
-            else:
+    @classmethod
+    def _do_pre_synchronize_evaluate(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        update_options,
+    ):
 
-                def eval_condition(obj):
-                    return True
+        try:
+            eval_condition = cls._eval_condition_from_statement(
+                update_options, statement
+            )
 
         except evaluator.UnevaluatableError as err:
             raise sa_exc.InvalidRequestError(
@@ -521,52 +964,8 @@ class BulkUDCompileState(ORMDMLState):
                 "synchronize_session execution option." % err
             ) from err
 
-        if statement.__visit_name__ == "lambda_element":
-            # ._resolved is called on every LambdaElement in order to
-            # generate the cache key, so this access does not add
-            # additional expense
-            effective_statement = statement._resolved
-        else:
-            effective_statement = statement
-
-        if effective_statement.__visit_name__ == "update":
-            resolved_values = cls._get_resolved_values(
-                mapper, effective_statement
-            )
-            value_evaluators = {}
-            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
-                mapper, resolved_values
-            )
-            for key, value in resolved_keys_as_propnames:
-                try:
-                    _evaluator = evaluator_compiler.process(
-                        coercions.expect(roles.ExpressionElementRole, value)
-                    )
-                except evaluator.UnevaluatableError:
-                    pass
-                else:
-                    value_evaluators[key] = _evaluator
-
-        # TODO: detect when the where clause is a trivial primary key match.
-        matched_objects = [
-            state.obj()
-            for state in session.identity_map.all_states()
-            if state.mapper.isa(mapper)
-            and not state.expired
-            and eval_condition(state.obj())
-            and (
-                update_options._refresh_identity_token is None
-                # TODO: coverage for the case where horizontal sharding
-                # invokes an update() or delete() given an explicit identity
-                # token up front
-                or state.identity_token
-                == update_options._refresh_identity_token
-            )
-        ]
         return update_options + {
-            "_matched_objects": matched_objects,
-            "_value_evaluators": value_evaluators,
-            "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+            "_eval_condition": eval_condition,
         }
 
     @classmethod
@@ -584,12 +983,6 @@ class BulkUDCompileState(ORMDMLState):
     def _resolved_keys_as_propnames(cls, mapper, resolved_values):
         values = []
         for k, v in resolved_values:
-            if isinstance(k, attributes.QueryableAttribute):
-                values.append((k.key, v))
-                continue
-            elif hasattr(k, "__clause_element__"):
-                k = k.__clause_element__()
-
             if mapper and isinstance(k, expression.ColumnElement):
                 try:
                     attr = mapper._columntoproperty[k]
@@ -599,7 +992,8 @@ class BulkUDCompileState(ORMDMLState):
                     values.append((attr.key, v))
             else:
                 raise sa_exc.InvalidRequestError(
-                    "Invalid expression type: %r" % k
+                    "Attribute name not found, can't be "
+                    "synchronized back to objects: %r" % k
                 )
         return values
 
@@ -622,17 +1016,46 @@ class BulkUDCompileState(ORMDMLState):
         )
         select_stmt._where_criteria = statement._where_criteria
 
+        # conditionally run the SELECT statement for pre-fetch, testing the
+        # "bind" for if we can use RETURNING or not using the do_orm_execute
+        # event.  If RETURNING is available, the do_orm_execute event
+        # will cancel the SELECT from being actually run.
+        #
+        # The way this is organized seems strange, why don't we just
+        # call can_use_returning() before invoking the statement and get
+        # answer?, why does this go through the whole execute phase using an
+        # event?  Answer: because we are integrating with extensions such
+        # as the horizontal sharding extention that "multiplexes" an individual
+        # statement run through multiple engines, and it uses
+        # do_orm_execute() to do that.
+
+        can_use_returning = None
+
         def skip_for_returning(orm_context: ORMExecuteState) -> Any:
             bind = orm_context.session.get_bind(**orm_context.bind_arguments)
-            if cls.can_use_returning(
+            nonlocal can_use_returning
+
+            per_bind_result = cls.can_use_returning(
                 bind.dialect,
                 mapper,
                 is_update_from=update_options._is_update_from,
                 is_delete_using=update_options._is_delete_using,
-            ):
-                return _result.null_result()
-            else:
-                return None
+            )
+
+            if can_use_returning is not None:
+                if can_use_returning != per_bind_result:
+                    raise sa_exc.InvalidRequestError(
+                        "For synchronize_session='fetch', can't mix multiple "
+                        "backends where some support RETURNING and others "
+                        "don't"
+                    )
+            else:
+                can_use_returning = per_bind_result
+
+            if per_bind_result:
+                return _result.null_result()
+            else:
+                return None
 
         result = session.execute(
             select_stmt,
@@ -643,52 +1066,22 @@ class BulkUDCompileState(ORMDMLState):
         )
         matched_rows = result.fetchall()
 
-        value_evaluators = EMPTY_DICT
-
-        if statement.__visit_name__ == "lambda_element":
-            # ._resolved is called on every LambdaElement in order to
-            # generate the cache key, so this access does not add
-            # additional expense
-            effective_statement = statement._resolved
-        else:
-            effective_statement = statement
-
-        if effective_statement.__visit_name__ == "update":
-            target_cls = mapper.class_
-            evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
-            resolved_values = cls._get_resolved_values(
-                mapper, effective_statement
-            )
-            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
-                mapper, resolved_values
-            )
-
-            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
-                mapper, resolved_values
-            )
-            value_evaluators = {}
-            for key, value in resolved_keys_as_propnames:
-                try:
-                    _evaluator = evaluator_compiler.process(
-                        coercions.expect(roles.ExpressionElementRole, value)
-                    )
-                except evaluator.UnevaluatableError:
-                    pass
-                else:
-                    value_evaluators[key] = _evaluator
-
-        else:
-            resolved_keys_as_propnames = EMPTY_DICT
-
         return update_options + {
-            "_value_evaluators": value_evaluators,
             "_matched_rows": matched_rows,
-            "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+            "_can_use_returning": can_use_returning,
         }
 
 
 @CompileState.plugin_for("orm", "insert")
-class ORMInsert(ORMDMLState, InsertDMLState):
+class BulkORMInsert(ORMDMLState, InsertDMLState):
+    class default_insert_options(Options):
+        _dml_strategy: _DMLStrategyArgument = "auto"
+        _render_nulls: bool = False
+        _return_defaults: bool = False
+        _subject_mapper: Optional[Mapper[Any]] = None
+
+    select_statement: Optional[FromStatement] = None
+
     @classmethod
     def orm_pre_session_exec(
         cls,
@@ -699,6 +1092,16 @@ class ORMInsert(ORMDMLState, InsertDMLState):
         bind_arguments,
         is_reentrant_invoke,
     ):
+
+        (
+            insert_options,
+            execution_options,
+        ) = BulkORMInsert.default_insert_options.from_execution_options(
+            "_sa_orm_insert_options",
+            {"dml_strategy"},
+            execution_options,
+            statement._execution_options,
+        )
         bind_arguments["clause"] = statement
         try:
             plugin_subject = statement._propagate_attrs["plugin_subject"]
@@ -707,22 +1110,209 @@ class ORMInsert(ORMDMLState, InsertDMLState):
         else:
             bind_arguments["mapper"] = plugin_subject.mapper
 
+        insert_options += {"_subject_mapper": plugin_subject.mapper}
+
+        if not params:
+            if insert_options._dml_strategy == "auto":
+                insert_options += {"_dml_strategy": "orm"}
+            elif insert_options._dml_strategy == "bulk":
+                raise sa_exc.InvalidRequestError(
+                    'Can\'t use "bulk" ORM insert strategy without '
+                    "passing separate parameters"
+                )
+        else:
+            if insert_options._dml_strategy == "auto":
+                insert_options += {"_dml_strategy": "bulk"}
+            elif insert_options._dml_strategy == "orm":
+                raise sa_exc.InvalidRequestError(
+                    'Can\'t use "orm" ORM insert strategy with a '
+                    "separate parameter list"
+                )
+
+        if insert_options._dml_strategy != "raw":
+            # for ORM object loading, like ORMContext, we have to disable
+            # result set adapt_to_context, because we will be generating a
+            # new statement with specific columns that's cached inside of
+            # an ORMFromStatementCompileState, which we will re-use for
+            # each result.
+            if not execution_options:
+                execution_options = context._orm_load_exec_options
+            else:
+                execution_options = execution_options.union(
+                    context._orm_load_exec_options
+                )
+
+        statement = statement._annotate(
+            {"dml_strategy": insert_options._dml_strategy}
+        )
+
         return (
             statement,
-            util.immutabledict(execution_options),
+            util.immutabledict(execution_options).union(
+                {"_sa_orm_insert_options": insert_options}
+            ),
         )
 
     @classmethod
-    def orm_setup_cursor_result(
+    def orm_execute_statement(
         cls,
-        session,
-        statement,
-        params,
-        execution_options,
-        bind_arguments,
-        result,
-    ):
-        return result
+        session: Session,
+        statement: dml.Insert,
+        params: _CoreAnyExecuteParams,
+        execution_options: _ExecuteOptionsParameter,
+        bind_arguments: _BindArguments,
+        conn: Connection,
+    ) -> _result.Result:
+
+        insert_options = execution_options.get(
+            "_sa_orm_insert_options", cls.default_insert_options
+        )
+
+        if insert_options._dml_strategy not in (
+            "raw",
+            "bulk",
+            "orm",
+            "auto",
+        ):
+            raise sa_exc.ArgumentError(
+                "Valid strategies for ORM insert strategy "
+                "are 'raw', 'orm', 'bulk', 'auto"
+            )
+
+        result: _result.Result[Any]
+
+        if insert_options._dml_strategy == "raw":
+            result = conn.execute(
+                statement, params or {}, execution_options=execution_options
+            )
+            return result
+
+        if insert_options._dml_strategy == "bulk":
+            mapper = insert_options._subject_mapper
+
+            if (
+                statement._post_values_clause is not None
+                and mapper._multiple_persistence_tables
+            ):
+                raise sa_exc.InvalidRequestError(
+                    "bulk INSERT with a 'post values' clause "
+                    "(typically upsert) not supported for multi-table "
+                    f"mapper {mapper}"
+                )
+
+            assert mapper is not None
+            assert session._transaction is not None
+            result = _bulk_insert(
+                mapper,
+                cast(
+                    "Iterable[Dict[str, Any]]",
+                    [params] if isinstance(params, dict) else params,
+                ),
+                session._transaction,
+                isstates=False,
+                return_defaults=insert_options._return_defaults,
+                render_nulls=insert_options._render_nulls,
+                use_orm_insert_stmt=statement,
+                execution_options=execution_options,
+            )
+        elif insert_options._dml_strategy == "orm":
+            result = conn.execute(
+                statement, params or {}, execution_options=execution_options
+            )
+        else:
+            raise AssertionError()
+
+        if not bool(statement._returning):
+            return result
+
+        return cls._return_orm_returning(
+            session,
+            statement,
+            params,
+            execution_options,
+            bind_arguments,
+            result,
+        )
+
+    @classmethod
+    def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
+
+        self = cast(
+            BulkORMInsert,
+            super().create_for_statement(statement, compiler, **kw),
+        )
+
+        if compiler is not None:
+            toplevel = not compiler.stack
+        else:
+            toplevel = True
+        if not toplevel:
+            return self
+
+        mapper = statement._propagate_attrs["plugin_subject"]
+        dml_strategy = statement._annotations.get("dml_strategy", "raw")
+        if dml_strategy == "bulk":
+            self._setup_for_bulk_insert(compiler)
+        elif dml_strategy == "orm":
+            self._setup_for_orm_insert(compiler, mapper)
+
+        return self
+
+    @classmethod
+    def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
+        return {
+            col.key if col is not None else k: v
+            for col, k, v in (
+                (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
+            )
+        }
+
+    def _setup_for_orm_insert(self, compiler, mapper):
+        statement = orm_level_statement = cast(dml.Insert, self.statement)
+
+        statement = self._setup_orm_returning(
+            compiler,
+            orm_level_statement,
+            statement,
+            use_supplemental_cols=False,
+        )
+        self.statement = statement
+
+    def _setup_for_bulk_insert(self, compiler):
+        """establish an INSERT statement within the context of
+        bulk insert.
+
+        This method will be within the "conn.execute()" call that is invoked
+        by persistence._emit_insert_statement().
+
+        """
+        statement = orm_level_statement = cast(dml.Insert, self.statement)
+        an = statement._annotations
+
+        emit_insert_table, emit_insert_mapper = (
+            an["_emit_insert_table"],
+            an["_emit_insert_mapper"],
+        )
+
+        statement = statement._clone()
+
+        statement.table = emit_insert_table
+        if self._dict_parameters:
+            self._dict_parameters = {
+                col: val
+                for col, val in self._dict_parameters.items()
+                if col.table is emit_insert_table
+            }
+
+        statement = self._setup_orm_returning(
+            compiler,
+            orm_level_statement,
+            statement,
+            use_supplemental_cols=True,
+            dml_mapper=emit_insert_mapper,
+        )
+
+        self.statement = statement
 
 
 @CompileState.plugin_for("orm", "update")
@@ -732,13 +1322,27 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
 
         self = cls.__new__(cls)
 
+        dml_strategy = statement._annotations.get(
+            "dml_strategy", "unspecified"
+        )
+
+        if dml_strategy == "bulk":
+            self._setup_for_bulk_update(statement, compiler)
+        elif dml_strategy in ("orm", "unspecified"):
+            self._setup_for_orm_update(statement, compiler)
+
+        return self
+
+    def _setup_for_orm_update(self, statement, compiler, **kw):
+        orm_level_statement = statement
+
         ext_info = statement.table._annotations["parententity"]
 
         self.mapper = mapper = ext_info.mapper
 
         self.extra_criteria_entities = {}
 
-        self._resolved_values = cls._get_resolved_values(mapper, statement)
+        self._resolved_values = self._get_resolved_values(mapper, statement)
 
         extra_criteria_attributes = {}
 
@@ -749,8 +1353,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
         if statement._values:
             self._resolved_values = dict(self._resolved_values)
 
-        new_stmt = sql.Update.__new__(sql.Update)
-        new_stmt.__dict__.update(statement.__dict__)
+        new_stmt = statement._clone()
         new_stmt.table = mapper.local_table
 
         # note if the statement has _multi_values, these
@@ -762,7 +1365,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
         elif statement._values:
             new_stmt._values = self._resolved_values
 
-        new_crit = cls._adjust_for_extra_criteria(
+        new_crit = self._adjust_for_extra_criteria(
             extra_criteria_attributes, mapper
         )
         if new_crit:
@@ -776,21 +1379,150 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
 
         UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
 
-        if compiler._annotations.get(
+        use_supplemental_cols = False
+
+        synchronize_session = compiler._annotations.get(
             "synchronize_session", None
-        ) == "fetch" and self.can_use_returning(
-            compiler.dialect, mapper, is_multitable=self.is_multitable
-        ):
-            if new_stmt._returning:
-                raise sa_exc.InvalidRequestError(
-                    "Can't use synchronize_session='fetch' "
-                    "with explicit returning()"
+        )
+        can_use_returning = compiler._annotations.get(
+            "can_use_returning", None
+        )
+        if can_use_returning is not False:
+            # even though pre_exec has determined basic
+            # can_use_returning for the dialect, if we are to use
+            # RETURNING we need to run can_use_returning() at this level
+            # unconditionally because is_delete_using was not known
+            # at the pre_exec level
+            can_use_returning = (
+                synchronize_session == "fetch"
+                and self.can_use_returning(
+                    compiler.dialect, mapper, is_multitable=self.is_multitable
                 )
-            self.statement = self.statement.returning(
-                *mapper.local_table.primary_key
             )
 
-        return self
+        if synchronize_session == "fetch" and can_use_returning:
+            use_supplemental_cols = True
+
+            # NOTE: we might want to RETURNING the actual columns to be
+            # synchronized also.  however this is complicated and difficult
+            # to align against the behavior of "evaluate".  Additionally,
+            # in a large number (if not the majority) of cases, we have the
+            # "evaluate" answer, usually a fixed value, in memory already and
+            # there's no need to re-fetch the same value
+            # over and over again.   so perhaps if it could be RETURNING just
+            # the elements that were based on a SQL expression and not
+            # a constant.   For now it doesn't quite seem worth it
+            new_stmt = new_stmt.return_defaults(
+                *(list(mapper.local_table.primary_key))
+            )
+
+        new_stmt = self._setup_orm_returning(
+            compiler,
+            orm_level_statement,
+            new_stmt,
+            use_supplemental_cols=use_supplemental_cols,
+        )
+
+        self.statement = new_stmt
+
+    def _setup_for_bulk_update(self, statement, compiler, **kw):
+        """establish an UPDATE statement within the context of
+        bulk insert.
+
+        This method will be within the "conn.execute()" call that is invoked
+        by persistence._emit_update_statement().
+
+        """
+        statement = cast(dml.Update, statement)
+        an = statement._annotations
+
+        emit_update_table, _ = (
+            an["_emit_update_table"],
+            an["_emit_update_mapper"],
+        )
+
+        statement = statement._clone()
+        statement.table = emit_update_table
+
+        UpdateDMLState.__init__(self, statement, compiler, **kw)
+
+        if self._ordered_values:
+            raise sa_exc.InvalidRequestError(
+                "bulk ORM UPDATE does not support ordered_values() for "
+                "custom UPDATE statements with bulk parameter sets.  Use a "
+                "non-bulk UPDATE statement or use values()."
+            )
+
+        if self._dict_parameters:
+            self._dict_parameters = {
+                col: val
+                for col, val in self._dict_parameters.items()
+                if col.table is emit_update_table
+            }
+        self.statement = statement
+
+    @classmethod
+    def orm_execute_statement(
+        cls,
+        session: Session,
+        statement: dml.Update,
+        params: _CoreAnyExecuteParams,
+        execution_options: _ExecuteOptionsParameter,
+        bind_arguments: _BindArguments,
+        conn: Connection,
+    ) -> _result.Result:
+
+        update_options = execution_options.get(
+            "_sa_orm_update_options", cls.default_update_options
+        )
+
+        if update_options._dml_strategy not in ("orm", "auto", "bulk"):
+            raise sa_exc.ArgumentError(
+                "Valid strategies for ORM UPDATE strategy "
+                "are 'orm', 'auto', 'bulk'"
+            )
+
+        result: _result.Result[Any]
+
+        if update_options._dml_strategy == "bulk":
+            if statement._where_criteria:
+                raise sa_exc.InvalidRequestError(
+                    "WHERE clause with bulk ORM UPDATE not "
+                    "supported right now.   Statement may be invoked at the "
+                    "Core level using "
+                    "session.connection().execute(stmt, parameters)"
+                )
+            mapper = update_options._subject_mapper
+            assert mapper is not None
+            assert session._transaction is not None
+            result = _bulk_update(
+                mapper,
+                cast(
+                    "Iterable[Dict[str, Any]]",
+                    [params] if isinstance(params, dict) else params,
+                ),
+                session._transaction,
+                isstates=False,
+                update_changed_only=False,
+                use_orm_update_stmt=statement,
+            )
+            return cls.orm_setup_cursor_result(
+                session,
+                statement,
+                params,
+                execution_options,
+                bind_arguments,
+                result,
+            )
+        else:
+            return super().orm_execute_statement(
+                session,
+                statement,
+                params,
+                execution_options,
+                bind_arguments,
+                conn,
+            )
 
     @classmethod
     def can_use_returning(
@@ -827,119 +1559,80 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
         return True
 
     @classmethod
-    def _get_crud_kv_pairs(cls, statement, kv_iterator):
-        plugin_subject = statement._propagate_attrs["plugin_subject"]
-
-        core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
-
-        if not plugin_subject or not plugin_subject.mapper:
-            return core_get_crud_kv_pairs(statement, kv_iterator)
-
-        mapper = plugin_subject.mapper
-
-        values = []
-
-        for k, v in kv_iterator:
-            k = coercions.expect(roles.DMLColumnRole, k)
+    def _do_post_synchronize_bulk_evaluate(
+        cls, session, params, result, update_options
+    ):
+        if not params:
+            return
 
-            if isinstance(k, str):
-                desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
-                if desc is NO_VALUE:
-                    values.append(
-                        (
-                            k,
-                            coercions.expect(
-                                roles.ExpressionElementRole,
-                                v,
-                                type_=sqltypes.NullType(),
-                                is_crud=True,
-                            ),
-                        )
-                    )
-                else:
-                    values.extend(
-                        core_get_crud_kv_pairs(
-                            statement, desc._bulk_update_tuples(v)
-                        )
-                    )
-            elif "entity_namespace" in k._annotations:
-                k_anno = k._annotations
-                attr = _entity_namespace_key(
-                    k_anno["entity_namespace"], k_anno["proxy_key"]
-                )
-                values.extend(
-                    core_get_crud_kv_pairs(
-                        statement, attr._bulk_update_tuples(v)
-                    )
-                )
-            else:
-                values.append(
-                    (
-                        k,
-                        coercions.expect(
-                            roles.ExpressionElementRole,
-                            v,
-                            type_=sqltypes.NullType(),
-                            is_crud=True,
-                        ),
-                    )
-                )
-        return values
+        mapper = update_options._subject_mapper
+        pk_keys = [prop.key for prop in mapper._identity_key_props]
 
-    @classmethod
-    def _do_post_synchronize_evaluate(cls, session, result, update_options):
+        identity_map = session.identity_map
 
-        states = set()
-        evaluated_keys = list(update_options._value_evaluators.keys())
-        values = update_options._resolved_keys_as_propnames
-        attrib = set(k for k, v in values)
-        for obj in update_options._matched_objects:
-
-            state, dict_ = (
-                attributes.instance_state(obj),
-                attributes.instance_dict(obj),
+        for param in params:
+            identity_key = mapper.identity_key_from_primary_key(
+                (param[key] for key in pk_keys),
+                update_options._refresh_identity_token,
             )
-
-            # the evaluated states were gathered across all identity tokens.
-            # however the post_sync events are called per identity token,
-            # so filter.
-            if (
-                update_options._refresh_identity_token is not None
-                and state.identity_token
-                != update_options._refresh_identity_token
-            ):
+            state = identity_map.fast_get_state(identity_key)
+            if not state:
                 continue
 
+            evaluated_keys = set(param).difference(pk_keys)
+
+            dict_ = state.dict
             # only evaluate unmodified attributes
             to_evaluate = state.unmodified.intersection(evaluated_keys)
             for key in to_evaluate:
                 if key in dict_:
-                    dict_[key] = update_options._value_evaluators[key](obj)
+                    dict_[key] = param[key]
 
             state.manager.dispatch.refresh(state, None, to_evaluate)
 
             state._commit(dict_, list(to_evaluate))
 
-            to_expire = attrib.intersection(dict_).difference(to_evaluate)
+            # attributes that were formerly modified instead get expired.
+            # this only gets hit if the session had pending changes
+            # and autoflush were set to False.
+            to_expire = evaluated_keys.intersection(dict_).difference(
+                to_evaluate
+            )
             if to_expire:
                 state._expire_attributes(dict_, to_expire)
 
-            states.add(state)
-        session._register_altered(states)
+    @classmethod
+    def _do_post_synchronize_evaluate(
+        cls, session, statement, result, update_options
+    ):
+
+        matched_objects = cls._get_matched_objects_on_criteria(
+            update_options,
+            session.identity_map.all_states(),
+        )
+
+        cls._apply_update_set_values_to_objects(
+            session,
+            update_options,
+            statement,
+            [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
+        )
 
     @classmethod
-    def _do_post_synchronize_fetch(cls, session, result, update_options):
+    def _do_post_synchronize_fetch(
+        cls, session, statement, result, update_options
+    ):
         target_mapper = update_options._subject_mapper
 
-        states = set()
-        evaluated_keys = list(update_options._value_evaluators.keys())
-
-        if result.returns_rows:
-            rows = cls._interpret_returning_rows(target_mapper, result.all())
+        returned_defaults_rows = result.returned_defaults_rows
+        if returned_defaults_rows:
+            pk_rows = cls._interpret_returning_rows(
+                target_mapper, returned_defaults_rows
+            )
 
             matched_rows = [
                 tuple(row) + (update_options._refresh_identity_token,)
-                for row in rows
+                for row in pk_rows
             ]
         else:
             matched_rows = update_options._matched_rows
@@ -960,23 +1653,69 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
             if identity_key in session.identity_map
         ]
 
-        values = update_options._resolved_keys_as_propnames
-        attrib = set(k for k, v in values)
+        if not objs:
+            return
 
-        for obj in objs:
-            state, dict_ = (
-                attributes.instance_state(obj),
-                attributes.instance_dict(obj),
-            )
+        cls._apply_update_set_values_to_objects(
+            session,
+            update_options,
+            statement,
+            [
+                (
+                    obj,
+                    attributes.instance_state(obj),
+                    attributes.instance_dict(obj),
+                )
+                for obj in objs
+            ],
+        )
+
+    @classmethod
+    def _apply_update_set_values_to_objects(
+        cls, session, update_options, statement, matched_objects
+    ):
+        """apply values to objects derived from an update statement, e.g.
+        UPDATE..SET <values>
+
+        """
+        mapper = update_options._subject_mapper
+        target_cls = mapper.class_
+        evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+        resolved_values = cls._get_resolved_values(mapper, statement)
+        resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+            mapper, resolved_values
+        )
+        value_evaluators = {}
+        for key, value in resolved_keys_as_propnames:
+            try:
+                _evaluator = evaluator_compiler.process(
+                    coercions.expect(roles.ExpressionElementRole, value)
+                )
+            except evaluator.UnevaluatableError:
+                pass
+            else:
+                value_evaluators[key] = _evaluator
+
+        evaluated_keys = list(value_evaluators.keys())
+        attrib = set(k for k, v in resolved_keys_as_propnames)
+
+        states = set()
+        for obj, state, dict_ in matched_objects:
 
             to_evaluate = state.unmodified.intersection(evaluated_keys)
+
             for key in to_evaluate:
                 if key in dict_:
-                    dict_[key] = update_options._value_evaluators[key](obj)
+                    # only run eval for attributes that are present.
+                    dict_[key] = value_evaluators[key](obj)
+
             state.manager.dispatch.refresh(state, None, to_evaluate)
 
             state._commit(dict_, list(to_evaluate))
 
+            # attributes that were formerly modified instead get expired.
+            # this only gets hit if the session had pending changes
+            # and autoflush were set to False.
             to_expire = attrib.intersection(dict_).difference(to_evaluate)
             if to_expire:
                 state._expire_attributes(dict_, to_expire)
@@ -991,6 +1730,8 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
     def create_for_statement(cls, statement, compiler, **kw):
         self = cls.__new__(cls)
 
+        orm_level_statement = statement
+
         ext_info = statement.table._annotations["parententity"]
         self.mapper = mapper = ext_info.mapper
 
@@ -1002,30 +1743,96 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
             if opt._is_criteria_option:
                 opt.get_global_criteria(extra_criteria_attributes)
 
+        new_stmt = statement._clone()
+        new_stmt.table = mapper.local_table
+
         new_crit = cls._adjust_for_extra_criteria(
             extra_criteria_attributes, mapper
         )
         if new_crit:
-            statement = statement.where(*new_crit)
+            new_stmt = new_stmt.where(*new_crit)
 
         # do this first as we need to determine if there is
         # DELETE..FROM
-        DeleteDMLState.__init__(self, statement, compiler, **kw)
+        DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
 
-        if compiler._annotations.get(
+        use_supplemental_cols = False
+
+        synchronize_session = compiler._annotations.get(
             "synchronize_session", None
-        ) == "fetch" and self.can_use_returning(
-            compiler.dialect,
-            mapper,
-            is_multitable=self.is_multitable,
-            is_delete_using=compiler._annotations.get(
-                "is_delete_using", False
-            ),
-        ):
-            self.statement = statement.returning(*statement.table.primary_key)
+        )
+        can_use_returning = compiler._annotations.get(
+            "can_use_returning", None
+        )
+        if can_use_returning is not False:
+            # even though pre_exec has determined basic
+            # can_use_returning for the dialect, if we are to use
+            # RETURNING we need to run can_use_returning() at this level
+            # unconditionally because is_delete_using was not known
+            # at the pre_exec level
+            can_use_returning = (
+                synchronize_session == "fetch"
+                and self.can_use_returning(
+                    compiler.dialect,
+                    mapper,
+                    is_multitable=self.is_multitable,
+                    is_delete_using=compiler._annotations.get(
+                        "is_delete_using", False
+                    ),
+                )
+            )
+
+        if can_use_returning:
+            use_supplemental_cols = True
+
+            new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
+
+        new_stmt = self._setup_orm_returning(
+            compiler,
+            orm_level_statement,
+            new_stmt,
+            use_supplemental_cols=use_supplemental_cols,
+        )
+
+        self.statement = new_stmt
 
         return self
 
+    @classmethod
+    def orm_execute_statement(
+        cls,
+        session: Session,
+        statement: dml.Delete,
+        params: _CoreAnyExecuteParams,
+        execution_options: _ExecuteOptionsParameter,
+        bind_arguments: _BindArguments,
+        conn: Connection,
+    ) -> _result.Result:
+
+        update_options = execution_options.get(
+            "_sa_orm_update_options", cls.default_update_options
+        )
+
+        if update_options._dml_strategy == "bulk":
+            raise sa_exc.InvalidRequestError(
+                "Bulk ORM DELETE not supported right now. "
+                "Statement may be invoked at the "
+                "Core level using "
+                "session.connection().execute(stmt, parameters)"
+            )
+
+        if update_options._dml_strategy not in (
+            "orm",
+            "auto",
+        ):
+            raise sa_exc.ArgumentError(
+                "Valid strategies for ORM DELETE strategy are 'orm', 'auto'"
+            )
+
+        return super().orm_execute_statement(
+            session, statement, params, execution_options, bind_arguments, conn
+        )
+
     @classmethod
     def can_use_returning(
         cls,
@@ -1068,25 +1875,41 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
         return True
 
     @classmethod
-    def _do_post_synchronize_evaluate(cls, session, result, update_options):
-
-        session._remove_newly_deleted(
-            [
-                attributes.instance_state(obj)
-                for obj in update_options._matched_objects
-            ]
+    def _do_post_synchronize_evaluate(
+        cls, session, statement, result, update_options
+    ):
+        matched_objects = cls._get_matched_objects_on_criteria(
+            update_options,
+            session.identity_map.all_states(),
         )
 
+        to_delete = []
+
+        for _, state, dict_, is_partially_expired in matched_objects:
+            if is_partially_expired:
+                state._expire(dict_, session.identity_map._modified)
+            else:
+                to_delete.append(state)
+
+        if to_delete:
+            session._remove_newly_deleted(to_delete)
+
     @classmethod
-    def _do_post_synchronize_fetch(cls, session, result, update_options):
+    def _do_post_synchronize_fetch(
+        cls, session, statement, result, update_options
+    ):
         target_mapper = update_options._subject_mapper
 
-        if result.returns_rows:
-            rows = cls._interpret_returning_rows(target_mapper, result.all())
+        returned_defaults_rows = result.returned_defaults_rows
+
+        if returned_defaults_rows:
+            pk_rows = cls._interpret_returning_rows(
+                target_mapper, returned_defaults_rows
+            )
 
             matched_rows = [
                 tuple(row) + (update_options._refresh_identity_token,)
-                for row in rows
+                for row in pk_rows
             ]
         else:
             matched_rows = update_options._matched_rows
index dc96f8c3c076b4d1233e36f3885a9ca645d4a2eb..f8c7ba7143efdb8dd16166e8d911767fd505c9f9 100644 (file)
@@ -73,6 +73,7 @@ if TYPE_CHECKING:
     from .query import Query
     from .session import _BindArguments
     from .session import Session
+    from ..engine import Result
     from ..engine.interfaces import _CoreSingleExecuteParams
     from ..engine.interfaces import _ExecuteOptionsParameter
     from ..sql._typing import _ColumnsClauseArgument
@@ -203,15 +204,19 @@ _orm_load_exec_options = util.immutabledict(
 
 
 class AbstractORMCompileState(CompileState):
+    is_dml_returning = False
+
     @classmethod
     def create_for_statement(
         cls,
         statement: Union[Select, FromStatement],
         compiler: Optional[SQLCompiler],
         **kw: Any,
-    ) -> ORMCompileState:
+    ) -> AbstractORMCompileState:
         """Create a context for a statement given a :class:`.Compiler`.
+
         This method is always invoked in the context of SQLCompiler.process().
+
         For a Select object, this would be invoked from
         SQLCompiler.visit_select(). For the special FromStatement object used
         by Query to indicate "Query.from_statement()", this is called by
@@ -232,6 +237,28 @@ class AbstractORMCompileState(CompileState):
     ):
         raise NotImplementedError()
 
+    @classmethod
+    def orm_execute_statement(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        conn,
+    ) -> Result:
+        result = conn.execute(
+            statement, params or {}, execution_options=execution_options
+        )
+        return cls.orm_setup_cursor_result(
+            session,
+            statement,
+            params,
+            execution_options,
+            bind_arguments,
+            result,
+        )
+
     @classmethod
     def orm_setup_cursor_result(
         cls,
@@ -309,6 +336,17 @@ class ORMCompileState(AbstractORMCompileState):
     def __init__(self, *arg, **kw):
         raise NotImplementedError()
 
+    if TYPE_CHECKING:
+
+        @classmethod
+        def create_for_statement(
+            cls,
+            statement: Union[Select, FromStatement],
+            compiler: Optional[SQLCompiler],
+            **kw: Any,
+        ) -> ORMCompileState:
+            ...
+
     def _append_dedupe_col_collection(self, obj, col_collection):
         dedupe = self.dedupe_columns
         if obj not in dedupe:
@@ -332,26 +370,6 @@ class ORMCompileState(AbstractORMCompileState):
         else:
             return SelectState._column_naming_convention(label_style)
 
-    @classmethod
-    def create_for_statement(
-        cls,
-        statement: Union[Select, FromStatement],
-        compiler: Optional[SQLCompiler],
-        **kw: Any,
-    ) -> ORMCompileState:
-        """Create a context for a statement given a :class:`.Compiler`.
-
-        This method is always invoked in the context of SQLCompiler.process().
-
-        For a Select object, this would be invoked from
-        SQLCompiler.visit_select(). For the special FromStatement object used
-        by Query to indicate "Query.from_statement()", this is called by
-        FromStatement._compiler_dispatch() that would be called by
-        SQLCompiler.process().
-
-        """
-        raise NotImplementedError()
-
     @classmethod
     def get_column_descriptions(cls, statement):
         return _column_descriptions(statement)
@@ -518,6 +536,49 @@ class ORMCompileState(AbstractORMCompileState):
         )
 
 
+class DMLReturningColFilter:
+    """an adapter used for the DML RETURNING case.
+
+    Has a subset of the interface used by
+    :class:`.ORMAdapter` and is used for :class:`._QueryEntity`
+    instances to set up their columns as used in RETURNING for a
+    DML statement.
+
+    """
+
+    __slots__ = ("mapper", "columns", "__weakref__")
+
+    def __init__(self, target_mapper, immediate_dml_mapper):
+        if (
+            immediate_dml_mapper is not None
+            and target_mapper.local_table
+            is not immediate_dml_mapper.local_table
+        ):
+            # joined inh, or in theory other kinds of multi-table mappings
+            self.mapper = immediate_dml_mapper
+        else:
+            # single inh, normal mappings, etc.
+            self.mapper = target_mapper
+        self.columns = self.columns = util.WeakPopulateDict(
+            self.adapt_check_present  # type: ignore
+        )
+
+    def __call__(self, col, as_filter):
+        for cc in sql_util._find_columns(col):
+            c2 = self.adapt_check_present(cc)
+            if c2 is not None:
+                return col
+        else:
+            return None
+
+    def adapt_check_present(self, col):
+        mapper = self.mapper
+        prop = mapper._columntoproperty.get(col, None)
+        if prop is None:
+            return None
+        return mapper.local_table.c.corresponding_column(col)
+
+
 @sql.base.CompileState.plugin_for("orm", "orm_from_statement")
 class ORMFromStatementCompileState(ORMCompileState):
     _from_obj_alias = None
@@ -525,7 +586,7 @@ class ORMFromStatementCompileState(ORMCompileState):
 
     statement_container: FromStatement
     requested_statement: Union[SelectBase, TextClause, UpdateBase]
-    dml_table: _DMLTableElement
+    dml_table: Optional[_DMLTableElement] = None
 
     _has_orm_entities = False
     multi_row_eager_loaders = False
@@ -541,7 +602,7 @@ class ORMFromStatementCompileState(ORMCompileState):
         statement_container: Union[Select, FromStatement],
         compiler: Optional[SQLCompiler],
         **kw: Any,
-    ) -> ORMCompileState:
+    ) -> ORMFromStatementCompileState:
 
         if compiler is not None:
             toplevel = not compiler.stack
@@ -565,6 +626,7 @@ class ORMFromStatementCompileState(ORMCompileState):
 
         if statement.is_dml:
             self.dml_table = statement.table
+            self.is_dml_returning = True
 
         self._entities = []
         self._polymorphic_adapters = {}
@@ -674,6 +736,18 @@ class ORMFromStatementCompileState(ORMCompileState):
     def _get_current_adapter(self):
         return None
 
+    def setup_dml_returning_compile_state(self, dml_mapper):
+        """used by BulkORMInsert (and Update / Delete?) to set up a handler
+        for RETURNING to return ORM objects and expressions
+
+        """
+        target_mapper = self.statement._propagate_attrs.get(
+            "plugin_subject", None
+        )
+        adapter = DMLReturningColFilter(target_mapper, dml_mapper)
+        for entity in self._entities:
+            entity.setup_dml_returning_compile_state(self, adapter)
+
 
 class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
     """Core construct that represents a load of ORM objects from various
@@ -813,7 +887,7 @@ class ORMSelectCompileState(ORMCompileState, SelectState):
         statement: Union[Select, FromStatement],
         compiler: Optional[SQLCompiler],
         **kw: Any,
-    ) -> ORMCompileState:
+    ) -> ORMSelectCompileState:
         """compiler hook, we arrive here from compiler.visit_select() only."""
 
         self = cls.__new__(cls)
@@ -2312,6 +2386,13 @@ class _QueryEntity:
     def setup_compile_state(self, compile_state: ORMCompileState) -> None:
         raise NotImplementedError()
 
+    def setup_dml_returning_compile_state(
+        self,
+        compile_state: ORMCompileState,
+        adapter: DMLReturningColFilter,
+    ) -> None:
+        raise NotImplementedError()
+
     def row_processor(self, context, result):
         raise NotImplementedError()
 
@@ -2509,8 +2590,24 @@ class _MapperEntity(_QueryEntity):
 
         return _instance, self._label_name, self._extra_entities
 
-    def setup_compile_state(self, compile_state):
+    def setup_dml_returning_compile_state(
+        self,
+        compile_state: ORMCompileState,
+        adapter: DMLReturningColFilter,
+    ) -> None:
+        loading._setup_entity_query(
+            compile_state,
+            self.mapper,
+            self,
+            self.path,
+            adapter,
+            compile_state.primary_columns,
+            with_polymorphic=self._with_polymorphic_mappers,
+            only_load_props=compile_state.compile_options._only_load_props,
+            polymorphic_discriminator=self._polymorphic_discriminator,
+        )
 
+    def setup_compile_state(self, compile_state):
         adapter = self._get_entity_clauses(compile_state)
 
         single_table_crit = self.mapper._single_table_criterion
@@ -2536,7 +2633,6 @@ class _MapperEntity(_QueryEntity):
             only_load_props=compile_state.compile_options._only_load_props,
             polymorphic_discriminator=self._polymorphic_discriminator,
         )
-
         compile_state._fallback_from_clauses.append(self.selectable)
 
 
@@ -2743,9 +2839,7 @@ class _ColumnEntity(_QueryEntity):
             getter, label_name, extra_entities = self._row_processor
             if self.translate_raw_column:
                 extra_entities += (
-                    result.context.invoked_statement._raw_columns[
-                        self.raw_column_index
-                    ],
+                    context.query._raw_columns[self.raw_column_index],
                 )
 
             return getter, label_name, extra_entities
@@ -2781,9 +2875,7 @@ class _ColumnEntity(_QueryEntity):
 
         if self.translate_raw_column:
             extra_entities = self._extra_entities + (
-                result.context.invoked_statement._raw_columns[
-                    self.raw_column_index
-                ],
+                context.query._raw_columns[self.raw_column_index],
             )
             return getter, self._label_name, extra_entities
         else:
@@ -2843,6 +2935,8 @@ class _RawColumnEntity(_ColumnEntity):
         current_adapter = compile_state._get_current_adapter()
         if current_adapter:
             column = current_adapter(self.column, False)
+            if column is None:
+                return
         else:
             column = self.column
 
@@ -2944,10 +3038,25 @@ class _ORMColumnEntity(_ColumnEntity):
                 self.entity_zero
             ) and entity.common_parent(self.entity_zero)
 
+    def setup_dml_returning_compile_state(
+        self,
+        compile_state: ORMCompileState,
+        adapter: DMLReturningColFilter,
+    ) -> None:
+        self._fetch_column = self.column
+        column = adapter(self.column, False)
+        if column is not None:
+            compile_state.dedupe_columns.add(column)
+            compile_state.primary_columns.append(column)
+
     def setup_compile_state(self, compile_state):
         current_adapter = compile_state._get_current_adapter()
         if current_adapter:
             column = current_adapter(self.column, False)
+            if column is None:
+                assert compile_state.is_dml_returning
+                self._fetch_column = self.column
+                return
         else:
             column = self.column
 
index 52b70b9d47a2aaded3060c896c8a518edf0fdf6b..13d3b70fe6ee89cf8e3f5f25370e7e9234abab80 100644 (file)
@@ -19,6 +19,7 @@ import operator
 import typing
 from typing import Any
 from typing import Callable
+from typing import Dict
 from typing import List
 from typing import NoReturn
 from typing import Optional
@@ -602,6 +603,31 @@ class Composite(
     def _attribute_keys(self) -> Sequence[str]:
         return [prop.key for prop in self.props]
 
+    def _populate_composite_bulk_save_mappings_fn(
+        self,
+    ) -> Callable[[Dict[str, Any]], None]:
+
+        if self._generated_composite_accessor:
+            get_values = self._generated_composite_accessor
+        else:
+
+            def get_values(val: Any) -> Tuple[Any]:
+                return val.__composite_values__()  # type: ignore
+
+        attrs = [prop.key for prop in self.props]
+
+        def populate(dest_dict: Dict[str, Any]) -> None:
+            dest_dict.update(
+                {
+                    key: val
+                    for key, val in zip(
+                        attrs, get_values(dest_dict.pop(self.key))
+                    )
+                }
+            )
+
+        return populate
+
     def get_history(
         self,
         state: InstanceState[Any],
index b3129afdd7a41317955940e1e9d79f0886d1d527..5af14cc004909a1a1cedc3eaa7e840650270d262 100644 (file)
@@ -9,8 +9,8 @@
 
 from __future__ import annotations
 
-import operator
-
+from .base import LoaderCallableStatus
+from .base import PassiveFlag
 from .. import exc
 from .. import inspect
 from .. import util
@@ -32,7 +32,16 @@ class _NoObject(operators.ColumnOperators):
         return None
 
 
+class _ExpiredObject(operators.ColumnOperators):
+    def operate(self, *arg, **kw):
+        return self
+
+    def reverse_operate(self, *arg, **kw):
+        return self
+
+
 _NO_OBJECT = _NoObject()
+_EXPIRED_OBJECT = _ExpiredObject()
 
 
 class EvaluatorCompiler:
@@ -73,6 +82,24 @@ class EvaluatorCompiler:
                     f"alternate class {parentmapper.class_}"
                 )
             key = parentmapper._columntoproperty[clause].key
+            impl = parentmapper.class_manager[key].impl
+
+            if impl is not None:
+
+                def get_corresponding_attr(obj):
+                    if obj is None:
+                        return _NO_OBJECT
+                    state = inspect(obj)
+                    dict_ = state.dict
+
+                    value = impl.get(
+                        state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
+                    )
+                    if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
+                        return _EXPIRED_OBJECT
+                    return value
+
+                return get_corresponding_attr
         else:
             key = clause.key
             if (
@@ -85,15 +112,16 @@ class EvaluatorCompiler:
                     "make use of the actual mapped columns in ORM-evaluated "
                     "UPDATE / DELETE expressions."
                 )
+
             else:
                 raise UnevaluatableError(f"Cannot evaluate column: {clause}")
 
-        get_corresponding_attr = operator.attrgetter(key)
-        return (
-            lambda obj: get_corresponding_attr(obj)
-            if obj is not None
-            else _NO_OBJECT
-        )
+        def get_corresponding_attr(obj):
+            if obj is None:
+                return _NO_OBJECT
+            return getattr(obj, key, _EXPIRED_OBJECT)
+
+        return get_corresponding_attr
 
     def visit_tuple(self, clause):
         return self.visit_clauselist(clause)
@@ -134,7 +162,9 @@ class EvaluatorCompiler:
             has_null = False
             for sub_evaluate in evaluators:
                 value = sub_evaluate(obj)
-                if value:
+                if value is _EXPIRED_OBJECT:
+                    return _EXPIRED_OBJECT
+                elif value:
                     return True
                 has_null = has_null or value is None
             if has_null:
@@ -147,6 +177,9 @@ class EvaluatorCompiler:
         def evaluate(obj):
             for sub_evaluate in evaluators:
                 value = sub_evaluate(obj)
+                if value is _EXPIRED_OBJECT:
+                    return _EXPIRED_OBJECT
+
                 if not value:
                     if value is None or value is _NO_OBJECT:
                         return None
@@ -160,7 +193,9 @@ class EvaluatorCompiler:
             values = []
             for sub_evaluate in evaluators:
                 value = sub_evaluate(obj)
-                if value is None or value is _NO_OBJECT:
+                if value is _EXPIRED_OBJECT:
+                    return _EXPIRED_OBJECT
+                elif value is None or value is _NO_OBJECT:
                     return None
                 values.append(value)
             return tuple(values)
@@ -183,13 +218,21 @@ class EvaluatorCompiler:
 
     def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
         def evaluate(obj):
-            return eval_left(obj) == eval_right(obj)
+            left_val = eval_left(obj)
+            right_val = eval_right(obj)
+            if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+                return _EXPIRED_OBJECT
+            return left_val == right_val
 
         return evaluate
 
     def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
         def evaluate(obj):
-            return eval_left(obj) != eval_right(obj)
+            left_val = eval_left(obj)
+            right_val = eval_right(obj)
+            if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+                return _EXPIRED_OBJECT
+            return left_val != right_val
 
         return evaluate
 
@@ -197,8 +240,11 @@ class EvaluatorCompiler:
         def evaluate(obj):
             left_val = eval_left(obj)
             right_val = eval_right(obj)
-            if left_val is None or right_val is None:
+            if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+                return _EXPIRED_OBJECT
+            elif left_val is None or right_val is None:
                 return None
+
             return operator(eval_left(obj), eval_right(obj))
 
         return evaluate
@@ -274,7 +320,9 @@ class EvaluatorCompiler:
 
             def evaluate(obj):
                 value = eval_inner(obj)
-                if value is None:
+                if value is _EXPIRED_OBJECT:
+                    return _EXPIRED_OBJECT
+                elif value is None:
                     return None
                 return not value
 
index 63b131a780e796105e624c9987aeb70431016342..4848f73f1364819dbd992305162af4cce7dcbbed 100644 (file)
@@ -68,6 +68,11 @@ class IdentityMap:
     ) -> Optional[_O]:
         raise NotImplementedError()
 
+    def fast_get_state(
+        self, key: _IdentityKeyType[_O]
+    ) -> Optional[InstanceState[_O]]:
+        raise NotImplementedError()
+
     def keys(self) -> Iterable[_IdentityKeyType[Any]]:
         return self._dict.keys()
 
@@ -206,6 +211,11 @@ class WeakInstanceDict(IdentityMap):
         self._dict[key] = state
         state._instance_dict = self._wr
 
+    def fast_get_state(
+        self, key: _IdentityKeyType[_O]
+    ) -> Optional[InstanceState[_O]]:
+        return self._dict.get(key)
+
     def get(
         self, key: _IdentityKeyType[_O], default: Optional[_O] = None
     ) -> Optional[_O]:
index 7317d48be97de9a016cfeea01b3fa75f123ce70e..64f2542fda883a4fc11b7ed1b704ce9445410353 100644 (file)
@@ -29,7 +29,6 @@ from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
-from sqlalchemy.orm.context import FromStatement
 from . import attributes
 from . import exc as orm_exc
 from . import path_registry
@@ -37,6 +36,7 @@ from .base import _DEFER_FOR_STATE
 from .base import _RAISE_FOR_STATE
 from .base import _SET_DEFERRED_EXPIRED
 from .base import PassiveFlag
+from .context import FromStatement
 from .util import _none_set
 from .util import state_str
 from .. import exc as sa_exc
@@ -50,6 +50,7 @@ from ..sql import util as sql_util
 from ..sql.selectable import ForUpdateArg
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..sql.selectable import SelectState
+from ..util import EMPTY_DICT
 
 if TYPE_CHECKING:
     from ._typing import _IdentityKeyType
@@ -764,7 +765,7 @@ def _instance_processor(
             )
 
         quick_populators = path.get(
-            context.attributes, "memoized_setups", _none_set
+            context.attributes, "memoized_setups", EMPTY_DICT
         )
 
         todo = []
index c8df51b0689d38f4db08d94e807377249f4c9423..c9cf8f49bf487bd8de2e76b91a03b0b9d2629f5d 100644 (file)
@@ -854,6 +854,7 @@ class Mapper(
     _memoized_values: Dict[Any, Callable[[], Any]]
     _inheriting_mappers: util.WeakSequence[Mapper[Any]]
     _all_tables: Set[Table]
+    _polymorphic_attr_key: Optional[str]
 
     _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]]
     _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]]
@@ -1653,6 +1654,7 @@ class Mapper(
 
         """
         setter = False
+        polymorphic_key: Optional[str] = None
 
         if self.polymorphic_on is not None:
             setter = True
@@ -1772,17 +1774,23 @@ class Mapper(
                         self._set_polymorphic_identity = (
                             mapper._set_polymorphic_identity
                         )
+                        self._polymorphic_attr_key = (
+                            mapper._polymorphic_attr_key
+                        )
                         self._validate_polymorphic_identity = (
                             mapper._validate_polymorphic_identity
                         )
                     else:
                         self._set_polymorphic_identity = None
+                        self._polymorphic_attr_key = None
                     return
 
         if setter:
 
             def _set_polymorphic_identity(state):
                 dict_ = state.dict
+                # TODO: what happens if polymorphic_on column attribute name
+                # does not match .key?
                 state.get_impl(polymorphic_key).set(
                     state,
                     dict_,
@@ -1790,6 +1798,8 @@ class Mapper(
                     None,
                 )
 
+            self._polymorphic_attr_key = polymorphic_key
+
             def _validate_polymorphic_identity(mapper, state, dict_):
                 if (
                     polymorphic_key in dict_
@@ -1808,6 +1818,7 @@ class Mapper(
                 _validate_polymorphic_identity
             )
         else:
+            self._polymorphic_attr_key = None
             self._set_polymorphic_identity = None
 
     _validate_polymorphic_identity = None
@@ -3561,6 +3572,10 @@ class Mapper(
     def _compiled_cache(self):
         return util.LRUCache(self._compiled_cache_size)
 
+    @HasMemoized.memoized_attribute
+    def _multiple_persistence_tables(self):
+        return len(self.tables) > 1
+
     @HasMemoized.memoized_attribute
     def _sorted_tables(self):
         table_to_mapper: Dict[Table, Mapper[Any]] = {}
index abd52898602b317665393c9b1513655e05da90b2..dfb61c28ac2e95659773a709dc6b6d2558845456 100644 (file)
@@ -31,6 +31,7 @@ from .. import exc as sa_exc
 from .. import future
 from .. import sql
 from .. import util
+from ..engine import cursor as _cursor
 from ..sql import operators
 from ..sql.elements import BooleanClauseList
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
@@ -398,6 +399,11 @@ def _collect_insert_commands(
                 None
             )
 
+        if bulk and mapper._set_polymorphic_identity:
+            params.setdefault(
+                mapper._polymorphic_attr_key, mapper.polymorphic_identity
+            )
+
         yield (
             state,
             state_dict,
@@ -411,7 +417,11 @@ def _collect_insert_commands(
 
 
 def _collect_update_commands(
-    uowtransaction, table, states_to_update, bulk=False
+    uowtransaction,
+    table,
+    states_to_update,
+    bulk=False,
+    use_orm_update_stmt=None,
 ):
     """Identify sets of values to use in UPDATE statements for a
     list of states.
@@ -437,7 +447,11 @@ def _collect_update_commands(
 
         pks = mapper._pks_by_table[table]
 
-        value_params = {}
+        if use_orm_update_stmt is not None:
+            # TODO: ordered values, etc
+            value_params = use_orm_update_stmt._values
+        else:
+            value_params = {}
 
         propkey_to_col = mapper._propkey_to_col[table]
 
@@ -697,6 +711,7 @@ def _emit_update_statements(
     table,
     update,
     bookkeeping=True,
+    use_orm_update_stmt=None,
 ):
     """Emit UPDATE statements corresponding to value lists collected
     by _collect_update_commands()."""
@@ -708,7 +723,7 @@ def _emit_update_statements(
 
     execution_options = {"compiled_cache": base_mapper._compiled_cache}
 
-    def update_stmt():
+    def update_stmt(existing_stmt=None):
         clauses = BooleanClauseList._construct_raw(operators.and_)
 
         for col in mapper._pks_by_table[table]:
@@ -725,10 +740,17 @@ def _emit_update_statements(
                 )
             )
 
-        stmt = table.update().where(clauses)
+        if existing_stmt is not None:
+            stmt = existing_stmt.where(clauses)
+        else:
+            stmt = table.update().where(clauses)
         return stmt
 
-    cached_stmt = base_mapper._memo(("update", table), update_stmt)
+    if use_orm_update_stmt is not None:
+        cached_stmt = update_stmt(use_orm_update_stmt)
+
+    else:
+        cached_stmt = base_mapper._memo(("update", table), update_stmt)
 
     for (
         (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
@@ -747,6 +769,15 @@ def _emit_update_statements(
         records = list(records)
 
         statement = cached_stmt
+
+        if use_orm_update_stmt is not None:
+            statement = statement._annotate(
+                {
+                    "_emit_update_table": table,
+                    "_emit_update_mapper": mapper,
+                }
+            )
+
         return_defaults = False
 
         if not has_all_pks:
@@ -904,16 +935,35 @@ def _emit_insert_statements(
     table,
     insert,
     bookkeeping=True,
+    use_orm_insert_stmt=None,
+    execution_options=None,
 ):
     """Emit INSERT statements corresponding to value lists collected
     by _collect_insert_commands()."""
 
-    cached_stmt = base_mapper._memo(("insert", table), table.insert)
+    if use_orm_insert_stmt is not None:
+        cached_stmt = use_orm_insert_stmt
+        exec_opt = util.EMPTY_DICT
 
-    execution_options = {"compiled_cache": base_mapper._compiled_cache}
+        # if a user query with RETURNING was passed, we definitely need
+        # to use RETURNING.
+        returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
+    else:
+        returning_is_required_anyway = False
+        cached_stmt = base_mapper._memo(("insert", table), table.insert)
+        exec_opt = {"compiled_cache": base_mapper._compiled_cache}
+
+    if execution_options:
+        execution_options = util.EMPTY_DICT.merge_with(
+            exec_opt, execution_options
+        )
+    else:
+        execution_options = exec_opt
+
+    return_result = None
 
     for (
-        (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
+        (connection, _, hasvalue, has_all_pks, has_all_defaults),
         records,
     ) in groupby(
         insert,
@@ -928,17 +978,29 @@ def _emit_insert_statements(
 
         statement = cached_stmt
 
+        if use_orm_insert_stmt is not None:
+            statement = statement._annotate(
+                {
+                    "_emit_insert_table": table,
+                    "_emit_insert_mapper": mapper,
+                }
+            )
+
         if (
-            not bookkeeping
-            or (
-                has_all_defaults
-                or not base_mapper.eager_defaults
-                or not base_mapper.local_table.implicit_returning
-                or not connection.dialect.insert_returning
+            (
+                not bookkeeping
+                or (
+                    has_all_defaults
+                    or not base_mapper.eager_defaults
+                    or not base_mapper.local_table.implicit_returning
+                    or not connection.dialect.insert_returning
+                )
             )
+            and not returning_is_required_anyway
             and has_all_pks
             and not hasvalue
         ):
+
             # the "we don't need newly generated values back" section.
             # here we have all the PKs, all the defaults or we don't want
             # to fetch them, or the dialect doesn't support RETURNING at all
@@ -946,7 +1008,7 @@ def _emit_insert_statements(
             records = list(records)
             multiparams = [rec[2] for rec in records]
 
-            c = connection.execute(
+            result = connection.execute(
                 statement, multiparams, execution_options=execution_options
             )
             if bookkeeping:
@@ -962,7 +1024,7 @@ def _emit_insert_statements(
                         has_all_defaults,
                     ),
                     last_inserted_params,
-                ) in zip(records, c.context.compiled_parameters):
+                ) in zip(records, result.context.compiled_parameters):
                     if state:
                         _postfetch(
                             mapper_rec,
@@ -970,19 +1032,20 @@ def _emit_insert_statements(
                             table,
                             state,
                             state_dict,
-                            c,
+                            result,
                             last_inserted_params,
                             value_params,
                             False,
-                            c.returned_defaults
-                            if not c.context.executemany
+                            result.returned_defaults
+                            if not result.context.executemany
                             else None,
                         )
                     else:
                         _postfetch_bulk_save(mapper_rec, state_dict, table)
 
         else:
-            # here, we need defaults and/or pk values back.
+            # here, we need defaults and/or pk values back or we otherwise
+            # know that we are using RETURNING in any case
 
             records = list(records)
             if (
@@ -991,6 +1054,16 @@ def _emit_insert_statements(
                 and len(records) > 1
             ):
                 do_executemany = True
+            elif returning_is_required_anyway:
+                if connection.dialect.insert_executemany_returning:
+                    do_executemany = True
+                else:
+                    raise sa_exc.InvalidRequestError(
+                        f"Can't use explicit RETURNING for bulk INSERT "
+                        f"operation with "
+                        f"{connection.dialect.dialect_description} backend; "
+                        f"executemany is not supported with RETURNING"
+                    )
             else:
                 do_executemany = False
 
@@ -998,6 +1071,7 @@ def _emit_insert_statements(
                 statement = statement.return_defaults(
                     *mapper._server_default_cols[table]
                 )
+
             if mapper.version_id_col is not None:
                 statement = statement.return_defaults(mapper.version_id_col)
             elif do_executemany:
@@ -1006,10 +1080,16 @@ def _emit_insert_statements(
             if do_executemany:
                 multiparams = [rec[2] for rec in records]
 
-                c = connection.execute(
+                result = connection.execute(
                     statement, multiparams, execution_options=execution_options
                 )
 
+                if use_orm_insert_stmt is not None:
+                    if return_result is None:
+                        return_result = result
+                    else:
+                        return_result = return_result.splice_vertically(result)
+
                 if bookkeeping:
                     for (
                         (
@@ -1027,9 +1107,9 @@ def _emit_insert_statements(
                         returned_defaults,
                     ) in zip_longest(
                         records,
-                        c.context.compiled_parameters,
-                        c.inserted_primary_key_rows,
-                        c.returned_defaults_rows or (),
+                        result.context.compiled_parameters,
+                        result.inserted_primary_key_rows,
+                        result.returned_defaults_rows or (),
                     ):
                         if inserted_primary_key is None:
                             # this is a real problem and means that we didn't
@@ -1062,7 +1142,7 @@ def _emit_insert_statements(
                                 table,
                                 state,
                                 state_dict,
-                                c,
+                                result,
                                 last_inserted_params,
                                 value_params,
                                 False,
@@ -1071,6 +1151,8 @@ def _emit_insert_statements(
                         else:
                             _postfetch_bulk_save(mapper_rec, state_dict, table)
             else:
+                assert not returning_is_required_anyway
+
                 for (
                     state,
                     state_dict,
@@ -1132,6 +1214,12 @@ def _emit_insert_statements(
                         else:
                             _postfetch_bulk_save(mapper_rec, state_dict, table)
 
+    if use_orm_insert_stmt is not None:
+        if return_result is None:
+            return _cursor.null_dml_result()
+        else:
+            return return_result
+
 
 def _emit_post_update_statements(
     base_mapper, uowtransaction, mapper, table, update
index 6d0f055e4bbb344c33eb27e7120d4566dccd73e4..4d5a98fcfded8b3360d8d471d01b23fa332b4be0 100644 (file)
@@ -2978,7 +2978,7 @@ class Query(
         )
 
     def delete(
-        self, synchronize_session: _SynchronizeSessionArgument = "evaluate"
+        self, synchronize_session: _SynchronizeSessionArgument = "auto"
     ) -> int:
         r"""Perform a DELETE with an arbitrary WHERE clause.
 
@@ -3042,7 +3042,7 @@ class Query(
     def update(
         self,
         values: Dict[_DMLColumnArgument, Any],
-        synchronize_session: _SynchronizeSessionArgument = "evaluate",
+        synchronize_session: _SynchronizeSessionArgument = "auto",
         update_args: Optional[Dict[Any, Any]] = None,
     ) -> int:
         r"""Perform an UPDATE with an arbitrary WHERE clause.
index a690da0d5e69fcc904d7d2ce601e17a8a9e63ffa..64c01330614f30bfa02fdfb8ed511f16a702f9ec 100644 (file)
@@ -1828,12 +1828,13 @@ class Session(_SessionClassMethods, EventTarget):
             statement._propagate_attrs.get("compile_state_plugin", None)
             == "orm"
         ):
-            # note that even without "future" mode, we need
             compile_state_cls = CompileState._get_plugin_class_for_plugin(
                 statement, "orm"
             )
             if TYPE_CHECKING:
-                assert isinstance(compile_state_cls, ORMCompileState)
+                assert isinstance(
+                    compile_state_cls, context.AbstractORMCompileState
+                )
         else:
             compile_state_cls = None
 
@@ -1897,18 +1898,18 @@ class Session(_SessionClassMethods, EventTarget):
                 statement, params or {}, execution_options=execution_options
             )
 
-        result: Result[Any] = conn.execute(
-            statement, params or {}, execution_options=execution_options
-        )
-
         if compile_state_cls:
-            result = compile_state_cls.orm_setup_cursor_result(
+            result: Result[Any] = compile_state_cls.orm_execute_statement(
                 self,
                 statement,
-                params,
+                params or {},
                 execution_options,
                 bind_arguments,
-                result,
+                conn,
+            )
+        else:
+            result = conn.execute(
+                statement, params or {}, execution_options=execution_options
             )
 
         if _scalar_result:
@@ -2066,7 +2067,7 @@ class Session(_SessionClassMethods, EventTarget):
     def scalars(
         self,
         statement: TypedReturnsRows[Tuple[_T]],
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -2078,7 +2079,7 @@ class Session(_SessionClassMethods, EventTarget):
     def scalars(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -2089,7 +2090,7 @@ class Session(_SessionClassMethods, EventTarget):
     def scalars(
         self,
         statement: Executable,
-        params: Optional[_CoreSingleExecuteParams] = None,
+        params: Optional[_CoreAnyExecuteParams] = None,
         *,
         execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
index 19c6493db43c8b1321d41fc7dd4de29c71082052..8652591c88089b19c0774c319f9125f545934a70 100644 (file)
@@ -227,6 +227,11 @@ class ColumnLoader(LoaderStrategy):
         fetch = self.columns[0]
         if adapter:
             fetch = adapter.columns[fetch]
+            if fetch is None:
+                # None happens here only for dml bulk_persistence cases
+                # when context.DMLReturningColFilter is used
+                return
+
         memoized_populators[self.parent_property] = fetch
 
     def init_class_attribute(self, mapper):
@@ -318,6 +323,12 @@ class ExpressionColumnLoader(ColumnLoader):
         fetch = columns[0]
         if adapter:
             fetch = adapter.columns[fetch]
+            if fetch is None:
+                # None is not expected to be the result of any
+                # adapter implementation here, however there may be theoretical
+                # usages of returning() with context.DMLReturningColFilter
+                return
+
         memoized_populators[self.parent_property] = fetch
 
     def create_row_processor(
index 86b2952cb83e3dab5f827e4ad88c340e257352a7..262048bd1d5ccf623247a8f6a5ef0e575d10e21c 100644 (file)
@@ -552,6 +552,8 @@ def _new_annotation_type(
     # e.g. BindParameter, add it if present.
     if cls.__dict__.get("inherit_cache", False):
         anno_cls.inherit_cache = True  # type: ignore
+    elif "inherit_cache" in cls.__dict__:
+        anno_cls.inherit_cache = cls.__dict__["inherit_cache"]  # type: ignore
 
     anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
 
index 201324a2a13af8726bf27f84bb1ba63fd9c7606e..c7e226fcc639de220c0a067c708c50cc03d1dbaa 100644 (file)
@@ -5166,6 +5166,8 @@ class SQLCompiler(Compiled):
             delete_stmt, delete_stmt.table, extra_froms
         )
 
+        crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
+
         if delete_stmt._hints:
             dialect_hints, table_text = self._setup_crud_hints(
                 delete_stmt, table_text
@@ -5178,13 +5180,14 @@ class SQLCompiler(Compiled):
 
         text += table_text
 
-        if delete_stmt._returning:
-            if self.returning_precedes_values:
-                text += " " + self.returning_clause(
-                    delete_stmt,
-                    delete_stmt._returning,
-                    populate_result_map=toplevel,
-                )
+        if (
+            self.implicit_returning or delete_stmt._returning
+        ) and self.returning_precedes_values:
+            text += " " + self.returning_clause(
+                delete_stmt,
+                self.implicit_returning or delete_stmt._returning,
+                populate_result_map=toplevel,
+            )
 
         if extra_froms:
             extra_from_text = self.delete_extra_from_clause(
@@ -5204,10 +5207,12 @@ class SQLCompiler(Compiled):
             if t:
                 text += " WHERE " + t
 
-        if delete_stmt._returning and not self.returning_precedes_values:
+        if (
+            self.implicit_returning or delete_stmt._returning
+        ) and not self.returning_precedes_values:
             text += " " + self.returning_clause(
                 delete_stmt,
-                delete_stmt._returning,
+                self.implicit_returning or delete_stmt._returning,
                 populate_result_map=toplevel,
             )
 
@@ -5297,7 +5302,6 @@ class StrSQLCompiler(SQLCompiler):
             self._label_select_column(None, c, True, False, {})
             for c in base._select_iterables(returning_cols)
         ]
-
         return "RETURNING " + ", ".join(columns)
 
     def update_from_clause(
index b13377a590a69f226e7f321d037cb6e065743862..22fffb73a1a2e210499cf7cdfbb2ecc4765a2a84 100644 (file)
@@ -150,6 +150,22 @@ def _get_crud_params(
             "return_defaults() simultaneously"
         )
 
+    if compile_state.isdelete:
+        _setup_delete_return_defaults(
+            compiler,
+            stmt,
+            compile_state,
+            (),
+            _getattr_col_key,
+            _column_as_key,
+            _col_bind_name,
+            (),
+            (),
+            toplevel,
+            kw,
+        )
+        return _CrudParams([], [])
+
     # no parameters in the statement, no parameters in the
     # compiled params - return binds for all columns
     if compiler.column_keys is None and compile_state._no_parameters:
@@ -466,13 +482,6 @@ def _scan_insert_from_select_cols(
     kw,
 ):
 
-    (
-        need_pks,
-        implicit_returning,
-        implicit_return_defaults,
-        postfetch_lastrowid,
-    ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
-
     cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
 
     assert compiler.stack[-1]["selectable"] is stmt
@@ -537,6 +546,8 @@ def _scan_cols(
         postfetch_lastrowid,
     ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
 
+    assert compile_state.isupdate or compile_state.isinsert
+
     if compile_state._parameter_ordering:
         parameter_ordering = [
             _column_as_key(key) for key in compile_state._parameter_ordering
@@ -563,6 +574,13 @@ def _scan_cols(
     else:
         autoincrement_col = insert_null_pk_still_autoincrements = None
 
+    if stmt._supplemental_returning:
+        supplemental_returning = set(stmt._supplemental_returning)
+    else:
+        supplemental_returning = set()
+
+    compiler_implicit_returning = compiler.implicit_returning
+
     for c in cols:
         # scan through every column in the target table
 
@@ -627,11 +645,13 @@ def _scan_cols(
                 # column has a DDL-level default, and is either not a pk
                 # column or we don't need the pk.
                 if implicit_return_defaults and c in implicit_return_defaults:
-                    compiler.implicit_returning.append(c)
+                    compiler_implicit_returning.append(c)
                 elif not c.primary_key:
                     compiler.postfetch.append(c)
+
             elif implicit_return_defaults and c in implicit_return_defaults:
-                compiler.implicit_returning.append(c)
+                compiler_implicit_returning.append(c)
+
             elif (
                 c.primary_key
                 and c is not stmt.table._autoincrement_column
@@ -652,6 +672,59 @@ def _scan_cols(
                 kw,
             )
 
+        # adding supplemental cols to implicit_returning in table
+        # order so that order is maintained between multiple INSERT
+        # statements which may have different parameters included, but all
+        # have the same RETURNING clause
+        if (
+            c in supplemental_returning
+            and c not in compiler_implicit_returning
+        ):
+            compiler_implicit_returning.append(c)
+
+    if supplemental_returning:
+        # we should have gotten every col into implicit_returning,
+        # however supplemental returning can also have SQL functions etc.
+        # in it
+        remaining_supplemental = supplemental_returning.difference(
+            compiler_implicit_returning
+        )
+        compiler_implicit_returning.extend(
+            c
+            for c in stmt._supplemental_returning
+            if c in remaining_supplemental
+        )
+
+
+def _setup_delete_return_defaults(
+    compiler,
+    stmt,
+    compile_state,
+    parameters,
+    _getattr_col_key,
+    _column_as_key,
+    _col_bind_name,
+    check_columns,
+    values,
+    toplevel,
+    kw,
+):
+    (_, _, implicit_return_defaults, _) = _get_returning_modifiers(
+        compiler, stmt, compile_state, toplevel
+    )
+
+    if not implicit_return_defaults:
+        return
+
+    if stmt._return_defaults_columns:
+        compiler.implicit_returning.extend(implicit_return_defaults)
+
+    if stmt._supplemental_returning:
+        ir_set = set(compiler.implicit_returning)
+        compiler.implicit_returning.extend(
+            c for c in stmt._supplemental_returning if c not in ir_set
+        )
+
 
 def _append_param_parameter(
     compiler,
@@ -743,7 +816,7 @@ def _append_param_parameter(
                 elif compiler.dialect.postfetch_lastrowid:
                     compiler.postfetch_lastrowid = True
 
-            elif implicit_return_defaults and c in implicit_return_defaults:
+            elif implicit_return_defaults and (c in implicit_return_defaults):
                 compiler.implicit_returning.append(c)
 
             else:
@@ -1303,6 +1376,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
     INSERT or UPDATE statement after it's invoked.
 
     """
+
     need_pks = (
         toplevel
         and _compile_state_isinsert(compile_state)
@@ -1315,6 +1389,7 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
             )
         )
         and not stmt._returning
+        # and (not stmt._returning or stmt._return_defaults)
         and not compile_state._has_multi_parameters
     )
 
@@ -1357,33 +1432,41 @@ def _get_returning_modifiers(compiler, stmt, compile_state, toplevel):
             or stmt._return_defaults
         )
     )
-
     if implicit_returning:
         postfetch_lastrowid = False
 
     if _compile_state_isinsert(compile_state):
-        implicit_return_defaults = implicit_returning and stmt._return_defaults
+        should_implicit_return_defaults = (
+            implicit_returning and stmt._return_defaults
+        )
     elif compile_state.isupdate:
-        implicit_return_defaults = (
+        should_implicit_return_defaults = (
             stmt._return_defaults
             and compile_state._primary_table.implicit_returning
             and compile_state._supports_implicit_returning
             and compiler.dialect.update_returning
         )
+    elif compile_state.isdelete:
+        should_implicit_return_defaults = (
+            stmt._return_defaults
+            and compile_state._primary_table.implicit_returning
+            and compile_state._supports_implicit_returning
+            and compiler.dialect.delete_returning
+        )
     else:
-        # this line is unused, currently we are always
-        # isinsert or isupdate
-        implicit_return_defaults = False  # pragma: no cover
+        should_implicit_return_defaults = False  # pragma: no cover
 
-    if implicit_return_defaults:
+    if should_implicit_return_defaults:
         if not stmt._return_defaults_columns:
             implicit_return_defaults = set(stmt.table.c)
         else:
             implicit_return_defaults = set(stmt._return_defaults_columns)
+    else:
+        implicit_return_defaults = None
 
     return (
         need_pks,
-        implicit_returning,
+        implicit_returning or should_implicit_return_defaults,
         implicit_return_defaults,
         postfetch_lastrowid,
     )
index a08e3880064b46c44306d5d4a79c310e3b543fd9..5145a4a16af43af528e5421c507156106699a4ad 100644 (file)
@@ -164,16 +164,33 @@ class DMLState(CompileState):
         def get_plugin_class(cls, statement: Executable) -> Type[DMLState]:
             ...
 
+    @classmethod
+    def _get_multi_crud_kv_pairs(
+        cls,
+        statement: UpdateBase,
+        multi_kv_iterator: Iterable[Dict[_DMLColumnArgument, Any]],
+    ) -> List[Dict[_DMLColumnElement, Any]]:
+        return [
+            {
+                coercions.expect(roles.DMLColumnRole, k): v
+                for k, v in mapping.items()
+            }
+            for mapping in multi_kv_iterator
+        ]
+
     @classmethod
     def _get_crud_kv_pairs(
         cls,
         statement: UpdateBase,
         kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]],
+        needs_to_be_cacheable: bool,
     ) -> List[Tuple[_DMLColumnElement, Any]]:
         return [
             (
                 coercions.expect(roles.DMLColumnRole, k),
-                coercions.expect(
+                v
+                if not needs_to_be_cacheable
+                else coercions.expect(
                     roles.ExpressionElementRole,
                     v,
                     type_=NullType(),
@@ -269,7 +286,7 @@ class InsertDMLState(DMLState):
     def _insert_col_keys(self) -> List[str]:
         # this is also done in crud.py -> _key_getters_for_crud_column
         return [
-            coercions.expect_as_key(roles.DMLColumnRole, col)
+            coercions.expect(roles.DMLColumnRole, col, as_key=True)
             for col in self._dict_parameters or ()
         ]
 
@@ -326,7 +343,6 @@ class UpdateDMLState(DMLState):
         self._extra_froms = ef
 
         self.is_multitable = mt = ef
-
         self.include_table_with_column_exprs = bool(
             mt and compiler.render_table_with_column_in_update_from
         )
@@ -389,6 +405,7 @@ class UpdateBase(
     _return_defaults_columns: Optional[
         Tuple[_ColumnsClauseElement, ...]
     ] = None
+    _supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None
     _returning: Tuple[_ColumnsClauseElement, ...] = ()
 
     is_dml = True
@@ -434,6 +451,215 @@ class UpdateBase(
         self._validate_dialect_kwargs(opt)
         return self
 
+    @_generative
+    def return_defaults(
+        self: SelfUpdateBase,
+        *cols: _DMLColumnArgument,
+        supplemental_cols: Optional[Iterable[_DMLColumnArgument]] = None,
+    ) -> SelfUpdateBase:
+        """Make use of a :term:`RETURNING` clause for the purpose
+        of fetching server-side expressions and defaults, for supporting
+        backends only.
+
+        .. deepalchemy::
+
+            The :meth:`.UpdateBase.return_defaults` method is used by the ORM
+            for its internal work in fetching newly generated primary key
+            and server default values, in particular to provide the underyling
+            implementation of the :paramref:`_orm.Mapper.eager_defaults`
+            ORM feature as well as to allow RETURNING support with bulk
+            ORM inserts.  Its behavior is fairly idiosyncratic
+            and is not really intended for general use.  End users should
+            stick with using :meth:`.UpdateBase.returning` in order to
+            add RETURNING clauses to their INSERT, UPDATE and DELETE
+            statements.
+
+        Normally, a single row INSERT statement will automatically populate the
+        :attr:`.CursorResult.inserted_primary_key` attribute when executed,
+        which stores the primary key of the row that was just inserted in the
+        form of a :class:`.Row` object with column names as named tuple keys
+        (and the :attr:`.Row._mapping` view fully populated as well). The
+        dialect in use chooses the strategy to use in order to populate this
+        data; if it was generated using server-side defaults and / or SQL
+        expressions, dialect-specific approaches such as ``cursor.lastrowid``
+        or ``RETURNING`` are typically used to acquire the new primary key
+        value.
+
+        However, when the statement is modified by calling
+        :meth:`.UpdateBase.return_defaults` before executing the statement,
+        additional behaviors take place **only** for backends that support
+        RETURNING and for :class:`.Table` objects that maintain the
+        :paramref:`.Table.implicit_returning` parameter at its default value of
+        ``True``. In these cases, when the :class:`.CursorResult` is returned
+        from the statement's execution, not only will
+        :attr:`.CursorResult.inserted_primary_key` be populated as always, the
+        :attr:`.CursorResult.returned_defaults` attribute will also be
+        populated with a :class:`.Row` named-tuple representing the full range
+        of server generated
+        values from that single row, including values for any columns that
+        specify :paramref:`_schema.Column.server_default` or which make use of
+        :paramref:`_schema.Column.default` using a SQL expression.
+
+        When invoking INSERT statements with multiple rows using
+        :ref:`insertmanyvalues <engine_insertmanyvalues>`, the
+        :meth:`.UpdateBase.return_defaults` modifier will have the effect of
+        the :attr:`_engine.CursorResult.inserted_primary_key_rows` and
+        :attr:`_engine.CursorResult.returned_defaults_rows` attributes being
+        fully populated with lists of :class:`.Row` objects representing newly
+        inserted primary key values as well as newly inserted server generated
+        values for each row inserted. The
+        :attr:`.CursorResult.inserted_primary_key` and
+        :attr:`.CursorResult.returned_defaults` attributes will also continue
+        to be populated with the first row of these two collections.
+
+        If the backend does not support RETURNING or the :class:`.Table` in use
+        has disabled :paramref:`.Table.implicit_returning`, then no RETURNING
+        clause is added and no additional data is fetched, however the
+        INSERT, UPDATE or DELETE statement proceeds normally.
+
+        E.g.::
+
+            stmt = table.insert().values(data='newdata').return_defaults()
+
+            result = connection.execute(stmt)
+
+            server_created_at = result.returned_defaults['created_at']
+
+        When used against an UPDATE statement
+        :meth:`.UpdateBase.return_defaults` instead looks for columns that
+        include :paramref:`_schema.Column.onupdate` or
+        :paramref:`_schema.Column.server_onupdate` parameters assigned, when
+        constructing the columns that will be included in the RETURNING clause
+        by default if explicit columns were not specified. When used against a
+        DELETE statement, no columns are included in RETURNING by default, they
+        instead must be specified explicitly as there are no columns that
+        normally change values when a DELETE statement proceeds.
+
+        .. versionadded:: 2.0  :meth:`.UpdateBase.return_defaults` is supported
+           for DELETE statements also and has been moved from
+           :class:`.ValuesBase` to :class:`.UpdateBase`.
+
+        The :meth:`.UpdateBase.return_defaults` method is mutually exclusive
+        against the :meth:`.UpdateBase.returning` method and errors will be
+        raised during the SQL compilation process if both are used at the same
+        time on one statement. The RETURNING clause of the INSERT, UPDATE or
+        DELETE statement is therefore controlled by only one of these methods
+        at a time.
+
+        The :meth:`.UpdateBase.return_defaults` method differs from
+        :meth:`.UpdateBase.returning` in these ways:
+
+        1. :meth:`.UpdateBase.return_defaults` method causes the
+           :attr:`.CursorResult.returned_defaults` collection to be populated
+           with the first row from the RETURNING result. This attribute is not
+           populated when using :meth:`.UpdateBase.returning`.
+
+        2. :meth:`.UpdateBase.return_defaults` is compatible with existing
+           logic used to fetch auto-generated primary key values that are then
+           populated into the :attr:`.CursorResult.inserted_primary_key`
+           attribute. By contrast, using :meth:`.UpdateBase.returning` will
+           have the effect of the :attr:`.CursorResult.inserted_primary_key`
+           attribute being left unpopulated.
+
+        3. :meth:`.UpdateBase.return_defaults` can be called against any
+           backend. Backends that don't support RETURNING will skip the usage
+           of the feature, rather than raising an exception. The return value
+           of :attr:`_engine.CursorResult.returned_defaults` will be ``None``
+           for backends that don't support RETURNING or for which the target
+           :class:`.Table` sets :paramref:`.Table.implicit_returning` to
+           ``False``.
+
+        4. An INSERT statement invoked with executemany() is supported if the
+           backend database driver supports the
+           :ref:`insertmanyvalues <engine_insertmanyvalues>`
+           feature which is now supported by most SQLAlchemy-included backends.
+           When executemany is used, the
+           :attr:`_engine.CursorResult.returned_defaults_rows` and
+           :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
+           will return the inserted defaults and primary keys.
+
+           .. versionadded:: 1.4 Added
+              :attr:`_engine.CursorResult.returned_defaults_rows` and
+              :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors.
+              In version 2.0, the underlying implementation which fetches and
+              populates the data for these attributes was generalized to be
+              supported by most backends, whereas in 1.4 they were only
+              supported by the ``psycopg2`` driver.
+
+
+        :param cols: optional list of column key names or
+         :class:`_schema.Column` that acts as a filter for those columns that
+         will be fetched.
+        :param supplemental_cols: optional list of RETURNING expressions,
+          in the same form as one would pass to the
+          :meth:`.UpdateBase.returning` method. When present, the additional
+          columns will be included in the RETURNING clause, and the
+          :class:`.CursorResult` object will be "rewound" when returned, so
+          that methods like :meth:`.CursorResult.all` will return new rows
+          mostly as though the statement used :meth:`.UpdateBase.returning`
+          directly. However, unlike when using :meth:`.UpdateBase.returning`
+          directly, the **order of the columns is undefined**, so can only be
+          targeted using names or :attr:`.Row._mapping` keys; they cannot
+          reliably be targeted positionally.
+
+          .. versionadded:: 2.0
+
+        .. seealso::
+
+            :meth:`.UpdateBase.returning`
+
+            :attr:`_engine.CursorResult.returned_defaults`
+
+            :attr:`_engine.CursorResult.returned_defaults_rows`
+
+            :attr:`_engine.CursorResult.inserted_primary_key`
+
+            :attr:`_engine.CursorResult.inserted_primary_key_rows`
+
+        """
+
+        if self._return_defaults:
+            # note _return_defaults_columns = () means return all columns,
+            # so if we have been here before, only update collection if there
+            # are columns in the collection
+            if self._return_defaults_columns and cols:
+                self._return_defaults_columns = tuple(
+                    util.OrderedSet(self._return_defaults_columns).union(
+                        coercions.expect(roles.ColumnsClauseRole, c)
+                        for c in cols
+                    )
+                )
+            else:
+                # set for all columns
+                self._return_defaults_columns = ()
+        else:
+            self._return_defaults_columns = tuple(
+                coercions.expect(roles.ColumnsClauseRole, c) for c in cols
+            )
+        self._return_defaults = True
+
+        if supplemental_cols:
+            # uniquifying while also maintaining order (the maintain of order
+            # is for test suites but also for vertical splicing
+            supplemental_col_tup = (
+                coercions.expect(roles.ColumnsClauseRole, c)
+                for c in supplemental_cols
+            )
+
+            if self._supplemental_returning is None:
+                self._supplemental_returning = tuple(
+                    util.unique_list(supplemental_col_tup)
+                )
+            else:
+                self._supplemental_returning = tuple(
+                    util.unique_list(
+                        self._supplemental_returning
+                        + tuple(supplemental_col_tup)
+                    )
+                )
+
+        return self
+
     @_generative
     def returning(
         self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
@@ -500,7 +726,7 @@ class UpdateBase(
 
         .. seealso::
 
-          :meth:`.ValuesBase.return_defaults` - an alternative method tailored
+          :meth:`.UpdateBase.return_defaults` - an alternative method tailored
           towards efficient fetching of server-side defaults and triggers
           for single-row INSERTs or UPDATEs.
 
@@ -703,7 +929,6 @@ class ValuesBase(UpdateBase):
 
     _select_names: Optional[List[str]] = None
     _inline: bool = False
-    _returning: Tuple[_ColumnsClauseElement, ...] = ()
 
     def __init__(self, table: _DMLTableArgument):
         self.table = coercions.expect(
@@ -859,7 +1084,15 @@ class ValuesBase(UpdateBase):
                 )
 
             elif isinstance(arg, collections_abc.Sequence):
-                if arg and isinstance(arg[0], (list, dict, tuple)):
+
+                if arg and isinstance(arg[0], dict):
+                    multi_kv_generator = DMLState.get_plugin_class(
+                        self
+                    )._get_multi_crud_kv_pairs
+                    self._multi_values += (multi_kv_generator(self, arg),)
+                    return self
+
+                if arg and isinstance(arg[0], (list, tuple)):
                     self._multi_values += (arg,)
                     return self
 
@@ -888,173 +1121,13 @@ class ValuesBase(UpdateBase):
         # and ensures they get the "crud"-style name when rendered.
 
         kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
-        coerced_arg = {k: v for k, v in kv_generator(self, arg.items())}
+        coerced_arg = dict(kv_generator(self, arg.items(), True))
         if self._values:
             self._values = self._values.union(coerced_arg)
         else:
             self._values = util.immutabledict(coerced_arg)
         return self
 
-    @_generative
-    def return_defaults(
-        self: SelfValuesBase, *cols: _DMLColumnArgument
-    ) -> SelfValuesBase:
-        """Make use of a :term:`RETURNING` clause for the purpose
-        of fetching server-side expressions and defaults, for supporting
-        backends only.
-
-        .. tip::
-
-            The :meth:`.ValuesBase.return_defaults` method is used by the ORM
-            for its internal work in fetching newly generated primary key
-            and server default values, in particular to provide the underyling
-            implementation of the :paramref:`_orm.Mapper.eager_defaults`
-            ORM feature.  Its behavior is fairly idiosyncratic
-            and is not really intended for general use.  End users should
-            stick with using :meth:`.UpdateBase.returning` in order to
-            add RETURNING clauses to their INSERT, UPDATE and DELETE
-            statements.
-
-        Normally, a single row INSERT statement will automatically populate the
-        :attr:`.CursorResult.inserted_primary_key` attribute when executed,
-        which stores the primary key of the row that was just inserted in the
-        form of a :class:`.Row` object with column names as named tuple keys
-        (and the :attr:`.Row._mapping` view fully populated as well). The
-        dialect in use chooses the strategy to use in order to populate this
-        data; if it was generated using server-side defaults and / or SQL
-        expressions, dialect-specific approaches such as ``cursor.lastrowid``
-        or ``RETURNING`` are typically used to acquire the new primary key
-        value.
-
-        However, when the statement is modified by calling
-        :meth:`.ValuesBase.return_defaults` before executing the statement,
-        additional behaviors take place **only** for backends that support
-        RETURNING and for :class:`.Table` objects that maintain the
-        :paramref:`.Table.implicit_returning` parameter at its default value of
-        ``True``. In these cases, when the :class:`.CursorResult` is returned
-        from the statement's execution, not only will
-        :attr:`.CursorResult.inserted_primary_key` be populated as always, the
-        :attr:`.CursorResult.returned_defaults` attribute will also be
-        populated with a :class:`.Row` named-tuple representing the full range
-        of server generated
-        values from that single row, including values for any columns that
-        specify :paramref:`_schema.Column.server_default` or which make use of
-        :paramref:`_schema.Column.default` using a SQL expression.
-
-        When invoking INSERT statements with multiple rows using
-        :ref:`insertmanyvalues <engine_insertmanyvalues>`, the
-        :meth:`.ValuesBase.return_defaults` modifier will have the effect of
-        the :attr:`_engine.CursorResult.inserted_primary_key_rows` and
-        :attr:`_engine.CursorResult.returned_defaults_rows` attributes being
-        fully populated with lists of :class:`.Row` objects representing newly
-        inserted primary key values as well as newly inserted server generated
-        values for each row inserted. The
-        :attr:`.CursorResult.inserted_primary_key` and
-        :attr:`.CursorResult.returned_defaults` attributes will also continue
-        to be populated with the first row of these two collections.
-
-        If the backend does not support RETURNING or the :class:`.Table` in use
-        has disabled :paramref:`.Table.implicit_returning`, then no RETURNING
-        clause is added and no additional data is fetched, however the
-        INSERT or UPDATE statement proceeds normally.
-
-
-        E.g.::
-
-            stmt = table.insert().values(data='newdata').return_defaults()
-
-            result = connection.execute(stmt)
-
-            server_created_at = result.returned_defaults['created_at']
-
-
-        The :meth:`.ValuesBase.return_defaults` method is mutually exclusive
-        against the :meth:`.UpdateBase.returning` method and errors will be
-        raised during the SQL compilation process if both are used at the same
-        time on one statement. The RETURNING clause of the INSERT or UPDATE
-        statement is therefore controlled by only one of these methods at a
-        time.
-
-        The :meth:`.ValuesBase.return_defaults` method differs from
-        :meth:`.UpdateBase.returning` in these ways:
-
-        1. :meth:`.ValuesBase.return_defaults` method causes the
-           :attr:`.CursorResult.returned_defaults` collection to be populated
-           with the first row from the RETURNING result. This attribute is not
-           populated when using :meth:`.UpdateBase.returning`.
-
-        2. :meth:`.ValuesBase.return_defaults` is compatible with existing
-           logic used to fetch auto-generated primary key values that are then
-           populated into the :attr:`.CursorResult.inserted_primary_key`
-           attribute. By contrast, using :meth:`.UpdateBase.returning` will
-           have the effect of the :attr:`.CursorResult.inserted_primary_key`
-           attribute being left unpopulated.
-
-        3. :meth:`.ValuesBase.return_defaults` can be called against any
-           backend. Backends that don't support RETURNING will skip the usage
-           of the feature, rather than raising an exception. The return value
-           of :attr:`_engine.CursorResult.returned_defaults` will be ``None``
-           for backends that don't support RETURNING or for which the target
-           :class:`.Table` sets :paramref:`.Table.implicit_returning` to
-           ``False``.
-
-        4. An INSERT statement invoked with executemany() is supported if the
-           backend database driver supports the
-           :ref:`insertmanyvalues <engine_insertmanyvalues>`
-           feature which is now supported by most SQLAlchemy-included backends.
-           When executemany is used, the
-           :attr:`_engine.CursorResult.returned_defaults_rows` and
-           :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
-           will return the inserted defaults and primary keys.
-
-           .. versionadded:: 1.4 Added
-              :attr:`_engine.CursorResult.returned_defaults_rows` and
-              :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors.
-              In version 2.0, the underlying implementation which fetches and
-              populates the data for these attributes was generalized to be
-              supported by most backends, whereas in 1.4 they were only
-              supported by the ``psycopg2`` driver.
-
-
-        :param cols: optional list of column key names or
-         :class:`_schema.Column` that acts as a filter for those columns that
-         will be fetched.
-
-        .. seealso::
-
-            :meth:`.UpdateBase.returning`
-
-            :attr:`_engine.CursorResult.returned_defaults`
-
-            :attr:`_engine.CursorResult.returned_defaults_rows`
-
-            :attr:`_engine.CursorResult.inserted_primary_key`
-
-            :attr:`_engine.CursorResult.inserted_primary_key_rows`
-
-        """
-
-        if self._return_defaults:
-            # note _return_defaults_columns = () means return all columns,
-            # so if we have been here before, only update collection if there
-            # are columns in the collection
-            if self._return_defaults_columns and cols:
-                self._return_defaults_columns = tuple(
-                    set(self._return_defaults_columns).union(
-                        coercions.expect(roles.ColumnsClauseRole, c)
-                        for c in cols
-                    )
-                )
-            else:
-                # set for all columns
-                self._return_defaults_columns = ()
-        else:
-            self._return_defaults_columns = tuple(
-                coercions.expect(roles.ColumnsClauseRole, c) for c in cols
-            )
-        self._return_defaults = True
-        return self
-
 
 SelfInsert = typing.TypeVar("SelfInsert", bound="Insert")
 
@@ -1459,7 +1532,7 @@ class Update(DMLWhereBase, ValuesBase):
             )
 
         kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
-        self._ordered_values = kv_generator(self, args)
+        self._ordered_values = kv_generator(self, args, True)
         return self
 
     @_generative
index 4416fe630e1fd7069817406491078f8d49d1e813..b3a71dbffc2b7717a60f40c9e0017dc191155431 100644 (file)
@@ -68,7 +68,7 @@ class CursorSQL(SQLMatchRule):
 
 class CompiledSQL(SQLMatchRule):
     def __init__(
-        self, statement, params=None, dialect="default", enable_returning=False
+        self, statement, params=None, dialect="default", enable_returning=True
     ):
         self.statement = statement
         self.params = params
@@ -90,6 +90,17 @@ class CompiledSQL(SQLMatchRule):
                 dialect.insert_returning = (
                     dialect.update_returning
                 ) = dialect.delete_returning = True
+                dialect.use_insertmanyvalues = True
+                dialect.supports_multivalues_insert = True
+                dialect.update_returning_multifrom = True
+                dialect.delete_returning_multifrom = True
+                # dialect.favor_returning_over_lastrowid = True
+                # dialect.insert_null_pk_still_autoincrements = True
+
+                # this is calculated but we need it to be True for this
+                # to look like all the current RETURNING dialects
+                assert dialect.insert_executemany_returning
+
             return dialect
         else:
             return url.URL.create(self.dialect).get_dialect()()
index 20dee5273bb96150517a7b0e0b3e09acc59f4de1..ef284babc1d9b9a3e1e10fe04bb516f73d4ba823 100644 (file)
@@ -23,7 +23,6 @@ from .util import adict
 from .util import drop_all_tables_from_metadata
 from .. import event
 from .. import util
-from ..orm import declarative_base
 from ..orm import DeclarativeBase
 from ..orm import MappedAsDataclass
 from ..orm import registry
@@ -117,7 +116,7 @@ class TestBase:
             metadata=metadata,
             type_annotation_map={
                 str: sa.String().with_variant(
-                    sa.String(50), "mysql", "mariadb"
+                    sa.String(50), "mysql", "mariadb", "oracle"
                 )
             },
         )
@@ -132,7 +131,7 @@ class TestBase:
             metadata = _md
             type_annotation_map = {
                 str: sa.String().with_variant(
-                    sa.String(50), "mysql", "mariadb"
+                    sa.String(50), "mysql", "mariadb", "oracle"
                 )
             }
 
@@ -780,18 +779,19 @@ class DeclarativeMappedTest(MappedTest):
     def _with_register_classes(cls, fn):
         cls_registry = cls.classes
 
-        class DeclarativeBasic:
+        class _DeclBase(DeclarativeBase):
             __table_cls__ = schema.Table
+            metadata = cls._tables_metadata
+            type_annotation_map = {
+                str: sa.String().with_variant(
+                    sa.String(50), "mysql", "mariadb", "oracle"
+                )
+            }
 
-            def __init_subclass__(cls) -> None:
+            def __init_subclass__(cls, **kw) -> None:
                 assert cls_registry is not None
                 cls_registry[cls.__name__] = cls
-                super().__init_subclass__()
-
-        _DeclBase = declarative_base(
-            metadata=cls._tables_metadata,
-            cls=DeclarativeBasic,
-        )
+                super().__init_subclass__(**kw)
 
         cls.DeclarativeBasic = _DeclBase
 
index b7d4b7452f9f55da4b041b6323227b36d1812787..8e19a24a8c3d3ab608b36b9a9d7d2646fa1c0536 100644 (file)
@@ -89,8 +89,13 @@ class RowCountTest(fixtures.TablesTest):
         eq_(r.rowcount, 3)
 
     @testing.requires.update_returning
-    @testing.requires.sane_rowcount_w_returning
     def test_update_rowcount_return_defaults(self, connection):
+        """note this test should succeed for all RETURNING backends
+        as of 2.0.  In
+        Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use
+        len(rows) when we have implicit returning
+
+        """
         employees_table = self.tables.employees
 
         department = employees_table.c.department
index f8348714c57c27d462b270a74a2b8151fad8bfd2..488229abbe9c01a9a5d3195962625a47418eb44c 100644 (file)
@@ -11,6 +11,7 @@ from __future__ import annotations
 from itertools import filterfalse
 from typing import AbstractSet
 from typing import Any
+from typing import Callable
 from typing import cast
 from typing import Collection
 from typing import Dict
@@ -481,7 +482,9 @@ class IdentitySet:
         return "%s(%r)" % (type(self).__name__, list(self._members.values()))
 
 
-def unique_list(seq, hashfunc=None):
+def unique_list(
+    seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None
+) -> List[_T]:
     seen: Set[Any] = set()
     seen_add = seen.add
     if not hashfunc:
index 7cc6a6f790a567c9d6000f073d64a8e683cb0fef..667f4bfb08da43130e53cb09a7ec242f10dc2d9d 100644 (file)
@@ -465,7 +465,11 @@ class ShardTest:
         t = get_tokyo(sess2)
         eq_(t.city, tokyo.city)
 
-    def test_bulk_update_synchronize_evaluate(self):
+    @testing.combinations(
+        "fetch", "evaluate", "auto", argnames="synchronize_session"
+    )
+    @testing.combinations(True, False, argnames="legacy")
+    def test_orm_update_synchronize(self, synchronize_session, legacy):
         sess = self._fixture_data()
 
         eq_(
@@ -476,33 +480,25 @@ class ShardTest:
         temps = sess.query(Report).all()
         eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
 
-        sess.query(Report).filter(Report.temperature >= 80).update(
-            {"temperature": Report.temperature + 6},
-            synchronize_session="evaluate",
-        )
-
-        eq_(
-            set(row.temperature for row in sess.query(Report.temperature)),
-            {86.0, 75.0, 91.0},
-        )
-
-        # test synchronize session as well
-        eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
-
-    def test_bulk_update_synchronize_fetch(self):
-        sess = self._fixture_data()
-
-        eq_(
-            set(row.temperature for row in sess.query(Report.temperature)),
-            {80.0, 75.0, 85.0},
-        )
+        if legacy:
+            sess.query(Report).filter(Report.temperature >= 80).update(
+                {"temperature": Report.temperature + 6},
+                synchronize_session=synchronize_session,
+            )
+        else:
+            sess.execute(
+                update(Report)
+                .filter(Report.temperature >= 80)
+                .values(temperature=Report.temperature + 6)
+                .execution_options(synchronize_session=synchronize_session)
+            )
 
-        temps = sess.query(Report).all()
-        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+        # test synchronize session
+        def go():
+            eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
 
-        sess.query(Report).filter(Report.temperature >= 80).update(
-            {"temperature": Report.temperature + 6},
-            synchronize_session="fetch",
+        self.assert_sql_count(
+            sess._ShardedSession__binds["north_america"], go, 0
         )
 
         eq_(
@@ -510,165 +506,41 @@ class ShardTest:
             {86.0, 75.0, 91.0},
         )
 
-        # test synchronize session as well
-        eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
-
-    def test_bulk_delete_synchronize_evaluate(self):
-        sess = self._fixture_data()
-
-        temps = sess.query(Report).all()
-        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
-        sess.query(Report).filter(Report.temperature >= 80).delete(
-            synchronize_session="evaluate"
-        )
-
-        eq_(
-            set(row.temperature for row in sess.query(Report.temperature)),
-            {75.0},
-        )
-
-        # test synchronize session as well
-        for t in temps:
-            assert inspect(t).deleted is (t.temperature >= 80)
-
-    def test_bulk_delete_synchronize_fetch(self):
+    @testing.combinations(
+        "fetch", "evaluate", "auto", argnames="synchronize_session"
+    )
+    @testing.combinations(True, False, argnames="legacy")
+    def test_orm_delete_synchronize(self, synchronize_session, legacy):
         sess = self._fixture_data()
 
         temps = sess.query(Report).all()
         eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
 
-        sess.query(Report).filter(Report.temperature >= 80).delete(
-            synchronize_session="fetch"
-        )
-
-        eq_(
-            set(row.temperature for row in sess.query(Report.temperature)),
-            {75.0},
-        )
-
-        # test synchronize session as well
-        for t in temps:
-            assert inspect(t).deleted is (t.temperature >= 80)
-
-    def test_bulk_update_future_synchronize_evaluate(self):
-        sess = self._fixture_data()
-
-        eq_(
-            set(
-                row.temperature
-                for row in sess.execute(select(Report.temperature))
-            ),
-            {80.0, 75.0, 85.0},
-        )
-
-        temps = sess.execute(select(Report)).scalars().all()
-        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
-        sess.execute(
-            update(Report)
-            .filter(Report.temperature >= 80)
-            .values(
-                {"temperature": Report.temperature + 6},
+        if legacy:
+            sess.query(Report).filter(Report.temperature >= 80).delete(
+                synchronize_session=synchronize_session
             )
-            .execution_options(synchronize_session="evaluate")
-        )
-
-        eq_(
-            set(
-                row.temperature
-                for row in sess.execute(select(Report.temperature))
-            ),
-            {86.0, 75.0, 91.0},
-        )
-
-        # test synchronize session as well
-        eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
-
-    def test_bulk_update_future_synchronize_fetch(self):
-        sess = self._fixture_data()
-
-        eq_(
-            set(
-                row.temperature
-                for row in sess.execute(select(Report.temperature))
-            ),
-            {80.0, 75.0, 85.0},
-        )
-
-        temps = sess.execute(select(Report)).scalars().all()
-        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
-        # MARKMARK
-        # omitting the criteria so that the UPDATE affects three out of
-        # four shards
-        sess.execute(
-            update(Report)
-            .values(
-                {"temperature": Report.temperature + 6},
+        else:
+            sess.execute(
+                delete(Report)
+                .filter(Report.temperature >= 80)
+                .execution_options(synchronize_session=synchronize_session)
             )
-            .execution_options(synchronize_session="fetch")
-        )
-
-        eq_(
-            set(
-                row.temperature
-                for row in sess.execute(select(Report.temperature))
-            ),
-            {86.0, 81.0, 91.0},
-        )
-
-        # test synchronize session as well
-        eq_(set(t.temperature for t in temps), {86.0, 81.0, 91.0})
-
-    def test_bulk_delete_future_synchronize_evaluate(self):
-        sess = self._fixture_data()
-
-        temps = sess.execute(select(Report)).scalars().all()
-        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
-        sess.execute(
-            delete(Report)
-            .filter(Report.temperature >= 80)
-            .execution_options(synchronize_session="evaluate")
-        )
 
-        eq_(
-            set(
-                row.temperature
-                for row in sess.execute(select(Report.temperature))
-            ),
-            {75.0},
-        )
-
-        # test synchronize session as well
-        for t in temps:
-            assert inspect(t).deleted is (t.temperature >= 80)
-
-    def test_bulk_delete_future_synchronize_fetch(self):
-        sess = self._fixture_data()
-
-        temps = sess.execute(select(Report)).scalars().all()
-        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+        def go():
+            # test synchronize session
+            for t in temps:
+                assert inspect(t).deleted is (t.temperature >= 80)
 
-        sess.execute(
-            delete(Report)
-            .filter(Report.temperature >= 80)
-            .execution_options(synchronize_session="fetch")
+        self.assert_sql_count(
+            sess._ShardedSession__binds["north_america"], go, 0
         )
 
         eq_(
-            set(
-                row.temperature
-                for row in sess.execute(select(Report.temperature))
-            ),
+            set(row.temperature for row in sess.query(Report.temperature)),
             {75.0},
         )
 
-        # test synchronize session as well
-        for t in temps:
-            assert inspect(t).deleted is (t.temperature >= 80)
-
 
 class DistinctEngineShardTest(ShardTest, fixtures.MappedTest):
     def _init_dbs(self):
index de5f89b25274068fecf3bd0601d0a79bf54f01fe..0cba8f3a15e0e295c6b3c6fed0ca5f2198942ea7 100644 (file)
@@ -3,6 +3,7 @@ from decimal import Decimal
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
+from sqlalchemy import insert
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
@@ -1017,15 +1018,43 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
             params={"first_name": "Dr."},
         )
 
-    def test_update_expr(self):
+    @testing.combinations("attr", "str", "kwarg", argnames="keytype")
+    def test_update_expr(self, keytype):
         Person = self.classes.Person
 
-        statement = update(Person).values({Person.name: "Dr. No"})
+        if keytype == "attr":
+            statement = update(Person).values({Person.name: "Dr. No"})
+        elif keytype == "str":
+            statement = update(Person).values({"name": "Dr. No"})
+        elif keytype == "kwarg":
+            statement = update(Person).values(name="Dr. No")
+        else:
+            assert False
 
         self.assert_compile(
             statement,
             "UPDATE person SET first_name=:first_name, last_name=:last_name",
-            params={"first_name": "Dr.", "last_name": "No"},
+            checkparams={"first_name": "Dr.", "last_name": "No"},
+        )
+
+    @testing.combinations("attr", "str", "kwarg", argnames="keytype")
+    def test_insert_expr(self, keytype):
+        Person = self.classes.Person
+
+        if keytype == "attr":
+            statement = insert(Person).values({Person.name: "Dr. No"})
+        elif keytype == "str":
+            statement = insert(Person).values({"name": "Dr. No"})
+        elif keytype == "kwarg":
+            statement = insert(Person).values(name="Dr. No")
+        else:
+            assert False
+
+        self.assert_compile(
+            statement,
+            "INSERT INTO person (first_name, last_name) VALUES "
+            "(:first_name, :last_name)",
+            checkparams={"first_name": "Dr.", "last_name": "No"},
         )
 
     # these tests all run two UPDATES to assert that caching is not
diff --git a/test/orm/dml/__init__.py b/test/orm/dml/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
similarity index 80%
rename from test/orm/test_bulk.py
rename to test/orm/dml/test_bulk.py
index 802cdfac5f218d3b03eebd83bb1d511aa37a10b5..52db4247f7b7a15bf6bbf604ea998646062ae5dd 100644 (file)
@@ -1,8 +1,11 @@
 from sqlalchemy import FetchedValue
 from sqlalchemy import ForeignKey
+from sqlalchemy import Identity
+from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import update
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import mock
@@ -20,6 +23,8 @@ class BulkTest(testing.AssertsExecutionResults):
 
 
 class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
+    __backend__ = True
+
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -73,6 +78,8 @@ class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
 
 
 class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
+    __backend__ = True
+
     @classmethod
     def setup_mappers(cls):
         User, Address, Order = cls.classes("User", "Address", "Order")
@@ -82,22 +89,42 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
         cls.mapper_registry.map_imperatively(Address, a)
         cls.mapper_registry.map_imperatively(Order, o)
 
-    def test_bulk_save_return_defaults(self):
+    @testing.combinations("save_objects", "insert_mappings", "insert_stmt")
+    def test_bulk_save_return_defaults(self, statement_type):
         (User,) = self.classes("User")
 
         s = fixture_session()
-        objects = [User(name="u1"), User(name="u2"), User(name="u3")]
-        assert "id" not in objects[0].__dict__
 
-        with self.sql_execution_asserter() as asserter:
-            s.bulk_save_objects(objects, return_defaults=True)
+        if statement_type == "save_objects":
+            objects = [User(name="u1"), User(name="u2"), User(name="u3")]
+            assert "id" not in objects[0].__dict__
+
+            returning_users_id = " RETURNING users.id"
+            with self.sql_execution_asserter() as asserter:
+                s.bulk_save_objects(objects, return_defaults=True)
+        elif statement_type == "insert_mappings":
+            data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
+            returning_users_id = " RETURNING users.id"
+            with self.sql_execution_asserter() as asserter:
+                s.bulk_insert_mappings(User, data, return_defaults=True)
+        elif statement_type == "insert_stmt":
+            data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
+
+            # for statement, "return_defaults" is heuristic on if we are
+            # a joined inh mapping if we don't otherwise include
+            # .returning() on the statement itself
+            returning_users_id = ""
+            with self.sql_execution_asserter() as asserter:
+                s.execute(insert(User), data)
 
         asserter.assert_(
             Conditional(
-                testing.db.dialect.insert_executemany_returning,
+                testing.db.dialect.insert_executemany_returning
+                or statement_type == "insert_stmt",
                 [
                     CompiledSQL(
-                        "INSERT INTO users (name) VALUES (:name)",
+                        "INSERT INTO users (name) "
+                        f"VALUES (:name){returning_users_id}",
                         [{"name": "u1"}, {"name": "u2"}, {"name": "u3"}],
                     ),
                 ],
@@ -117,7 +144,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
                 ],
             )
         )
-        eq_(objects[0].__dict__["id"], 1)
+        if statement_type == "save_objects":
+            eq_(objects[0].__dict__["id"], 1)
 
     def test_bulk_save_mappings_preserve_order(self):
         (User,) = self.classes("User")
@@ -219,8 +247,9 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
             )
         )
 
-    def test_bulk_update(self):
-        (User,) = self.classes("User")
+    @testing.combinations("update_mappings", "update_stmt")
+    def test_bulk_update(self, statement_type):
+        User = self.classes.User
 
         s = fixture_session(expire_on_commit=False)
         objects = [User(name="u1"), User(name="u2"), User(name="u3")]
@@ -228,15 +257,18 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
         s.commit()
 
         s = fixture_session()
-        with self.sql_execution_asserter() as asserter:
-            s.bulk_update_mappings(
-                User,
-                [
-                    {"id": 1, "name": "u1new"},
-                    {"id": 2, "name": "u2"},
-                    {"id": 3, "name": "u3new"},
-                ],
-            )
+        data = [
+            {"id": 1, "name": "u1new"},
+            {"id": 2, "name": "u2"},
+            {"id": 3, "name": "u3new"},
+        ]
+
+        if statement_type == "update_mappings":
+            with self.sql_execution_asserter() as asserter:
+                s.bulk_update_mappings(User, data)
+        elif statement_type == "update_stmt":
+            with self.sql_execution_asserter() as asserter:
+                s.execute(update(User), data)
 
         asserter.assert_(
             CompiledSQL(
@@ -303,6 +335,8 @@ class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
 
 
 class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest):
+    __backend__ = True
+
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -360,6 +394,8 @@ class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest):
 
 
 class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest):
+    __backend__ = True
+
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -547,6 +583,8 @@ class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest):
 
 
 class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
+    __backend__ = True
+
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -643,6 +681,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
         )
 
         s = fixture_session()
+
         objects = [
             Manager(name="m1", status="s1", manager_name="mn1"),
             Engineer(name="e1", status="s2", primary_language="l1"),
@@ -669,7 +708,7 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
                 [
                     CompiledSQL(
                         "INSERT INTO people (name, type) "
-                        "VALUES (:name, :type)",
+                        "VALUES (:name, :type) RETURNING people.person_id",
                         [
                             {"type": "engineer", "name": "e1"},
                             {"type": "engineer", "name": "e2"},
@@ -798,59 +837,74 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
             ),
         )
 
-    def test_bulk_insert_joined_inh_return_defaults(self):
+    @testing.combinations("insert_mappings", "insert_stmt")
+    def test_bulk_insert_joined_inh_return_defaults(self, statement_type):
         Person, Engineer, Manager, Boss = self.classes(
             "Person", "Engineer", "Manager", "Boss"
         )
 
         s = fixture_session()
-        with self.sql_execution_asserter() as asserter:
-            s.bulk_insert_mappings(
-                Boss,
-                [
-                    dict(
-                        name="b1",
-                        status="s1",
-                        manager_name="mn1",
-                        golf_swing="g1",
-                    ),
-                    dict(
-                        name="b2",
-                        status="s2",
-                        manager_name="mn2",
-                        golf_swing="g2",
-                    ),
-                    dict(
-                        name="b3",
-                        status="s3",
-                        manager_name="mn3",
-                        golf_swing="g3",
-                    ),
-                ],
-                return_defaults=True,
-            )
+        data = [
+            dict(
+                name="b1",
+                status="s1",
+                manager_name="mn1",
+                golf_swing="g1",
+            ),
+            dict(
+                name="b2",
+                status="s2",
+                manager_name="mn2",
+                golf_swing="g2",
+            ),
+            dict(
+                name="b3",
+                status="s3",
+                manager_name="mn3",
+                golf_swing="g3",
+            ),
+        ]
+
+        if statement_type == "insert_mappings":
+            with self.sql_execution_asserter() as asserter:
+                s.bulk_insert_mappings(
+                    Boss,
+                    data,
+                    return_defaults=True,
+                )
+        elif statement_type == "insert_stmt":
+            with self.sql_execution_asserter() as asserter:
+                s.execute(insert(Boss), data)
 
         asserter.assert_(
             Conditional(
                 testing.db.dialect.insert_executemany_returning,
                 [
                     CompiledSQL(
-                        "INSERT INTO people (name) VALUES (:name)",
-                        [{"name": "b1"}, {"name": "b2"}, {"name": "b3"}],
+                        "INSERT INTO people (name, type) "
+                        "VALUES (:name, :type) RETURNING people.person_id",
+                        [
+                            {"name": "b1", "type": "boss"},
+                            {"name": "b2", "type": "boss"},
+                            {"name": "b3", "type": "boss"},
+                        ],
                     ),
                 ],
                 [
                     CompiledSQL(
-                        "INSERT INTO people (name) VALUES (:name)",
-                        [{"name": "b1"}],
+                        "INSERT INTO people (name, type) "
+                        "VALUES (:name, :type)",
+                        [{"name": "b1", "type": "boss"}],
                     ),
                     CompiledSQL(
-                        "INSERT INTO people (name) VALUES (:name)",
-                        [{"name": "b2"}],
+                        "INSERT INTO people (name, type) "
+                        "VALUES (:name, :type)",
+                        [{"name": "b2", "type": "boss"}],
                     ),
                     CompiledSQL(
-                        "INSERT INTO people (name) VALUES (:name)",
-                        [{"name": "b3"}],
+                        "INSERT INTO people (name, type) "
+                        "VALUES (:name, :type)",
+                        [{"name": "b3", "type": "boss"}],
                     ),
                 ],
             ),
@@ -874,15 +928,79 @@ class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
             ),
         )
 
+    @testing.combinations("update_mappings", "update_stmt")
+    def test_bulk_update(self, statement_type):
+        Person, Engineer, Manager, Boss = self.classes(
+            "Person", "Engineer", "Manager", "Boss"
+        )
+
+        s = fixture_session()
+
+        b1, b2, b3 = (
+            Boss(name="b1", status="s1", manager_name="mn1", golf_swing="g1"),
+            Boss(name="b2", status="s2", manager_name="mn2", golf_swing="g2"),
+            Boss(name="b3", status="s3", manager_name="mn3", golf_swing="g3"),
+        )
+        s.add_all([b1, b2, b3])
+        s.commit()
+
+        # slight non-convenient thing.  we have to fill in boss_id here
+        # for update, this is not sent along automatically.  this is not a
+        # new behavior in bulk
+        new_data = [
+            {
+                "person_id": b1.person_id,
+                "boss_id": b1.boss_id,
+                "name": "b1_updated",
+                "manager_name": "mn1_updated",
+            },
+            {
+                "person_id": b3.person_id,
+                "boss_id": b3.boss_id,
+                "manager_name": "mn2_updated",
+                "golf_swing": "g1_updated",
+            },
+        ]
+
+        if statement_type == "update_mappings":
+            with self.sql_execution_asserter() as asserter:
+                s.bulk_update_mappings(Boss, new_data)
+        elif statement_type == "update_stmt":
+            with self.sql_execution_asserter() as asserter:
+                s.execute(update(Boss), new_data)
+
+        asserter.assert_(
+            CompiledSQL(
+                "UPDATE people SET name=:name WHERE "
+                "people.person_id = :people_person_id",
+                [{"name": "b1_updated", "people_person_id": 1}],
+            ),
+            CompiledSQL(
+                "UPDATE managers SET manager_name=:manager_name WHERE "
+                "managers.person_id = :managers_person_id",
+                [
+                    {"manager_name": "mn1_updated", "managers_person_id": 1},
+                    {"manager_name": "mn2_updated", "managers_person_id": 3},
+                ],
+            ),
+            CompiledSQL(
+                "UPDATE boss SET golf_swing=:golf_swing WHERE "
+                "boss.boss_id = :boss_boss_id",
+                [{"golf_swing": "g1_updated", "boss_boss_id": 3}],
+            ),
+        )
+
 
 class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest):
+    __backend__ = True
+
     @classmethod
     def setup_classes(cls):
         Base = cls.DeclarativeBasic
 
         class User(Base):
             __tablename__ = "users"
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, Identity(), primary_key=True)
             name = Column(String(255), nullable=False)
 
     def test_issue_6793(self):
@@ -907,7 +1025,8 @@ class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest):
                         [{"name": "A"}, {"name": "B"}],
                     ),
                     CompiledSQL(
-                        "INSERT INTO users (name) VALUES (:name)",
+                        "INSERT INTO users (name) VALUES (:name) "
+                        "RETURNING users.id",
                         [{"name": "C"}, {"name": "D"}],
                     ),
                 ],
diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py
new file mode 100644 (file)
index 0000000..0cca9e6
--- /dev/null
@@ -0,0 +1,1199 @@
+from __future__ import annotations
+
+from typing import Any
+from typing import List
+from typing import Optional
+import uuid
+
+from sqlalchemy import exc
+from sqlalchemy import ForeignKey
+from sqlalchemy import func
+from sqlalchemy import Identity
+from sqlalchemy import insert
+from sqlalchemy import inspect
+from sqlalchemy import literal_column
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy import update
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import load_only
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.testing import config
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import mock
+from sqlalchemy.testing import provision
+from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.fixtures import fixture_session
+
+
+class NoReturningTest(fixtures.TestBase):
+    def test_no_returning_error(self, decl_base):
+        class A(fixtures.ComparableEntity, decl_base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+            data: Mapped[str]
+            x: Mapped[Optional[int]] = mapped_column("xcol")
+
+        decl_base.metadata.create_all(testing.db)
+        s = fixture_session()
+
+        if testing.requires.insert_executemany_returning.enabled:
+            result = s.scalars(
+                insert(A).returning(A),
+                [
+                    {"data": "d3", "x": 5},
+                    {"data": "d4", "x": 6},
+                ],
+            )
+            eq_(result.all(), [A(data="d3", x=5), A(data="d4", x=6)])
+
+        else:
+            with expect_raises_message(
+                exc.InvalidRequestError,
+                "Can't use explicit RETURNING for bulk INSERT operation",
+            ):
+                s.scalars(
+                    insert(A).returning(A),
+                    [
+                        {"data": "d3", "x": 5},
+                        {"data": "d4", "x": 6},
+                    ],
+                )
+
+    def test_omit_returning_ok(self, decl_base):
+        class A(decl_base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+            data: Mapped[str]
+            x: Mapped[Optional[int]] = mapped_column("xcol")
+
+        decl_base.metadata.create_all(testing.db)
+        s = fixture_session()
+
+        s.execute(
+            insert(A),
+            [
+                {"data": "d3", "x": 5},
+                {"data": "d4", "x": 6},
+            ],
+        )
+        eq_(
+            s.execute(select(A.data, A.x).order_by(A.id)).all(),
+            [("d3", 5), ("d4", 6)],
+        )
+
+
+class BulkDMLReturningInhTest:
+    def test_insert_col_key_also_works_currently(self):
+        """using the column key, not mapped attr key.
+
+        right now this passes through to the INSERT.  when doing this with
+        an UPDATE, it tends to fail because the synchronize session
+        strategies can't match "xcol" back.  however w/ INSERT we aren't
+        doing that, so there's no place this gets checked.   UPDATE also
+        succeeds if synchronize_session is turned off.
+
+        """
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+        s.execute(insert(A).values(type="a", data="d", xcol=10))
+        eq_(s.scalars(select(A.x)).all(), [10])
+
+    @testing.combinations(True, False, argnames="use_returning")
+    def test_heterogeneous_keys(self, use_returning):
+        A, B = self.classes("A", "B")
+
+        values = [
+            {"data": "d3", "x": 5, "type": "a"},
+            {"data": "d4", "x": 6, "type": "a"},
+            {"data": "d5", "type": "a"},
+            {"data": "d6", "x": 8, "y": 9, "type": "a"},
+            {"data": "d7", "x": 12, "y": 12, "type": "a"},
+            {"data": "d8", "x": 7, "type": "a"},
+        ]
+
+        s = fixture_session()
+
+        stmt = insert(A)
+        if use_returning:
+            stmt = stmt.returning(A)
+
+        with self.sql_execution_asserter() as asserter:
+            result = s.execute(stmt, values)
+
+        if inspect(B).single:
+            single_inh = ", a.bd, a.zcol, a.q"
+        else:
+            single_inh = ""
+
+        if use_returning:
+            asserter.assert_(
+                CompiledSQL(
+                    "INSERT INTO a (type, data, xcol) VALUES "
+                    "(:type, :data, :xcol) "
+                    f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+                    [
+                        {"type": "a", "data": "d3", "xcol": 5},
+                        {"type": "a", "data": "d4", "xcol": 6},
+                    ],
+                ),
+                CompiledSQL(
+                    "INSERT INTO a (type, data) VALUES (:type, :data) "
+                    f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+                    [{"type": "a", "data": "d5"}],
+                ),
+                CompiledSQL(
+                    "INSERT INTO a (type, data, xcol, y) "
+                    "VALUES (:type, :data, :xcol, :y) "
+                    f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+                    [
+                        {"type": "a", "data": "d6", "xcol": 8, "y": 9},
+                        {"type": "a", "data": "d7", "xcol": 12, "y": 12},
+                    ],
+                ),
+                CompiledSQL(
+                    "INSERT INTO a (type, data, xcol) "
+                    "VALUES (:type, :data, :xcol) "
+                    f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+                    [{"type": "a", "data": "d8", "xcol": 7}],
+                ),
+            )
+        else:
+            asserter.assert_(
+                CompiledSQL(
+                    "INSERT INTO a (type, data, xcol) VALUES "
+                    "(:type, :data, :xcol)",
+                    [
+                        {"type": "a", "data": "d3", "xcol": 5},
+                        {"type": "a", "data": "d4", "xcol": 6},
+                    ],
+                ),
+                CompiledSQL(
+                    "INSERT INTO a (type, data) VALUES (:type, :data)",
+                    [{"type": "a", "data": "d5"}],
+                ),
+                CompiledSQL(
+                    "INSERT INTO a (type, data, xcol, y) "
+                    "VALUES (:type, :data, :xcol, :y)",
+                    [
+                        {"type": "a", "data": "d6", "xcol": 8, "y": 9},
+                        {"type": "a", "data": "d7", "xcol": 12, "y": 12},
+                    ],
+                ),
+                CompiledSQL(
+                    "INSERT INTO a (type, data, xcol) "
+                    "VALUES (:type, :data, :xcol)",
+                    [{"type": "a", "data": "d8", "xcol": 7}],
+                ),
+            )
+
+        if use_returning:
+            eq_(
+                result.scalars().all(),
+                [
+                    A(data="d3", id=mock.ANY, type="a", x=5, y=None),
+                    A(data="d4", id=mock.ANY, type="a", x=6, y=None),
+                    A(data="d5", id=mock.ANY, type="a", x=None, y=None),
+                    A(data="d6", id=mock.ANY, type="a", x=8, y=9),
+                    A(data="d7", id=mock.ANY, type="a", x=12, y=12),
+                    A(data="d8", id=mock.ANY, type="a", x=7, y=None),
+                ],
+            )
+
+    @testing.combinations(
+        "strings",
+        "cols",
+        "strings_w_exprs",
+        "cols_w_exprs",
+        argnames="paramstyle",
+    )
+    @testing.combinations(
+        True,
+        (False, testing.requires.multivalues_inserts),
+        argnames="single_element",
+    )
+    def test_single_values_returning_fn(self, paramstyle, single_element):
+        """test using insert().values().
+
+        these INSERT statements go straight in as a single execute without any
+        insertmanyreturning or bulk_insert_mappings thing going on.  the
+        advantage here is that SQL expressions can be used in the values also.
+        Disadvantage is none of the automation for inheritance mappers.
+
+        """
+        A, B = self.classes("A", "B")
+
+        if paramstyle == "strings":
+            values = [
+                {"data": "d3", "x": 5, "y": 9, "type": "a"},
+                {"data": "d4", "x": 10, "y": 8, "type": "a"},
+            ]
+        elif paramstyle == "cols":
+            values = [
+                {A.data: "d3", A.x: 5, A.y: 9, A.type: "a"},
+                {A.data: "d4", A.x: 10, A.y: 8, A.type: "a"},
+            ]
+        elif paramstyle == "strings_w_exprs":
+            values = [
+                {"data": func.lower("D3"), "x": 5, "y": 9, "type": "a"},
+                {
+                    "data": "d4",
+                    "x": literal_column("5") + 5,
+                    "y": 8,
+                    "type": "a",
+                },
+            ]
+        elif paramstyle == "cols_w_exprs":
+            values = [
+                {A.data: func.lower("D3"), A.x: 5, A.y: 9, A.type: "a"},
+                {
+                    A.data: "d4",
+                    A.x: literal_column("5") + 5,
+                    A.y: 8,
+                    A.type: "a",
+                },
+            ]
+        else:
+            assert False
+
+        s = fixture_session()
+
+        if single_element:
+            if paramstyle.startswith("strings"):
+                stmt = (
+                    insert(A)
+                    .values(**values[0])
+                    .returning(A, func.upper(A.data, type_=String))
+                )
+            else:
+                stmt = (
+                    insert(A)
+                    .values(values[0])
+                    .returning(A, func.upper(A.data, type_=String))
+                )
+        else:
+            stmt = (
+                insert(A)
+                .values(values)
+                .returning(A, func.upper(A.data, type_=String))
+            )
+
+        for i in range(3):
+            result = s.execute(stmt)
+            expected: List[Any] = [(A(data="d3", x=5, y=9), "D3")]
+            if not single_element:
+                expected.append((A(data="d4", x=10, y=8), "D4"))
+            eq_(result.all(), expected)
+
+    def test_bulk_w_sql_expressions(self):
+        A, B = self.classes("A", "B")
+
+        data = [
+            {"x": 5, "y": 9, "type": "a"},
+            {
+                "x": 10,
+                "y": 8,
+                "type": "a",
+            },
+        ]
+
+        s = fixture_session()
+
+        stmt = (
+            insert(A)
+            .values(data=func.lower("DD"))
+            .returning(A, func.upper(A.data, type_=String))
+        )
+
+        for i in range(3):
+            result = s.execute(stmt, data)
+            expected: List[Any] = [
+                (A(data="dd", x=5, y=9), "DD"),
+                (A(data="dd", x=10, y=8), "DD"),
+            ]
+            eq_(result.all(), expected)
+
+    def test_bulk_w_sql_expressions_subclass(self):
+        A, B = self.classes("A", "B")
+
+        data = [
+            {"bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+            {"bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+        ]
+
+        s = fixture_session()
+
+        stmt = (
+            insert(B)
+            .values(data=func.lower("DD"))
+            .returning(B, func.upper(B.data, type_=String))
+        )
+
+        for i in range(3):
+            result = s.execute(stmt, data)
+            expected: List[Any] = [
+                (B(bd="bd1", data="dd", q=4, type="b", x=1, y=2, z=3), "DD"),
+                (B(bd="bd2", data="dd", q=8, type="b", x=5, y=6, z=7), "DD"),
+            ]
+            eq_(result.all(), expected)
+
+    @testing.combinations(True, False, argnames="use_ordered")
+    def test_bulk_upd_w_sql_expressions_no_ordered_values(self, use_ordered):
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+
+        stmt = update(B).ordered_values(
+            ("data", func.lower("DD_UPDATE")),
+            ("z", literal_column("3 + 12")),
+        )
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            r"bulk ORM UPDATE does not support ordered_values\(\) "
+            r"for custom UPDATE",
+        ):
+            s.execute(
+                stmt,
+                [
+                    {"id": 5, "bd": "bd1_updated"},
+                    {"id": 6, "bd": "bd2_updated"},
+                ],
+            )
+
+    def test_bulk_upd_w_sql_expressions_subclass(self):
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+
+        data = [
+            {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+            {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+        ]
+        ids = s.scalars(insert(B).returning(B.id), data).all()
+
+        stmt = update(B).values(
+            data=func.lower("DD_UPDATE"), z=literal_column("3 + 12")
+        )
+
+        result = s.execute(
+            stmt,
+            [
+                {"id": ids[0], "bd": "bd1_updated"},
+                {"id": ids[1], "bd": "bd2_updated"},
+            ],
+        )
+
+        # this is a nullresult at the moment
+        assert result is not None
+
+        eq_(
+            s.scalars(select(B)).all(),
+            [
+                B(
+                    bd="bd1_updated",
+                    data="dd_update",
+                    id=ids[0],
+                    q=4,
+                    type="b",
+                    x=1,
+                    y=2,
+                    z=15,
+                ),
+                B(
+                    bd="bd2_updated",
+                    data="dd_update",
+                    id=ids[1],
+                    q=8,
+                    type="b",
+                    x=5,
+                    y=6,
+                    z=15,
+                ),
+            ],
+        )
+
+    def test_single_returning_fn(self):
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+        for i in range(3):
+            result = s.execute(
+                insert(A).returning(A, func.upper(A.data, type_=String)),
+                [{"data": "d3"}, {"data": "d4"}],
+            )
+            eq_(result.all(), [(A(data="d3"), "D3"), (A(data="d4"), "D4")])
+
+    @testing.combinations(
+        True,
+        False,
+        argnames="single_element",
+    )
+    def test_subclass_no_returning(self, single_element):
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+
+        if single_element:
+            data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+        else:
+            data = [
+                {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+                {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+            ]
+
+        result = s.execute(insert(B), data)
+        assert result._soft_closed
+
+    @testing.combinations(
+        True,
+        False,
+        argnames="single_element",
+    )
+    def test_subclass_load_only(self, single_element):
+        """test that load_only() prevents additional attributes from being
+        populated.
+
+        """
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+
+        if single_element:
+            data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+        else:
+            data = [
+                {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+                {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+            ]
+
+        for i in range(3):
+            # tests both caching and that the data dictionaries aren't
+            # mutated...
+            result = s.execute(
+                insert(B).returning(B).options(load_only(B.data, B.y, B.q)),
+                data,
+            )
+            objects = result.scalars().all()
+            for obj in objects:
+                assert "data" in obj.__dict__
+                assert "q" in obj.__dict__
+                assert "z" not in obj.__dict__
+                assert "x" not in obj.__dict__
+
+            expected = [
+                B(data="d3", bd="bd1", x=1, y=2, z=3, q=4),
+            ]
+            if not single_element:
+                expected.append(B(data="d4", bd="bd2", x=5, y=6, z=7, q=8))
+            eq_(objects, expected)
+
+    @testing.combinations(
+        True,
+        False,
+        argnames="single_element",
+    )
+    def test_subclass_load_only_doesnt_fetch_cols(self, single_element):
+        """test that when using load_only(), the actual INSERT statement
+        does not include the deferred columns
+
+        """
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+
+        data = [
+            {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+            {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+        ]
+        if single_element:
+            data = data[0]
+
+        with self.sql_execution_asserter() as asserter:
+
+            # tests both caching and that the data dictionaries aren't
+            # mutated...
+
+            # note that if we don't put B.id here, accessing .id on the
+            # B object for joined inheritance is triggering a SELECT
+            # (and not for single inheritance). this seems not great, but is
+            # likely a different issue
+            result = s.execute(
+                insert(B)
+                .returning(B)
+                .options(load_only(B.id, B.data, B.y, B.q)),
+                data,
+            )
+            objects = result.scalars().all()
+            if single_element:
+                id0 = objects[0].id
+                id1 = None
+            else:
+                id0, id1 = objects[0].id, objects[1].id
+
+        if inspect(B).single or inspect(B).concrete:
+            expected_params = [
+                {
+                    "type": "b",
+                    "data": "d3",
+                    "xcol": 1,
+                    "y": 2,
+                    "bd": "bd1",
+                    "zcol": 3,
+                    "q": 4,
+                },
+                {
+                    "type": "b",
+                    "data": "d4",
+                    "xcol": 5,
+                    "y": 6,
+                    "bd": "bd2",
+                    "zcol": 7,
+                    "q": 8,
+                },
+            ]
+            if single_element:
+                expected_params[1:] = []
+            # RETURNING only includes PK, discriminator, then the cols
+            # we asked for data, y, q.  xcol, z, bd are omitted
+
+            if inspect(B).single:
+                asserter.assert_(
+                    CompiledSQL(
+                        "INSERT INTO a (type, data, xcol, y, bd, zcol, q) "
+                        "VALUES "
+                        "(:type, :data, :xcol, :y, :bd, :zcol, :q) "
+                        "RETURNING a.id, a.type, a.data, a.y, a.q",
+                        expected_params,
+                    ),
+                )
+            else:
+                asserter.assert_(
+                    CompiledSQL(
+                        "INSERT INTO b (type, data, xcol, y, bd, zcol, q) "
+                        "VALUES "
+                        "(:type, :data, :xcol, :y, :bd, :zcol, :q) "
+                        "RETURNING b.id, b.type, b.data, b.y, b.q",
+                        expected_params,
+                    ),
+                )
+        else:
+            a_data = [
+                {"type": "b", "data": "d3", "xcol": 1, "y": 2},
+                {"type": "b", "data": "d4", "xcol": 5, "y": 6},
+            ]
+            b_data = [
+                {"id": id0, "bd": "bd1", "zcol": 3, "q": 4},
+                {"id": id1, "bd": "bd2", "zcol": 7, "q": 8},
+            ]
+            if single_element:
+                a_data[1:] = []
+                b_data[1:] = []
+            # RETURNING only includes PK, discriminator, then the cols
+            # we asked for data, y, q.  xcol, z, bd are omitted.  plus they
+            # are broken out correctly in the two statements.
+            asserter.assert_(
+                CompiledSQL(
+                    "INSERT INTO a (type, data, xcol, y) VALUES "
+                    "(:type, :data, :xcol, :y) "
+                    "RETURNING a.id, a.type, a.data, a.y",
+                    a_data,
+                ),
+                CompiledSQL(
+                    "INSERT INTO b (id, bd, zcol, q) "
+                    "VALUES (:id, :bd, :zcol, :q) "
+                    "RETURNING b.id, b.q",
+                    b_data,
+                ),
+            )
+
+    @testing.combinations(
+        True,
+        False,
+        argnames="single_element",
+    )
+    def test_subclass_returning_bind_expr(self, single_element):
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+
+        if single_element:
+            data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+        else:
+            data = [
+                {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+                {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+            ]
+        # note there's a fix in compiler.py ->
+        # _deliver_insertmanyvalues_batches
+        # for this re: the parameter rendering that isn't tested anywhere
+        # else.  two different versions of the bug for both positional
+        # and non
+        result = s.execute(insert(B).returning(B.data, B.y, B.q + 5), data)
+        if single_element:
+            eq_(result.all(), [("d3", 2, 9)])
+        else:
+            eq_(result.all(), [("d3", 2, 9), ("d4", 6, 13)])
+
+    def test_subclass_bulk_update(self):
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+
+        data = [
+            {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+            {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+        ]
+        ids = s.scalars(insert(B).returning(B.id), data).all()
+
+        result = s.execute(
+            update(B),
+            [
+                {"id": ids[0], "data": "d3_updated", "bd": "bd1_updated"},
+                {"id": ids[1], "data": "d4_updated", "bd": "bd2_updated"},
+            ],
+        )
+
+        # this is a nullresult at the moment
+        assert result is not None
+
+        eq_(
+            s.scalars(select(B)).all(),
+            [
+                B(
+                    bd="bd1_updated",
+                    data="d3_updated",
+                    id=ids[0],
+                    q=4,
+                    type="b",
+                    x=1,
+                    y=2,
+                    z=3,
+                ),
+                B(
+                    bd="bd2_updated",
+                    data="d4_updated",
+                    id=ids[1],
+                    q=8,
+                    type="b",
+                    x=5,
+                    y=6,
+                    z=7,
+                ),
+            ],
+        )
+
+    @testing.combinations(True, False, argnames="single_element")
+    def test_subclass_return_just_subclass_ids(self, single_element):
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+
+        if single_element:
+            data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+        else:
+            data = [
+                {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+                {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+            ]
+
+        ids = s.scalars(insert(B).returning(B.id), data).all()
+        actual_ids = s.scalars(select(B.id).order_by(B.data)).all()
+
+        eq_(ids, actual_ids)
+
+    @testing.combinations(
+        "orm",
+        "bulk",
+        argnames="insert_strategy",
+    )
+    @testing.requires.provisioned_upsert
+    def test_base_class_upsert(self, insert_strategy):
+        """upsert is really tricky.   if you dont have any data updated,
+        then you dont get the rows back and things dont work so well.
+
+        so we need to be careful how much we document this because this is
+        still a thorny use case.
+
+        """
+        A = self.classes.A
+
+        s = fixture_session()
+
+        initial_data = [
+            {"data": "d3", "x": 1, "y": 2, "q": 4},
+            {"data": "d4", "x": 5, "y": 6, "q": 8},
+        ]
+        ids = s.scalars(insert(A).returning(A.id), initial_data).all()
+
+        upsert_data = [
+            {
+                "id": ids[0],
+                "type": "a",
+                "data": "d3",
+                "x": 1,
+                "y": 2,
+            },
+            {
+                "id": 32,
+                "type": "a",
+                "data": "d32",
+                "x": 19,
+                "y": 5,
+            },
+            {
+                "id": ids[1],
+                "type": "a",
+                "data": "d4",
+                "x": 5,
+                "y": 6,
+            },
+            {
+                "id": 28,
+                "type": "a",
+                "data": "d28",
+                "x": 9,
+                "y": 15,
+            },
+        ]
+
+        stmt = provision.upsert(
+            config,
+            A,
+            (A,),
+            lambda inserted: {"data": inserted.data + " upserted"},
+        )
+
+        if insert_strategy == "orm":
+            result = s.scalars(stmt.values(upsert_data))
+        elif insert_strategy == "bulk":
+            result = s.scalars(stmt, upsert_data)
+        else:
+            assert False
+
+        eq_(
+            result.all(),
+            [
+                A(data="d3 upserted", id=ids[0], type="a", x=1, y=2),
+                A(data="d32", id=32, type="a", x=19, y=5),
+                A(data="d4 upserted", id=ids[1], type="a", x=5, y=6),
+                A(data="d28", id=28, type="a", x=9, y=15),
+            ],
+        )
+
+    @testing.combinations(
+        "orm",
+        "bulk",
+        argnames="insert_strategy",
+    )
+    @testing.requires.provisioned_upsert
+    def test_subclass_upsert(self, insert_strategy):
+        """note this is overridden in the joined version to expect failure"""
+
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+
+        idd3 = 1
+        idd4 = 2
+        id32 = 32
+        id28 = 28
+
+        initial_data = [
+            {
+                "id": idd3,
+                "data": "d3",
+                "bd": "bd1",
+                "x": 1,
+                "y": 2,
+                "z": 3,
+                "q": 4,
+            },
+            {
+                "id": idd4,
+                "data": "d4",
+                "bd": "bd2",
+                "x": 5,
+                "y": 6,
+                "z": 7,
+                "q": 8,
+            },
+        ]
+        ids = s.scalars(insert(B).returning(B.id), initial_data).all()
+
+        upsert_data = [
+            {
+                "id": ids[0],
+                "type": "b",
+                "data": "d3",
+                "bd": "bd1_upserted",
+                "x": 1,
+                "y": 2,
+                "z": 33,
+                "q": 44,
+            },
+            {
+                "id": id32,
+                "type": "b",
+                "data": "d32",
+                "bd": "bd 32",
+                "x": 19,
+                "y": 5,
+                "z": 20,
+                "q": 21,
+            },
+            {
+                "id": ids[1],
+                "type": "b",
+                "bd": "bd2_upserted",
+                "data": "d4",
+                "x": 5,
+                "y": 6,
+                "z": 77,
+                "q": 88,
+            },
+            {
+                "id": id28,
+                "type": "b",
+                "data": "d28",
+                "bd": "bd 28",
+                "x": 9,
+                "y": 15,
+                "z": 10,
+                "q": 11,
+            },
+        ]
+
+        stmt = provision.upsert(
+            config,
+            B,
+            (B,),
+            lambda inserted: {
+                "data": inserted.data + " upserted",
+                "bd": inserted.bd + " upserted",
+            },
+        )
+        result = s.scalars(stmt, upsert_data)
+        eq_(
+            result.all(),
+            [
+                B(
+                    bd="bd1_upserted upserted",
+                    data="d3 upserted",
+                    id=ids[0],
+                    q=4,
+                    type="b",
+                    x=1,
+                    y=2,
+                    z=3,
+                ),
+                B(
+                    bd="bd 32",
+                    data="d32",
+                    id=32,
+                    q=21,
+                    type="b",
+                    x=19,
+                    y=5,
+                    z=20,
+                ),
+                B(
+                    bd="bd2_upserted upserted",
+                    data="d4 upserted",
+                    id=ids[1],
+                    q=8,
+                    type="b",
+                    x=5,
+                    y=6,
+                    z=7,
+                ),
+                B(
+                    bd="bd 28",
+                    data="d28",
+                    id=28,
+                    q=11,
+                    type="b",
+                    x=9,
+                    y=15,
+                    z=10,
+                ),
+            ],
+        )
+
+
+class BulkDMLReturningJoinedInhTest(
+    BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
+):
+
+    __requires__ = ("insert_returning",)
+    __backend__ = True
+
+    @classmethod
+    def setup_classes(cls):
+        decl_base = cls.DeclarativeBasic
+
+        class A(fixtures.ComparableEntity, decl_base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+            type: Mapped[str]
+            data: Mapped[str]
+            x: Mapped[Optional[int]] = mapped_column("xcol")
+            y: Mapped[Optional[int]]
+
+            __mapper_args__ = {
+                "polymorphic_identity": "a",
+                "polymorphic_on": "type",
+            }
+
+        class B(A):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(
+                ForeignKey("a.id"), primary_key=True
+            )
+            bd: Mapped[str]
+            z: Mapped[Optional[int]] = mapped_column("zcol")
+            q: Mapped[Optional[int]]
+
+            __mapper_args__ = {"polymorphic_identity": "b"}
+
+    @testing.combinations(
+        "orm",
+        "bulk",
+        argnames="insert_strategy",
+    )
+    @testing.combinations(
+        True,
+        False,
+        argnames="single_param",
+    )
+    @testing.requires.provisioned_upsert
+    def test_subclass_upsert(self, insert_strategy, single_param):
+        A, B = self.classes("A", "B")
+
+        s = fixture_session()
+
+        initial_data = [
+            {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+            {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+        ]
+        ids = s.scalars(insert(B).returning(B.id), initial_data).all()
+
+        upsert_data = [
+            {
+                "id": ids[0],
+                "type": "b",
+            },
+            {
+                "id": 32,
+                "type": "b",
+            },
+        ]
+        if single_param:
+            upsert_data = upsert_data[0]
+
+        stmt = provision.upsert(
+            config,
+            B,
+            (B,),
+            lambda inserted: {
+                "bd": inserted.bd + " upserted",
+            },
+        )
+
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            r"bulk INSERT with a 'post values' clause \(typically upsert\) "
+            r"not supported for multi-table mapper",
+        ):
+            s.scalars(stmt, upsert_data)
+
+
+class BulkDMLReturningSingleInhTest(
+    BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
+):
+    __requires__ = ("insert_returning",)
+    __backend__ = True
+
+    @classmethod
+    def setup_classes(cls):
+        decl_base = cls.DeclarativeBasic
+
+        class A(fixtures.ComparableEntity, decl_base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+            type: Mapped[str]
+            data: Mapped[str]
+            x: Mapped[Optional[int]] = mapped_column("xcol")
+            y: Mapped[Optional[int]]
+
+            __mapper_args__ = {
+                "polymorphic_identity": "a",
+                "polymorphic_on": "type",
+            }
+
+        class B(A):
+            bd: Mapped[str] = mapped_column(nullable=True)
+            z: Mapped[Optional[int]] = mapped_column("zcol")
+            q: Mapped[Optional[int]]
+
+            __mapper_args__ = {"polymorphic_identity": "b"}
+
+
+class BulkDMLReturningConcreteInhTest(
+    BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
+):
+    __requires__ = ("insert_returning",)
+    __backend__ = True
+
+    @classmethod
+    def setup_classes(cls):
+        decl_base = cls.DeclarativeBasic
+
+        class A(fixtures.ComparableEntity, decl_base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+            type: Mapped[str]
+            data: Mapped[str]
+            x: Mapped[Optional[int]] = mapped_column("xcol")
+            y: Mapped[Optional[int]]
+
+            __mapper_args__ = {
+                "polymorphic_identity": "a",
+                "polymorphic_on": "type",
+            }
+
+        class B(A):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+            type: Mapped[str]
+            data: Mapped[str]
+            x: Mapped[Optional[int]] = mapped_column("xcol")
+            y: Mapped[Optional[int]]
+
+            bd: Mapped[str] = mapped_column(nullable=True)
+            z: Mapped[Optional[int]] = mapped_column("zcol")
+            q: Mapped[Optional[int]]
+
+            __mapper_args__ = {
+                "polymorphic_identity": "b",
+                "concrete": True,
+                "polymorphic_on": "type",
+            }
+
+
+class CTETest(fixtures.DeclarativeMappedTest):
+    __requires__ = ("insert_returning", "ctes_on_dml")
+    __backend__ = True
+
+    @classmethod
+    def setup_classes(cls):
+        decl_base = cls.DeclarativeBasic
+
+        class User(fixtures.ComparableEntity, decl_base):
+            __tablename__ = "users"
+            id: Mapped[uuid.UUID] = mapped_column(primary_key=True)
+            username: Mapped[str]
+
+    @testing.combinations(
+        ("cte_aliased", True),
+        ("cte", False),
+        argnames="wrap_cte_in_aliased",
+        id_="ia",
+    )
+    @testing.combinations(
+        ("use_union", True),
+        ("no_union", False),
+        argnames="use_a_union",
+        id_="ia",
+    )
+    @testing.combinations(
+        "from_statement", "aliased", "direct", argnames="fetch_entity_type"
+    )
+    def test_select_from_insert_cte(
+        self, wrap_cte_in_aliased, use_a_union, fetch_entity_type
+    ):
+        """test the use case from #8544; SELECT that selects from a
+        CTE INSERT...RETURNING.
+
+        """
+        User = self.classes.User
+
+        id_ = uuid.uuid4()
+
+        cte = (
+            insert(User)
+            .values(id=id_, username="some user")
+            .returning(User)
+            .cte()
+        )
+        if wrap_cte_in_aliased:
+            cte = aliased(User, cte)
+
+        if use_a_union:
+            stmt = select(User).where(User.id == id_).union(select(cte))
+        else:
+            stmt = select(cte)
+
+        if fetch_entity_type == "from_statement":
+            outer_stmt = select(User).from_statement(stmt)
+            expect_entity = True
+        elif fetch_entity_type == "aliased":
+            outer_stmt = select(aliased(User, stmt.subquery()))
+            expect_entity = True
+        elif fetch_entity_type == "direct":
+            outer_stmt = stmt
+            expect_entity = not use_a_union and wrap_cte_in_aliased
+        else:
+            assert False
+
+        sess = fixture_session()
+        with self.sql_execution_asserter() as asserter:
+
+            if not expect_entity:
+                row = sess.execute(outer_stmt).one()
+                eq_(row, (id_, "some user"))
+            else:
+                new_user = sess.scalars(outer_stmt).one()
+                eq_(new_user, User(id=id_, username="some user"))
+
+        cte_sql = (
+            "(INSERT INTO users (id, username) "
+            "VALUES (:param_1, :param_2) "
+            "RETURNING users.id, users.username)"
+        )
+
+        if fetch_entity_type == "aliased" and not use_a_union:
+            expected = (
+                f"WITH anon_2 AS {cte_sql} "
+                "SELECT anon_1.id, anon_1.username "
+                "FROM (SELECT anon_2.id AS id, anon_2.username AS username "
+                "FROM anon_2) AS anon_1"
+            )
+        elif not use_a_union:
+            expected = (
+                f"WITH anon_1 AS {cte_sql} "
+                "SELECT anon_1.id, anon_1.username FROM anon_1"
+            )
+        elif fetch_entity_type == "aliased":
+            expected = (
+                f"WITH anon_2 AS {cte_sql} SELECT anon_1.id, anon_1.username "
+                "FROM (SELECT users.id AS id, users.username AS username "
+                "FROM users WHERE users.id = :id_1 "
+                "UNION SELECT anon_2.id AS id, anon_2.username AS username "
+                "FROM anon_2) AS anon_1"
+            )
+        else:
+            expected = (
+                f"WITH anon_1 AS {cte_sql} "
+                "SELECT users.id, users.username FROM users "
+                "WHERE users.id = :id_1 "
+                "UNION SELECT anon_1.id, anon_1.username FROM anon_1"
+            )
+
+        asserter.assert_(
+            CompiledSQL(expected, [{"param_1": id_, "param_2": "some user"}])
+        )
similarity index 99%
rename from test/orm/test_evaluator.py
rename to test/orm/dml/test_evaluator.py
index ff40cd20155a68b4295615da63475a5fe5636e5c..4b903b863cc1f747d3bb38f42be621b51357f466 100644 (file)
@@ -324,7 +324,6 @@ class EvaluateTest(fixtures.MappedTest):
         """test #3162"""
 
         User = self.classes.User
-
         with expect_raises_message(
             evaluator.UnevaluatableError,
             r"Custom operator '\^\^' can't be evaluated in "
similarity index 83%
rename from test/orm/test_update_delete.py
rename to test/orm/dml/test_update_delete_where.py
index 1e93f88de608c25970324f486638a514d271837c..836feb6595a8bb818c5cdd2664cb88754994d443 100644 (file)
@@ -1,3 +1,4 @@
+from sqlalchemy import bindparam
 from sqlalchemy import Boolean
 from sqlalchemy import case
 from sqlalchemy import column
@@ -7,6 +8,7 @@ from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import insert
+from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import lambda_stmt
 from sqlalchemy import MetaData
@@ -17,6 +19,7 @@ from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import update
 from sqlalchemy.orm import backref
+from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
@@ -26,6 +29,7 @@ from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import in_
 from sqlalchemy.testing import not_in
@@ -123,6 +127,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
             },
         )
 
+    def test_update_dont_use_col_key(self):
+        User = self.classes.User
+
+        s = fixture_session()
+
+        # make sure objects are present to synchronize
+        _ = s.query(User).all()
+
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            "Attribute name not found, can't be synchronized back "
+            "to objects: 'age_int'",
+        ):
+            s.execute(update(User).values(age_int=5))
+
+        stmt = update(User).values(age=5)
+        s.execute(stmt)
+        eq_(s.scalars(select(User.age)).all(), [5, 5, 5, 5])
+
     @testing.combinations("table", "mapper", "both", argnames="bind_type")
     @testing.combinations(
         "update", "insert", "delete", argnames="statement_type"
@@ -162,7 +185,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
         assert_raises_message(
             exc.ArgumentError,
             "Valid strategies for session synchronization "
-            "are 'evaluate', 'fetch', False",
+            "are 'auto', 'evaluate', 'fetch', False",
             s.query(User).update,
             {},
             synchronize_session="fake",
@@ -351,6 +374,12 @@ class UpdateDeleteTest(fixtures.MappedTest):
     def test_evaluate_dont_refresh_expired_objects(
         self, expire_jane_age, add_filter_criteria
     ):
+        """test #5664.
+
+        approach is revised in SQLAlchemy 2.0 to not pre-emptively
+        unexpire the involved attributes
+
+        """
         User = self.classes.User
 
         sess = fixture_session()
@@ -379,15 +408,10 @@ class UpdateDeleteTest(fixtures.MappedTest):
         if add_filter_criteria:
             if expire_jane_age:
                 asserter.assert_(
-                    # it has to unexpire jane.name, because jane is not fully
-                    # expired and the criteria needs to look at this particular
-                    # key
-                    CompiledSQL(
-                        "SELECT users.age_int AS users_age_int, "
-                        "users.name AS users_name FROM users "
-                        "WHERE users.id = :pk_1",
-                        [{"pk_1": 4}],
-                    ),
+                    # previously, this would unexpire the attribute and
+                    # cause an additional SELECT.  The
+                    # 2.0 approach is that if the object has expired attrs
+                    # we just expire the whole thing, avoiding SQL up front
                     CompiledSQL(
                         "UPDATE users "
                         "SET age_int=(users.age_int + :age_int_1) "
@@ -397,14 +421,10 @@ class UpdateDeleteTest(fixtures.MappedTest):
                 )
             else:
                 asserter.assert_(
-                    # it has to unexpire jane.name, because jane is not fully
-                    # expired and the criteria needs to look at this particular
-                    # key
-                    CompiledSQL(
-                        "SELECT users.name AS users_name FROM users "
-                        "WHERE users.id = :pk_1",
-                        [{"pk_1": 4}],
-                    ),
+                    # previously, this would unexpire the attribute and
+                    # cause an additional SELECT.  The
+                    # 2.0 approach is that if the object has expired attrs
+                    # we just expire the whole thing, avoiding SQL up front
                     CompiledSQL(
                         "UPDATE users SET "
                         "age_int=(users.age_int + :age_int_1) "
@@ -443,9 +463,9 @@ class UpdateDeleteTest(fixtures.MappedTest):
             ),
         ]
 
-        if expire_jane_age and not add_filter_criteria:
+        if expire_jane_age:
             to_assert.append(
-                # refresh jane
+                # refresh jane for partial attributes
                 CompiledSQL(
                     "SELECT users.age_int AS users_age_int, "
                     "users.name AS users_name FROM users "
@@ -455,6 +475,75 @@ class UpdateDeleteTest(fixtures.MappedTest):
             )
         asserter.assert_(*to_assert)
 
+    @testing.combinations(True, False, argnames="is_evaluable")
+    def test_auto_synchronize(self, is_evaluable):
+        User = self.classes.User
+
+        sess = fixture_session()
+
+        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
+
+        if is_evaluable:
+            crit = or_(User.name == "jack", User.name == "jane")
+        else:
+            crit = case((User.name.in_(["jack", "jane"]), True), else_=False)
+
+        with self.sql_execution_asserter() as asserter:
+            sess.execute(update(User).where(crit).values(age=User.age + 10))
+
+        if is_evaluable:
+            asserter.assert_(
+                CompiledSQL(
+                    "UPDATE users SET age_int=(users.age_int + :age_int_1) "
+                    "WHERE users.name = :name_1 OR users.name = :name_2",
+                    [{"age_int_1": 10, "name_1": "jack", "name_2": "jane"}],
+                ),
+            )
+        elif testing.db.dialect.update_returning:
+            asserter.assert_(
+                CompiledSQL(
+                    "UPDATE users SET age_int=(users.age_int + :age_int_1) "
+                    "WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) "
+                    "THEN :param_1 ELSE :param_2 END = 1 RETURNING users.id",
+                    [
+                        {
+                            "age_int_1": 10,
+                            "name_1": ["jack", "jane"],
+                            "param_1": True,
+                            "param_2": False,
+                        }
+                    ],
+                ),
+            )
+        else:
+            asserter.assert_(
+                CompiledSQL(
+                    "SELECT users.id FROM users WHERE CASE WHEN "
+                    "(users.name IN (__[POSTCOMPILE_name_1])) "
+                    "THEN :param_1 ELSE :param_2 END = 1",
+                    [
+                        {
+                            "name_1": ["jack", "jane"],
+                            "param_1": True,
+                            "param_2": False,
+                        }
+                    ],
+                ),
+                CompiledSQL(
+                    "UPDATE users SET age_int=(users.age_int + :age_int_1) "
+                    "WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) "
+                    "THEN :param_1 ELSE :param_2 END = 1",
+                    [
+                        {
+                            "age_int_1": 10,
+                            "name_1": ["jack", "jane"],
+                            "param_1": True,
+                            "param_2": False,
+                        }
+                    ],
+                ),
+            )
+
     def test_fetch_dont_refresh_expired_objects(self):
         User = self.classes.User
 
@@ -518,17 +607,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
             ),
         )
 
-    def test_delete(self):
+    @testing.combinations(False, None, "auto", "evaluate", "fetch")
+    def test_delete(self, synchronize_session):
         User = self.classes.User
 
         sess = fixture_session()
 
         john, jack, jill, jane = sess.query(User).order_by(User.id).all()
-        sess.query(User).filter(
+
+        stmt = delete(User).filter(
             or_(User.name == "john", User.name == "jill")
-        ).delete()
+        )
+        if synchronize_session is not None:
+            stmt = stmt.execution_options(
+                synchronize_session=synchronize_session
+            )
+        sess.execute(stmt)
 
-        assert john not in sess and jill not in sess
+        if synchronize_session not in (False, None):
+            assert john not in sess and jill not in sess
 
         eq_(sess.query(User).order_by(User.id).all(), [jack, jane])
 
@@ -629,6 +726,33 @@ class UpdateDeleteTest(fixtures.MappedTest):
 
         eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane])
 
+    def test_update_multirow_not_supported(self):
+        User = self.classes.User
+
+        sess = fixture_session()
+
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            "WHERE clause with bulk ORM UPDATE not supported " "right now.",
+        ):
+            sess.execute(
+                update(User).where(User.id == bindparam("id")),
+                [{"id": 1, "age": 27}, {"id": 2, "age": 37}],
+            )
+
+    def test_delete_bulk_not_supported(self):
+        User = self.classes.User
+
+        sess = fixture_session()
+
+        with expect_raises_message(
+            exc.InvalidRequestError, "Bulk ORM DELETE not supported right now."
+        ):
+            sess.execute(
+                delete(User),
+                [{"id": 1}, {"id": 2}],
+            )
+
     def test_update(self):
         User, users = self.classes.User, self.tables.users
 
@@ -640,6 +764,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
         )
 
         eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
+
         eq_(
             sess.query(User.age).order_by(User.id).all(),
             list(zip([25, 37, 29, 27])),
@@ -974,7 +1099,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
             )
 
     @testing.requires.update_returning
-    def test_update_explicit_returning(self):
+    def test_update_evaluate_w_explicit_returning(self):
         User = self.classes.User
 
         sess = fixture_session()
@@ -987,6 +1112,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
                 .filter(User.age > 29)
                 .values({"age": User.age - 10})
                 .returning(User.id)
+                .execution_options(synchronize_session="evaluate")
             )
 
             rows = sess.execute(stmt).all()
@@ -1006,24 +1132,41 @@ class UpdateDeleteTest(fixtures.MappedTest):
         )
 
     @testing.requires.update_returning
-    def test_no_fetch_w_explicit_returning(self):
+    @testing.combinations("update", "delete", argnames="crud_type")
+    def test_fetch_w_explicit_returning(self, crud_type):
         User = self.classes.User
 
         sess = fixture_session()
 
-        stmt = (
-            update(User)
-            .filter(User.age > 29)
-            .values({"age": User.age - 10})
-            .execution_options(synchronize_session="fetch")
-            .returning(User.id)
-        )
-        with expect_raises_message(
-            exc.InvalidRequestError,
-            r"Can't use synchronize_session='fetch' "
-            r"with explicit returning\(\)",
-        ):
-            sess.execute(stmt)
+        if crud_type == "update":
+            stmt = (
+                update(User)
+                .filter(User.age > 29)
+                .values({"age": User.age - 10})
+                .execution_options(synchronize_session="fetch")
+                .returning(User, User.name)
+            )
+            expected = [
+                (User(age=37), "jack"),
+                (User(age=27), "jane"),
+            ]
+        elif crud_type == "delete":
+            stmt = (
+                delete(User)
+                .filter(User.age > 29)
+                .execution_options(synchronize_session="fetch")
+                .returning(User, User.name)
+            )
+            expected = [
+                (User(age=47), "jack"),
+                (User(age=37), "jane"),
+            ]
+        else:
+            assert False
+
+        result = sess.execute(stmt)
+
+        eq_(result.all(), expected)
 
     @testing.combinations(True, False, argnames="implicit_returning")
     def test_delete_fetch_returning(self, implicit_returning):
@@ -1142,7 +1285,8 @@ class UpdateDeleteTest(fixtures.MappedTest):
             list(zip([25, 47, 44, 37])),
         )
 
-    def test_update_changes_resets_dirty(self):
+    @testing.combinations("orm", "bulk")
+    def test_update_changes_resets_dirty(self, update_type):
         User = self.classes.User
 
         sess = fixture_session(autoflush=False)
@@ -1155,9 +1299,30 @@ class UpdateDeleteTest(fixtures.MappedTest):
         # autoflush is false.  therefore our '50' and '37' are getting
         # blown away by this operation.
 
-        sess.query(User).filter(User.age > 29).update(
-            {"age": User.age - 10}, synchronize_session="evaluate"
-        )
+        if update_type == "orm":
+            sess.execute(
+                update(User)
+                .filter(User.age > 29)
+                .values({"age": User.age - 10}),
+                execution_options=dict(synchronize_session="evaluate"),
+            )
+        elif update_type == "bulk":
+
+            data = [
+                {"id": john.id, "age": 25},
+                {"id": jack.id, "age": 37},
+                {"id": jill.id, "age": 29},
+                {"id": jane.id, "age": 27},
+            ]
+
+            sess.execute(
+                update(User),
+                data,
+                execution_options=dict(synchronize_session="evaluate"),
+            )
+
+        else:
+            assert False
 
         for x in (john, jack, jill, jane):
             assert not sess.is_modified(x)
@@ -1171,6 +1336,93 @@ class UpdateDeleteTest(fixtures.MappedTest):
         assert not sess.is_modified(john)
         assert not sess.is_modified(jack)
 
+    @testing.combinations(
+        None, False, "evaluate", "fetch", argnames="synchronize_session"
+    )
+    @testing.combinations(True, False, argnames="homogeneous_keys")
+    def test_bulk_update_synchronize_session(
+        self, synchronize_session, homogeneous_keys
+    ):
+        User = self.classes.User
+
+        sess = fixture_session(expire_on_commit=False)
+
+        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
+
+        if homogeneous_keys:
+            data = [
+                {"id": john.id, "age": 35},
+                {"id": jack.id, "age": 27},
+                {"id": jill.id, "age": 30},
+            ]
+        else:
+            data = [
+                {"id": john.id, "age": 35},
+                {"id": jack.id, "name": "new jack"},
+                {"id": jill.id, "age": 30, "name": "new jill"},
+            ]
+
+        with self.sql_execution_asserter() as asserter:
+            if synchronize_session is not None:
+                opts = {"synchronize_session": synchronize_session}
+            else:
+                opts = {}
+
+            if synchronize_session == "fetch":
+                with expect_raises_message(
+                    exc.InvalidRequestError,
+                    "The 'fetch' synchronization strategy is not available "
+                    "for 'bulk' ORM updates",
+                ):
+                    sess.execute(update(User), data, execution_options=opts)
+                return
+            else:
+                sess.execute(update(User), data, execution_options=opts)
+
+        if homogeneous_keys:
+            asserter.assert_(
+                CompiledSQL(
+                    "UPDATE users SET age_int=:age_int "
+                    "WHERE users.id = :users_id",
+                    [
+                        {"age_int": 35, "users_id": 1},
+                        {"age_int": 27, "users_id": 2},
+                        {"age_int": 30, "users_id": 3},
+                    ],
+                )
+            )
+        else:
+            asserter.assert_(
+                CompiledSQL(
+                    "UPDATE users SET age_int=:age_int "
+                    "WHERE users.id = :users_id",
+                    [{"age_int": 35, "users_id": 1}],
+                ),
+                CompiledSQL(
+                    "UPDATE users SET name=:name WHERE users.id = :users_id",
+                    [{"name": "new jack", "users_id": 2}],
+                ),
+                CompiledSQL(
+                    "UPDATE users SET name=:name, age_int=:age_int "
+                    "WHERE users.id = :users_id",
+                    [{"name": "new jill", "age_int": 30, "users_id": 3}],
+                ),
+            )
+
+        if synchronize_session is False:
+            eq_(jill.name, "jill")
+            eq_(jack.name, "jack")
+            eq_(jill.age, 29)
+            eq_(jack.age, 47)
+        else:
+            if not homogeneous_keys:
+                eq_(jill.name, "new jill")
+                eq_(jack.name, "new jack")
+                eq_(jack.age, 47)
+            else:
+                eq_(jack.age, 27)
+            eq_(jill.age, 30)
+
     def test_update_changes_with_autoflush(self):
         User = self.classes.User
 
@@ -1214,7 +1466,8 @@ class UpdateDeleteTest(fixtures.MappedTest):
         )
 
     @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount)
-    def test_update_returns_rowcount(self):
+    @testing.combinations("auto", "fetch", "evaluate")
+    def test_update_returns_rowcount(self, synchronize_session):
         User = self.classes.User
 
         sess = fixture_session()
@@ -1222,20 +1475,25 @@ class UpdateDeleteTest(fixtures.MappedTest):
         rowcount = (
             sess.query(User)
             .filter(User.age > 29)
-            .update({"age": User.age + 0})
+            .update(
+                {"age": User.age + 0}, synchronize_session=synchronize_session
+            )
         )
         eq_(rowcount, 2)
 
         rowcount = (
             sess.query(User)
             .filter(User.age > 29)
-            .update({"age": User.age - 10})
+            .update(
+                {"age": User.age - 10}, synchronize_session=synchronize_session
+            )
         )
         eq_(rowcount, 2)
 
         # test future
         result = sess.execute(
-            update(User).where(User.age > 19).values({"age": User.age - 10})
+            update(User).where(User.age > 19).values({"age": User.age - 10}),
+            execution_options={"synchronize_session": synchronize_session},
         )
         eq_(result.rowcount, 4)
 
@@ -1327,12 +1585,17 @@ class UpdateDeleteTest(fixtures.MappedTest):
         )
         assert john not in sess
 
-    def test_evaluate_before_update(self):
+    @testing.combinations(True, False)
+    def test_evaluate_before_update(self, full_expiration):
         User = self.classes.User
 
         sess = fixture_session()
         john = sess.query(User).filter_by(name="john").one()
-        sess.expire(john, ["age"])
+
+        if full_expiration:
+            sess.expire(john)
+        else:
+            sess.expire(john, ["age"])
 
         # eval must be before the update.  otherwise
         # we eval john, age has been expired and doesn't
@@ -1356,17 +1619,47 @@ class UpdateDeleteTest(fixtures.MappedTest):
         eq_(john.name, "j2")
         eq_(john.age, 40)
 
-    def test_evaluate_before_delete(self):
+    @testing.combinations(True, False)
+    def test_evaluate_before_delete(self, full_expiration):
         User = self.classes.User
 
         sess = fixture_session()
         john = sess.query(User).filter_by(name="john").one()
-        sess.expire(john, ["age"])
+        jill = sess.query(User).filter_by(name="jill").one()
+        jane = sess.query(User).filter_by(name="jane").one()
 
-        sess.query(User).filter_by(name="john").filter_by(age=25).delete(
+        if full_expiration:
+            sess.expire(jill)
+            sess.expire(john)
+        else:
+            sess.expire(jill, ["age"])
+            sess.expire(john, ["age"])
+
+        sess.query(User).filter(or_(User.age == 25, User.age == 37)).delete(
             synchronize_session="evaluate"
         )
-        assert john not in sess
+
+        # was fully deleted
+        assert jane not in sess
+
+        # deleted object was expired, but not otherwise affected
+        assert jill in sess
+
+        # deleted object was expired, but not otherwise affected
+        assert john in sess
+
+        # partially expired row fully expired
+        assert inspect(jill).expired
+
+        # non-deleted row still present
+        eq_(jill.age, 29)
+
+        # partially expired row fully expired
+        assert inspect(john).expired
+
+        # is deleted
+        with expect_raises(orm_exc.ObjectDeletedError):
+            john.name
 
     def test_fetch_before_delete(self):
         User = self.classes.User
@@ -1378,6 +1671,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
         sess.query(User).filter_by(name="john").filter_by(age=25).delete(
             synchronize_session="fetch"
         )
+
         assert john not in sess
 
     def test_update_unordered_dict(self):
@@ -1495,6 +1789,60 @@ class UpdateDeleteTest(fixtures.MappedTest):
         ]
         eq_(["name", "age_int"], cols)
 
+    @testing.requires.sqlite
+    def test_sharding_extension_returning_mismatch(self, testing_engine):
+        """test one horizontal shard case where the given binds don't match
+        for RETURNING support; we dont support this.
+
+        See test/ext/test_horizontal_shard.py for complete round trip
+        test cases for ORM update/delete
+
+        """
+        e1 = testing_engine("sqlite://")
+        e2 = testing_engine("sqlite://")
+        e1.connect().close()
+        e2.connect().close()
+
+        e1.dialect.update_returning = True
+        e2.dialect.update_returning = False
+
+        engines = [e1, e2]
+
+        # a simulated version of the horizontal sharding extension
+        def execute_and_instances(orm_context):
+            execution_options = dict(orm_context.local_execution_options)
+            partial = []
+            for engine in engines:
+                bind_arguments = dict(orm_context.bind_arguments)
+                bind_arguments["bind"] = engine
+                result_ = orm_context.invoke_statement(
+                    bind_arguments=bind_arguments,
+                    execution_options=execution_options,
+                )
+
+                partial.append(result_)
+            return partial[0].merge(*partial[1:])
+
+        User = self.classes.User
+        session = Session()
+
+        event.listen(
+            session, "do_orm_execute", execute_and_instances, retval=True
+        )
+
+        stmt = (
+            update(User)
+            .filter(User.id == 15)
+            .values(age=123)
+            .execution_options(synchronize_session="fetch")
+        )
+        with expect_raises_message(
+            exc.InvalidRequestError,
+            "For synchronize_session='fetch', can't mix multiple backends "
+            "where some support RETURNING and others don't",
+        ):
+            session.execute(stmt)
+
 
 class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest):
     @classmethod
@@ -1748,6 +2096,7 @@ class UpdateDeleteFromTest(fixtures.MappedTest):
             "Could not evaluate current criteria in Python.",
             q.update,
             {"samename": "ed"},
+            synchronize_session="evaluate",
         )
 
     @testing.requires.multi_table_update
@@ -1901,7 +2250,7 @@ class ExpressionUpdateTest(fixtures.MappedTest):
         sess.commit()
         eq_(d1.cnt, 0)
 
-        sess.query(Data).update({Data.cnt: Data.cnt + 1})
+        sess.query(Data).update({Data.cnt: Data.cnt + 1}, "evaluate")
         sess.flush()
 
         eq_(d1.cnt, 1)
@@ -2443,7 +2792,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
         )
 
     @testing.requires.update_returning
-    def test_load_from_update(self, connection):
+    @testing.combinations(True, False, argnames="use_from_statement")
+    def test_load_from_update(self, connection, use_from_statement):
         User = self.classes.User
 
         stmt = (
@@ -2453,7 +2803,16 @@ class LoadFromReturningTest(fixtures.MappedTest):
             .returning(User)
         )
 
-        stmt = select(User).from_statement(stmt)
+        if use_from_statement:
+            # this is now a legacy-ish case, because as of 2.0 you can just
+            # use returning() directly to get the objects back.
+            #
+            # when from_statement is used, the UPDATE statement is no
+            # longer interpreted by
+            # BulkUDCompileState.orm_pre_session_exec or
+            # BulkUDCompileState.orm_setup_cursor_result.  The compilation
+            # level routines still take place though
+            stmt = select(User).from_statement(stmt)
 
         with Session(connection) as sess:
             rows = sess.execute(stmt).scalars().all()
@@ -2468,7 +2827,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
         ("multiple", testing.requires.multivalues_inserts),
         argnames="params",
     )
-    def test_load_from_insert(self, connection, params):
+    @testing.combinations(True, False, argnames="use_from_statement")
+    def test_load_from_insert(self, connection, params, use_from_statement):
         User = self.classes.User
 
         if params == "multiple":
@@ -2484,7 +2844,8 @@ class LoadFromReturningTest(fixtures.MappedTest):
 
         stmt = insert(User).values(values).returning(User)
 
-        stmt = select(User).from_statement(stmt)
+        if use_from_statement:
+            stmt = select(User).from_statement(stmt)
 
         with Session(connection) as sess:
             rows = sess.execute(stmt).scalars().all()
@@ -2505,3 +2866,25 @@ class LoadFromReturningTest(fixtures.MappedTest):
                 )
             else:
                 assert False
+
+    @testing.requires.delete_returning
+    @testing.combinations(True, False, argnames="use_from_statement")
+    def test_load_from_delete(self, connection, use_from_statement):
+        User = self.classes.User
+
+        stmt = (
+            delete(User).where(User.name.in_(["jack", "jill"])).returning(User)
+        )
+
+        if use_from_statement:
+            stmt = select(User).from_statement(stmt)
+
+        with Session(connection) as sess:
+            rows = sess.execute(stmt).scalars().all()
+
+            eq_(
+                rows,
+                [User(name="jack", age=47), User(name="jill", age=29)],
+            )
+
+            # TODO: state of above objects should be "deleted"
index 2e3874549eb80e3d004cc4883d6071dace7757bd..5f8cfc1f56779bfa19f9e3365d5c427c84bbc202 100644 (file)
@@ -2012,7 +2012,8 @@ class JoinedNoFKSortingTest(fixtures.MappedTest):
                 and testing.db.dialect.supports_default_metavalue,
                 [
                     CompiledSQL(
-                        "INSERT INTO a (id) VALUES (DEFAULT)", [{}, {}, {}, {}]
+                        "INSERT INTO a (id) VALUES (DEFAULT) RETURNING a.id",
+                        [{}, {}, {}, {}],
                     ),
                 ],
                 [
index a6480365d0b4eecdec1d6c4f155fc8717e997eb0..2f392cf6e5ffeeb148905ccdf80103a0492c1f20 100644 (file)
@@ -326,6 +326,7 @@ class BindIntegrationTest(_fixtures.FixtureTest):
         ),
         (
             lambda User: update(User)
+            .execution_options(synchronize_session=False)
             .values(name="not ed")
             .where(User.name == "ed"),
             lambda User: {"clause": mock.ANY, "mapper": inspect(User)},
@@ -392,7 +393,15 @@ class BindIntegrationTest(_fixtures.FixtureTest):
         engine = {"e1": e1, "e2": e2, "e3": e3}[expected_engine_name]
 
         with mock.patch(
-            "sqlalchemy.orm.context.ORMCompileState.orm_setup_cursor_result"
+            "sqlalchemy.orm.context." "ORMCompileState.orm_setup_cursor_result"
+        ), mock.patch(
+            "sqlalchemy.orm.context.ORMCompileState.orm_execute_statement"
+        ), mock.patch(
+            "sqlalchemy.orm.bulk_persistence."
+            "BulkORMInsert.orm_execute_statement"
+        ), mock.patch(
+            "sqlalchemy.orm.bulk_persistence."
+            "BulkUDCompileState.orm_setup_cursor_result"
         ):
             sess.execute(statement)
 
index 3a789aff769069bcb0343d86e886f403ea8ef7eb..efa2ecb45eba2c403e78927aa069bce19f88e2d9 100644 (file)
@@ -1,8 +1,10 @@
 import dataclasses
 import operator
+import random
 
 import sqlalchemy as sa
 from sqlalchemy import ForeignKey
+from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import select
 from sqlalchemy import String
@@ -233,7 +235,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             is g.edges[1]
         )
 
-    def test_bulk_update_sql(self):
+    def test_update_crit_sql(self):
         Edge, Point = (self.classes.Edge, self.classes.Point)
 
         sess = self._fixture()
@@ -258,7 +260,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             dialect="default",
         )
 
-    def test_bulk_update_evaluate(self):
+    def test_update_crit_evaluate(self):
         Edge, Point = (self.classes.Edge, self.classes.Point)
 
         sess = self._fixture()
@@ -287,7 +289,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
 
         eq_(e1.end, Point(17, 8))
 
-    def test_bulk_update_fetch(self):
+    def test_update_crit_fetch(self):
         Edge, Point = (self.classes.Edge, self.classes.Point)
 
         sess = self._fixture()
@@ -305,6 +307,205 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
 
         eq_(e1.end, Point(17, 8))
 
+    @testing.combinations(
+        "legacy", "statement", "values", "stmt_returning", "values_returning"
+    )
+    def test_bulk_insert(self, type_):
+        Edge, Point = (self.classes.Edge, self.classes.Point)
+        Graph = self.classes.Graph
+
+        sess = self._fixture()
+
+        graph = Graph(id=2)
+        sess.add(graph)
+        sess.flush()
+        graph_id = 2
+
+        data = [
+            {
+                "start": Point(random.randint(1, 50), random.randint(1, 50)),
+                "end": Point(random.randint(1, 50), random.randint(1, 50)),
+                "graph_id": graph_id,
+            }
+            for i in range(25)
+        ]
+        returning = False
+        if type_ == "statement":
+            sess.execute(insert(Edge), data)
+        elif type_ == "stmt_returning":
+            result = sess.scalars(insert(Edge).returning(Edge), data)
+            returning = True
+        elif type_ == "values":
+            sess.execute(insert(Edge).values(data))
+        elif type_ == "values_returning":
+            result = sess.scalars(insert(Edge).values(data).returning(Edge))
+            returning = True
+        elif type_ == "legacy":
+            sess.bulk_insert_mappings(Edge, data)
+        else:
+            assert False
+
+        if returning:
+            eq_(result.all(), [Edge(rec["start"], rec["end"]) for rec in data])
+
+        edges = self.tables.edges
+        eq_(
+            sess.execute(
+                select(edges.c["x1", "y1", "x2", "y2"])
+                .where(edges.c.graph_id == graph_id)
+                .order_by(edges.c.id)
+            ).all(),
+            [
+                (e["start"].x, e["start"].y, e["end"].x, e["end"].y)
+                for e in data
+            ],
+        )
+
+    @testing.combinations("legacy", "statement")
+    def test_bulk_insert_heterogeneous(self, type_):
+        Edge, Point = (self.classes.Edge, self.classes.Point)
+        Graph = self.classes.Graph
+
+        sess = self._fixture()
+
+        graph = Graph(id=2)
+        sess.add(graph)
+        sess.flush()
+        graph_id = 2
+
+        d1 = [
+            {
+                "start": Point(random.randint(1, 50), random.randint(1, 50)),
+                "end": Point(random.randint(1, 50), random.randint(1, 50)),
+                "graph_id": graph_id,
+            }
+            for i in range(3)
+        ]
+        d2 = [
+            {
+                "start": Point(random.randint(1, 50), random.randint(1, 50)),
+                "graph_id": graph_id,
+            }
+            for i in range(2)
+        ]
+        d3 = [
+            {
+                "x2": random.randint(1, 50),
+                "y2": random.randint(1, 50),
+                "graph_id": graph_id,
+            }
+            for i in range(2)
+        ]
+        data = d1 + d2 + d3
+        random.shuffle(data)
+
+        assert_data = [
+            {
+                "start": d["start"] if "start" in d else None,
+                "end": d["end"]
+                if "end" in d
+                else Point(d["x2"], d["y2"])
+                if "x2" in d
+                else None,
+                "graph_id": d["graph_id"],
+            }
+            for d in data
+        ]
+
+        if type_ == "statement":
+            sess.execute(insert(Edge), data)
+        elif type_ == "legacy":
+            sess.bulk_insert_mappings(Edge, data)
+        else:
+            assert False
+
+        edges = self.tables.edges
+        eq_(
+            sess.execute(
+                select(edges.c["x1", "y1", "x2", "y2"])
+                .where(edges.c.graph_id == graph_id)
+                .order_by(edges.c.id)
+            ).all(),
+            [
+                (
+                    e["start"].x if e["start"] else None,
+                    e["start"].y if e["start"] else None,
+                    e["end"].x if e["end"] else None,
+                    e["end"].y if e["end"] else None,
+                )
+                for e in assert_data
+            ],
+        )
+
+    @testing.combinations("legacy", "statement")
+    def test_bulk_update(self, type_):
+        Edge, Point = (self.classes.Edge, self.classes.Point)
+        Graph = self.classes.Graph
+
+        sess = self._fixture()
+
+        graph = Graph(id=2)
+        sess.add(graph)
+        sess.flush()
+        graph_id = 2
+
+        data = [
+            {
+                "start": Point(random.randint(1, 50), random.randint(1, 50)),
+                "end": Point(random.randint(1, 50), random.randint(1, 50)),
+                "graph_id": graph_id,
+            }
+            for i in range(25)
+        ]
+        sess.execute(insert(Edge), data)
+
+        inserted_data = [
+            dict(row._mapping)
+            for row in sess.execute(
+                select(Edge.id, Edge.start, Edge.end, Edge.graph_id)
+                .where(Edge.graph_id == graph_id)
+                .order_by(Edge.id)
+            )
+        ]
+
+        to_update = []
+        updated_pks = {}
+        for rec in random.choices(inserted_data, k=7):
+            rec_copy = dict(rec)
+            updated_pks[rec_copy["id"]] = rec_copy
+            rec_copy["start"] = Point(
+                random.randint(1, 50), random.randint(1, 50)
+            )
+            rec_copy["end"] = Point(
+                random.randint(1, 50), random.randint(1, 50)
+            )
+            to_update.append(rec_copy)
+
+        expected_dataset = [
+            updated_pks[row["id"]] if row["id"] in updated_pks else row
+            for row in inserted_data
+        ]
+
+        if type_ == "statement":
+            sess.execute(update(Edge), to_update)
+        elif type_ == "legacy":
+            sess.bulk_update_mappings(Edge, to_update)
+        else:
+            assert False
+
+        edges = self.tables.edges
+        eq_(
+            sess.execute(
+                select(edges.c["x1", "y1", "x2", "y2"])
+                .where(edges.c.graph_id == graph_id)
+                .order_by(edges.c.id)
+            ).all(),
+            [
+                (e["start"].x, e["start"].y, e["end"].x, e["end"].y)
+                for e in expected_dataset
+            ],
+        )
+
     def test_get_history(self):
         Edge = self.classes.Edge
         Point = self.classes.Point
index 15155293faa80f2b243d942bd107db5710bede44..4d4a7ff64331aff336321fe88cfcd9ea08532280 100644 (file)
@@ -1125,7 +1125,7 @@ class OneToManyManyToOneTest(fixtures.MappedTest):
                 [
                     CompiledSQL(
                         "INSERT INTO ball (person_id, data) "
-                        "VALUES (:person_id, :data)",
+                        "VALUES (:person_id, :data) RETURNING ball.id",
                         [
                             {"person_id": None, "data": "some data"},
                             {"person_id": None, "data": "some data"},
index 7860f5eb1d0e3f4d2ffccf8ca6154b6d649e73f1..e738689b8901248614c295d2fc59394a89d432b6 100644 (file)
@@ -383,20 +383,24 @@ class ComputedDefaultsOnUpdateTest(fixtures.MappedTest):
                 CompiledSQL(
                     "UPDATE test SET foo=:foo WHERE test.id = :test_id",
                     [{"foo": 5, "test_id": 1}],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "UPDATE test SET foo=:foo WHERE test.id = :test_id",
                     [{"foo": 6, "test_id": 2}],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "SELECT test.bar AS test_bar FROM test "
                     "WHERE test.id = :pk_1",
                     [{"pk_1": 1}],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "SELECT test.bar AS test_bar FROM test "
                     "WHERE test.id = :pk_1",
                     [{"pk_1": 2}],
+                    enable_returning=False,
                 ),
             )
         else:
index 24870e20f1219b164d8a9765f738ac0d3e31f015..75955afb5a5e4c0b4e6b2370ea0b25bc36fd9700 100644 (file)
@@ -661,8 +661,17 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
 
         canary = self._flag_fixture(sess)
 
-        sess.execute(delete(User).filter_by(id=18))
-        sess.execute(update(User).filter_by(id=18).values(name="eighteen"))
+        sess.execute(
+            delete(User)
+            .filter_by(id=18)
+            .execution_options(synchronize_session="evaluate")
+        )
+        sess.execute(
+            update(User)
+            .filter_by(id=18)
+            .values(name="eighteen")
+            .execution_options(synchronize_session="evaluate")
+        )
 
         eq_(
             canary.mock_calls,
index b9499871606187fe3c283a6644b5629d0ce06ae2..fc452dc9c1a411deb8047bc0ed66f0bdaedea813 100644 (file)
@@ -2868,12 +2868,14 @@ class SaveTest2(_fixtures.FixtureTest):
                 testing.db.dialect.insert_executemany_returning,
                 [
                     CompiledSQL(
-                        "INSERT INTO users (name) VALUES (:name)",
+                        "INSERT INTO users (name) VALUES (:name) "
+                        "RETURNING users.id",
                         [{"name": "u1"}, {"name": "u2"}],
                     ),
                     CompiledSQL(
                         "INSERT INTO addresses (user_id, email_address) "
-                        "VALUES (:user_id, :email_address)",
+                        "VALUES (:user_id, :email_address) "
+                        "RETURNING addresses.id",
                         [
                             {"user_id": 1, "email_address": "a1"},
                             {"user_id": 2, "email_address": "a2"},
index dd3b88915113e10a7c416060830ed14972ce49cd..855b44e810ff41b8baf7f676c13a5d2a39384597 100644 (file)
@@ -98,7 +98,8 @@ class RudimentaryFlushTest(UOWTest):
                 [
                     CompiledSQL(
                         "INSERT INTO addresses (user_id, email_address) "
-                        "VALUES (:user_id, :email_address)",
+                        "VALUES (:user_id, :email_address) "
+                        "RETURNING addresses.id",
                         lambda ctx: [
                             {"email_address": "a1", "user_id": u1.id},
                             {"email_address": "a2", "user_id": u1.id},
@@ -220,7 +221,8 @@ class RudimentaryFlushTest(UOWTest):
                 [
                     CompiledSQL(
                         "INSERT INTO addresses (user_id, email_address) "
-                        "VALUES (:user_id, :email_address)",
+                        "VALUES (:user_id, :email_address) "
+                        "RETURNING addresses.id",
                         lambda ctx: [
                             {"email_address": "a1", "user_id": u1.id},
                             {"email_address": "a2", "user_id": u1.id},
@@ -889,7 +891,7 @@ class SingleCycleTest(UOWTest):
                 [
                     CompiledSQL(
                         "INSERT INTO nodes (parent_id, data) VALUES "
-                        "(:parent_id, :data)",
+                        "(:parent_id, :data) RETURNING nodes.id",
                         lambda ctx: [
                             {"parent_id": n1.id, "data": "n2"},
                             {"parent_id": n1.id, "data": "n3"},
@@ -1003,7 +1005,7 @@ class SingleCycleTest(UOWTest):
                 [
                     CompiledSQL(
                         "INSERT INTO nodes (parent_id, data) VALUES "
-                        "(:parent_id, :data)",
+                        "(:parent_id, :data) RETURNING nodes.id",
                         lambda ctx: [
                             {"parent_id": n1.id, "data": "n2"},
                             {"parent_id": n1.id, "data": "n3"},
@@ -1165,7 +1167,7 @@ class SingleCycleTest(UOWTest):
                 [
                     CompiledSQL(
                         "INSERT INTO nodes (parent_id, data) VALUES "
-                        "(:parent_id, :data)",
+                        "(:parent_id, :data) RETURNING nodes.id",
                         lambda ctx: [
                             {"parent_id": n1.id, "data": "n11"},
                             {"parent_id": n1.id, "data": "n12"},
@@ -1196,7 +1198,7 @@ class SingleCycleTest(UOWTest):
                 [
                     CompiledSQL(
                         "INSERT INTO nodes (parent_id, data) VALUES "
-                        "(:parent_id, :data)",
+                        "(:parent_id, :data) RETURNING nodes.id",
                         lambda ctx: [
                             {"parent_id": n12.id, "data": "n121"},
                             {"parent_id": n12.id, "data": "n122"},
@@ -2099,7 +2101,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
                 testing.db.dialect.insert_executemany_returning,
                 [
                     CompiledSQL(
-                        "INSERT INTO t (data) VALUES (:data)",
+                        "INSERT INTO t (data) VALUES (:data) RETURNING t.id",
                         [{"data": "t1"}, {"data": "t2"}],
                     ),
                 ],
@@ -2472,20 +2474,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
                 CompiledSQL(
                     "INSERT INTO test (id, foo) VALUES (:id, 2 + 5)",
                     [{"id": 1}],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "INSERT INTO test (id, foo) VALUES (:id, 5 + 5)",
                     [{"id": 2}],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "SELECT test.foo AS test_foo FROM test "
                     "WHERE test.id = :pk_1",
                     [{"pk_1": 1}],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "SELECT test.foo AS test_foo FROM test "
                     "WHERE test.id = :pk_1",
                     [{"pk_1": 2}],
+                    enable_returning=False,
                 ),
             )
 
@@ -2678,20 +2684,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
                     CompiledSQL(
                         "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
                         [{"foo": 5, "test2_id": 1}],
+                        enable_returning=False,
                     ),
                     CompiledSQL(
                         "UPDATE test2 SET foo=:foo, bar=:bar "
                         "WHERE test2.id = :test2_id",
                         [{"foo": 6, "bar": 10, "test2_id": 2}],
+                        enable_returning=False,
                     ),
                     CompiledSQL(
                         "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
                         [{"foo": 7, "test2_id": 3}],
+                        enable_returning=False,
                     ),
                     CompiledSQL(
                         "UPDATE test2 SET foo=:foo, bar=:bar "
                         "WHERE test2.id = :test2_id",
                         [{"foo": 8, "bar": 12, "test2_id": 4}],
+                        enable_returning=False,
                     ),
                     CompiledSQL(
                         "SELECT test2.bar AS test2_bar FROM test2 "
@@ -2772,31 +2782,37 @@ class EagerDefaultsTest(fixtures.MappedTest):
                         "UPDATE test4 SET foo=:foo, bar=5 + 3 "
                         "WHERE test4.id = :test4_id",
                         [{"foo": 5, "test4_id": 1}],
+                        enable_returning=False,
                     ),
                     CompiledSQL(
                         "UPDATE test4 SET foo=:foo, bar=:bar "
                         "WHERE test4.id = :test4_id",
                         [{"foo": 6, "bar": 10, "test4_id": 2}],
+                        enable_returning=False,
                     ),
                     CompiledSQL(
                         "UPDATE test4 SET foo=:foo, bar=5 + 3 "
                         "WHERE test4.id = :test4_id",
                         [{"foo": 7, "test4_id": 3}],
+                        enable_returning=False,
                     ),
                     CompiledSQL(
                         "UPDATE test4 SET foo=:foo, bar=:bar "
                         "WHERE test4.id = :test4_id",
                         [{"foo": 8, "bar": 12, "test4_id": 4}],
+                        enable_returning=False,
                     ),
                     CompiledSQL(
                         "SELECT test4.bar AS test4_bar FROM test4 "
                         "WHERE test4.id = :pk_1",
                         [{"pk_1": 1}],
+                        enable_returning=False,
                     ),
                     CompiledSQL(
                         "SELECT test4.bar AS test4_bar FROM test4 "
                         "WHERE test4.id = :pk_1",
                         [{"pk_1": 3}],
+                        enable_returning=False,
                     ),
                 ],
             ),
@@ -2871,20 +2887,24 @@ class EagerDefaultsTest(fixtures.MappedTest):
                     "UPDATE test2 SET foo=:foo, bar=1 + 1 "
                     "WHERE test2.id = :test2_id",
                     [{"foo": 5, "test2_id": 1}],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "UPDATE test2 SET foo=:foo, bar=:bar "
                     "WHERE test2.id = :test2_id",
                     [{"foo": 6, "bar": 10, "test2_id": 2}],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
                     [{"foo": 7, "test2_id": 3}],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "UPDATE test2 SET foo=:foo, bar=5 + 7 "
                     "WHERE test2.id = :test2_id",
                     [{"foo": 8, "test2_id": 4}],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "SELECT test2.bar AS test2_bar FROM test2 "
index abd5833bee77868d45679c8f2a136ddfef054e32..84e5a83b07e3f7888cbc9180606ed8c35c5702f1 100644 (file)
@@ -1424,12 +1424,10 @@ class ServerVersioningTest(fixtures.MappedTest):
         sess.add(f1)
 
         statements = [
-            # note that the assertsql tests the rule against
-            # "default" - on a "returning" backend, the statement
-            # includes "RETURNING"
             CompiledSQL(
                 "INSERT INTO version_table (version_id, value) "
-                "VALUES (1, :value)",
+                "VALUES (1, :value) "
+                "RETURNING version_table.id, version_table.version_id",
                 lambda ctx: [{"value": "f1"}],
             )
         ]
@@ -1493,6 +1491,7 @@ class ServerVersioningTest(fixtures.MappedTest):
                             "value": "f2",
                         }
                     ],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "SELECT version_table.version_id "
@@ -1618,6 +1617,7 @@ class ServerVersioningTest(fixtures.MappedTest):
                             "value": "f1a",
                         }
                     ],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "UPDATE version_table SET version_id=2, value=:value "
@@ -1630,6 +1630,7 @@ class ServerVersioningTest(fixtures.MappedTest):
                             "value": "f2a",
                         }
                     ],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "UPDATE version_table SET version_id=2, value=:value "
@@ -1642,6 +1643,7 @@ class ServerVersioningTest(fixtures.MappedTest):
                             "value": "f3a",
                         }
                     ],
+                    enable_returning=False,
                 ),
                 CompiledSQL(
                     "SELECT version_table.version_id "
index 42cf31bf54d9c1ee85e04ffff784e21479dd7495..4f776e30033e4658d233f438777fbbce7552e9c3 100644 (file)
@@ -100,10 +100,55 @@ class CursorResultTest(fixtures.TablesTest):
         Table(
             "test",
             metadata,
-            Column("x", Integer, primary_key=True),
+            Column(
+                "x", Integer, primary_key=True, test_needs_autoincrement=False
+            ),
             Column("y", String(50)),
         )
 
+    @testing.requires.insert_returning
+    def test_splice_horizontally(self, connection):
+        users = self.tables.users
+        addresses = self.tables.addresses
+
+        r1 = connection.execute(
+            users.insert().returning(users.c.user_name, users.c.user_id),
+            [
+                dict(user_id=1, user_name="john"),
+                dict(user_id=2, user_name="jack"),
+            ],
+        )
+
+        r2 = connection.execute(
+            addresses.insert().returning(
+                addresses.c.address_id,
+                addresses.c.address,
+                addresses.c.user_id,
+            ),
+            [
+                dict(address_id=1, user_id=1, address="foo@bar.com"),
+                dict(address_id=2, user_id=2, address="bar@bat.com"),
+            ],
+        )
+
+        rows = r1.splice_horizontally(r2).all()
+        eq_(
+            rows,
+            [
+                ("john", 1, 1, "foo@bar.com", 1),
+                ("jack", 2, 2, "bar@bat.com", 2),
+            ],
+        )
+
+        eq_(rows[0]._mapping[users.c.user_id], 1)
+        eq_(rows[0]._mapping[addresses.c.user_id], 1)
+        eq_(rows[1].address, "bar@bat.com")
+
+        with expect_raises_message(
+            exc.InvalidRequestError, "Ambiguous column name 'user_id'"
+        ):
+            rows[0].user_id
+
     def test_keys_no_rows(self, connection):
 
         for i in range(2):
index f8cc325170552c1c981732b286a99ee83050b28d..c26f825c27c13727cd9499b31f9f61ed023ff170 100644 (file)
@@ -23,6 +23,7 @@ from sqlalchemy.testing import config
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import provision
 from sqlalchemy.testing.schema import Column
@@ -76,6 +77,7 @@ class ReturnCombinationTests(fixtures.TestBase, AssertsCompiledSQL):
         stmt = stmt.returning(t.c.x)
 
         stmt = stmt.return_defaults()
+
         assert_raises_message(
             sa_exc.CompileError,
             r"Can't compile statement that includes returning\(\) "
@@ -330,6 +332,7 @@ class InsertReturningTest(fixtures.TablesTest, AssertsExecutionResults):
         table = self.tables.returning_tbl
 
         exprs = testing.resolve_lambda(testcase, table=table)
+
         result = connection.execute(
             table.insert().returning(*exprs),
             {"persons": 5, "full": False, "strval": "str1"},
@@ -679,6 +682,30 @@ class InsertReturnDefaultsTest(fixtures.TablesTest):
             Column("upddef", Integer, onupdate=IncDefault()),
         )
 
+        Table(
+            "table_no_addtl_defaults",
+            metadata,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("data", String(50)),
+        )
+
+        class MyType(TypeDecorator):
+            impl = String(50)
+
+            def process_result_value(self, value, dialect):
+                return f"PROCESSED! {value}"
+
+        Table(
+            "table_datatype_has_result_proc",
+            metadata,
+            Column(
+                "id", Integer, primary_key=True, test_needs_autoincrement=True
+            ),
+            Column("data", MyType()),
+        )
+
     def test_chained_insert_pk(self, connection):
         t1 = self.tables.t1
         result = connection.execute(
@@ -758,6 +785,38 @@ class InsertReturnDefaultsTest(fixtures.TablesTest):
         )
         eq_(result.inserted_primary_key, (1,))
 
+    def test_insert_w_defaults_supplemental_cols(self, connection):
+        t1 = self.tables.t1
+        result = connection.execute(
+            t1.insert().return_defaults(supplemental_cols=[t1.c.id]),
+            {"data": "d1"},
+        )
+        eq_(result.all(), [(1, 0, None)])
+
+    def test_insert_w_no_defaults_supplemental_cols(self, connection):
+        t1 = self.tables.table_no_addtl_defaults
+        result = connection.execute(
+            t1.insert().return_defaults(supplemental_cols=[t1.c.id]),
+            {"data": "d1"},
+        )
+        eq_(result.all(), [(1,)])
+
+    def test_insert_w_defaults_supplemental_processor_cols(self, connection):
+        """test that the cursor._rewind() used by supplemental RETURNING
+        clears out result-row processors as we will have already processed
+        the rows.
+
+        """
+
+        t1 = self.tables.table_datatype_has_result_proc
+        result = connection.execute(
+            t1.insert().return_defaults(
+                supplemental_cols=[t1.c.id, t1.c.data]
+            ),
+            {"data": "d1"},
+        )
+        eq_(result.all(), [(1, "PROCESSED! d1")])
+
 
 class UpdatedReturnDefaultsTest(fixtures.TablesTest):
     __requires__ = ("update_returning",)
@@ -792,6 +851,7 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
 
         t1 = self.tables.t1
         connection.execute(t1.insert().values(upddef=1))
+
         result = connection.execute(
             t1.update().values(upddef=2).return_defaults(t1.c.data)
         )
@@ -800,6 +860,72 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
             [None],
         )
 
+    def test_update_values_col_is_excluded(self, connection):
+        """columns that are in values() are not returned"""
+        t1 = self.tables.t1
+        connection.execute(t1.insert().values(upddef=1))
+
+        result = connection.execute(
+            t1.update().values(data="x", upddef=2).return_defaults(t1.c.data)
+        )
+        is_(result.returned_defaults, None)
+
+        result = connection.execute(
+            t1.update()
+            .values(data="x", upddef=2)
+            .return_defaults(t1.c.data, t1.c.id)
+        )
+        eq_(result.returned_defaults, (1,))
+
+    def test_update_supplemental_cols(self, connection):
+        """with supplemental_cols, we can get back arbitrary cols."""
+
+        t1 = self.tables.t1
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(
+            t1.update()
+            .values(data="x", insdef=3)
+            .return_defaults(supplemental_cols=[t1.c.data, t1.c.insdef])
+        )
+
+        row = result.returned_defaults
+
+        # row has all the cols in it
+        eq_(row, ("x", 3, 1))
+        eq_(row._mapping[t1.c.upddef], 1)
+        eq_(row._mapping[t1.c.insdef], 3)
+
+        # result is rewound
+        # but has both return_defaults + supplemental_cols
+        eq_(result.all(), [("x", 3, 1)])
+
+    def test_update_expl_return_defaults_plus_supplemental_cols(
+        self, connection
+    ):
+        """with supplemental_cols, we can get back arbitrary cols."""
+
+        t1 = self.tables.t1
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(
+            t1.update()
+            .values(data="x", insdef=3)
+            .return_defaults(
+                t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef]
+            )
+        )
+
+        row = result.returned_defaults
+
+        # row has all the cols in it
+        eq_(row, (1, "x", 3))
+        eq_(row._mapping[t1.c.id], 1)
+        eq_(row._mapping[t1.c.insdef], 3)
+        assert t1.c.upddef not in row._mapping
+
+        # result is rewound
+        # but has both return_defaults + supplemental_cols
+        eq_(result.all(), [(1, "x", 3)])
+
     def test_update_sql_expr(self, connection):
         from sqlalchemy import literal
 
@@ -833,6 +959,75 @@ class UpdatedReturnDefaultsTest(fixtures.TablesTest):
         eq_(dict(result.returned_defaults._mapping), {"upddef": 1})
 
 
+class DeleteReturnDefaultsTest(fixtures.TablesTest):
+    __requires__ = ("delete_returning",)
+    run_define_tables = "each"
+    __backend__ = True
+
+    define_tables = InsertReturnDefaultsTest.define_tables
+
+    def test_delete(self, connection):
+        t1 = self.tables.t1
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(t1.delete().return_defaults(t1.c.upddef))
+        eq_(
+            [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1]
+        )
+
+    def test_delete_empty_return_defaults(self, connection):
+        t1 = self.tables.t1
+        connection.execute(t1.insert().values(upddef=5))
+        result = connection.execute(t1.delete().return_defaults())
+
+        # there's no "delete" default, so we get None.  we have to
+        # ask for them in all cases
+        eq_(result.returned_defaults, None)
+
+    def test_delete_non_default(self, connection):
+        """test that a column not marked at all as a
+        default works with this feature."""
+
+        t1 = self.tables.t1
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(t1.delete().return_defaults(t1.c.data))
+        eq_(
+            [result.returned_defaults._mapping[k] for k in (t1.c.data,)],
+            [None],
+        )
+
+    def test_delete_non_default_plus_default(self, connection):
+        t1 = self.tables.t1
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(
+            t1.delete().return_defaults(t1.c.data, t1.c.upddef)
+        )
+        eq_(
+            dict(result.returned_defaults._mapping),
+            {"data": None, "upddef": 1},
+        )
+
+    def test_delete_supplemental_cols(self, connection):
+        """with supplemental_cols, we can get back arbitrary cols."""
+
+        t1 = self.tables.t1
+        connection.execute(t1.insert().values(upddef=1))
+        result = connection.execute(
+            t1.delete().return_defaults(
+                t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef]
+            )
+        )
+
+        row = result.returned_defaults
+
+        # row has all the cols in it
+        eq_(row, (1, None, 0))
+        eq_(row._mapping[t1.c.insdef], 0)
+
+        # result is rewound
+        # but has both return_defaults + supplemental_cols
+        eq_(result.all(), [(1, None, 0)])
+
+
 class InsertManyReturnDefaultsTest(fixtures.TablesTest):
     __requires__ = ("insert_executemany_returning",)
     run_define_tables = "each"
index 64ff2e421e371a413110205ff05fdd91888e8778..5ef927b157ee76ddf7283577670bafc160ab079e 100644 (file)
@@ -44,6 +44,7 @@ from sqlalchemy.sql import operators
 from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import visitors
+from sqlalchemy.sql.dml import Insert
 from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
@@ -3029,6 +3030,26 @@ class AnnotationsTest(fixtures.TestBase):
         eq_(whereclause.left._annotations, {"foo": "bar"})
         eq_(whereclause.right._annotations, {"foo": "bar"})
 
+    @testing.combinations(True, False, None)
+    def test_setup_inherit_cache(self, inherit_cache_value):
+        if inherit_cache_value is None:
+
+            class MyInsertThing(Insert):
+                pass
+
+        else:
+
+            class MyInsertThing(Insert):
+                inherit_cache = inherit_cache_value
+
+        t = table("t", column("x"))
+        anno = MyInsertThing(t)._annotate({"foo": "bar"})
+
+        if inherit_cache_value is not None:
+            is_(type(anno).__dict__["inherit_cache"], inherit_cache_value)
+        else:
+            assert "inherit_cache" not in type(anno).__dict__
+
     def test_proxy_set_iteration_includes_annotated(self):
         from sqlalchemy.schema import Column