With both the 1.x and 2.0 form of ORM-enabled updates and deletes, the following
values for ``synchronize_session`` are supported:
+* ``'auto'`` - this is the default. The ``'fetch'`` strategy will be used on
+ backends that support RETURNING, which includes all SQLAlchemy-native drivers
+ except for MySQL. If RETURNING is not supported, the ``'evaluate'``
+ strategy will be used instead.
+
+ .. versionchanged:: 2.0 Added the ``'auto'`` synchronization strategy. As
+ most backends now support RETURNING, selecting ``'fetch'`` for these
+ backends specifically is the more efficient and error-free default for
+ these backends. The MySQL backend as well as third party backends without
+ RETURNING support will continue to use ``'evaluate'`` by default.
+
* ``False`` - don't synchronize the session. This option is the most
efficient and is reliable once the session is expired, which
typically occurs after a commit(), or explicitly using
The SQL Server ``fast_executemany`` parameter may be used at the same time
as ``insertmanyvalues`` is enabled; however, the parameter will not be used
in as many cases as INSERT statements that are invoked using Core
- :class:`.Insert` constructs as well as all ORM use no longer use the
+ :class:`_dml.Insert` constructs as well as all ORM use no longer use the
``.executemany()`` DBAPI cursor method.
The PyODBC driver includes support for a "fast executemany" mode of execution
stmt = insert(table)
+ table_pk = inspect(table).selectable
+
if set_lambda:
stmt = stmt.on_conflict_do_update(
- index_elements=table.primary_key, set_=set_lambda(stmt.excluded)
+ index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded)
)
else:
stmt = stmt.on_conflict_do_nothing()
return target_text
- def visit_insert(self, insert_stmt, **kw):
- if insert_stmt._post_values_clause is not None:
- kw["disable_implicit_returning"] = True
- return super().visit_insert(insert_stmt, **kw)
-
def visit_on_conflict_do_nothing(self, on_conflict, **kw):
target_text = self._on_conflict_target(on_conflict, **kw)
from typing import List
from typing import NoReturn
from typing import Optional
+from typing import overload
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
+from .result import IteratorResult
from .result import MergedResult
from .result import Result
from .result import ResultMetaData
from .interfaces import ExecutionContext
from .result import _KeyIndexType
from .result import _KeyMapRecType
+ from .result import _KeyMapType
from .result import _KeyType
from .result import _ProcessorsType
+ from .result import _TupleGetterType
from ..sql.type_api import _ResultProcessorType
_T = TypeVar("_T", bound=Any)
+
# metadata entry tuple indexes.
# using raw tuple is faster than namedtuple.
-MD_INDEX: Literal[0] = 0 # integer index in cursor.description
-MD_RESULT_MAP_INDEX: Literal[
- 1
-] = 1 # integer index in compiled._result_columns
-MD_OBJECTS: Literal[
- 2
-] = 2 # other string keys and ColumnElement obj that can match
-MD_LOOKUP_KEY: Literal[
- 3
-] = 3 # string key we usually expect for key-based lookup
-MD_RENDERED_NAME: Literal[4] = 4 # name that is usually in cursor.description
-MD_PROCESSOR: Literal[5] = 5 # callable to process a result value into a row
-MD_UNTRANSLATED: Literal[6] = 6 # raw name from cursor.description
+# these match up to the positions in
+# _CursorKeyMapRecType
+MD_INDEX: Literal[0] = 0
+"""integer index in cursor.description
+
+"""
+
+MD_RESULT_MAP_INDEX: Literal[1] = 1
+"""integer index in compiled._result_columns"""
+
+MD_OBJECTS: Literal[2] = 2
+"""other string keys and ColumnElement obj that can match.
+
+This comes from compiler.RM_OBJECTS / compiler.ResultColumnsEntry.objects
+
+"""
+
+MD_LOOKUP_KEY: Literal[3] = 3
+"""string key we usually expect for key-based lookup
+
+this comes from compiler.RM_NAME / compiler.ResultColumnsEntry.name
+"""
+
+
+MD_RENDERED_NAME: Literal[4] = 4
+"""name that is usually in cursor.description
+
+this comes from compiler.RENDERED_NAME / compiler.ResultColumnsEntry.keyname
+"""
+
+
+MD_PROCESSOR: Literal[5] = 5
+"""callable to process a result value into a row"""
+
+MD_UNTRANSLATED: Literal[6] = 6
+"""raw name from cursor.description"""
_CursorKeyMapRecType = Tuple[
- int, int, List[Any], str, str, Optional["_ResultProcessorType"], str
+ Optional[int], # MD_INDEX, None means the record is ambiguously named
+ int, # MD_RESULT_MAP_INDEX
+ List[Any], # MD_OBJECTS
+ str, # MD_LOOKUP_KEY
+ str, # MD_RENDERED_NAME
+ Optional["_ResultProcessorType"], # MD_PROCESSOR
+ Optional[str], # MD_UNTRANSLATED
]
_CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType]
+# same as _CursorKeyMapRecType except the MD_INDEX value is definitely
+# not None
+_NonAmbigCursorKeyMapRecType = Tuple[
+ int,
+ int,
+ List[Any],
+ str,
+ str,
+ Optional["_ResultProcessorType"],
+ str,
+]
+
class CursorResultMetaData(ResultMetaData):
"""Result metadata for DBAPI cursors."""
extra=[self._keymap[key][MD_OBJECTS] for key in self._keys],
)
- def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
- recs = cast(
- "List[_CursorKeyMapRecType]", list(self._metadata_for_keys(keys))
+ def _make_new_metadata(
+ self,
+ *,
+ unpickled: bool,
+ processors: _ProcessorsType,
+ keys: Sequence[str],
+ keymap: _KeyMapType,
+ tuplefilter: Optional[_TupleGetterType],
+ translated_indexes: Optional[List[int]],
+ safe_for_cache: bool,
+ keymap_by_result_column_idx: Any,
+ ) -> CursorResultMetaData:
+ new_obj = self.__class__.__new__(self.__class__)
+ new_obj._unpickled = unpickled
+ new_obj._processors = processors
+ new_obj._keys = keys
+ new_obj._keymap = keymap
+ new_obj._tuplefilter = tuplefilter
+ new_obj._translated_indexes = translated_indexes
+ new_obj._safe_for_cache = safe_for_cache
+ new_obj._keymap_by_result_column_idx = keymap_by_result_column_idx
+ return new_obj
+
+ def _remove_processors(self) -> CursorResultMetaData:
+ assert not self._tuplefilter
+ return self._make_new_metadata(
+ unpickled=self._unpickled,
+ processors=[None] * len(self._processors),
+ tuplefilter=None,
+ translated_indexes=None,
+ keymap={
+ key: value[0:5] + (None,) + value[6:]
+ for key, value in self._keymap.items()
+ },
+ keys=self._keys,
+ safe_for_cache=self._safe_for_cache,
+ keymap_by_result_column_idx=self._keymap_by_result_column_idx,
)
+ def _splice_horizontally(
+ self, other: CursorResultMetaData
+ ) -> CursorResultMetaData:
+
+ assert not self._tuplefilter
+
+ keymap = self._keymap.copy()
+ offset = len(self._keys)
+ keymap.update(
+ {
+ key: (
+ # int index should be None for ambiguous key
+ value[0] + offset
+ if value[0] is not None and key not in keymap
+ else None,
+ value[1] + offset,
+ *value[2:],
+ )
+ for key, value in other._keymap.items()
+ }
+ )
+
+ return self._make_new_metadata(
+ unpickled=self._unpickled,
+ processors=self._processors + other._processors, # type: ignore
+ tuplefilter=None,
+ translated_indexes=None,
+ keys=self._keys + other._keys, # type: ignore
+ keymap=keymap,
+ safe_for_cache=self._safe_for_cache,
+ keymap_by_result_column_idx={
+ metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry
+ for metadata_entry in keymap.values()
+ },
+ )
+
+ def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData:
+ recs = list(self._metadata_for_keys(keys))
+
indexes = [rec[MD_INDEX] for rec in recs]
new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs]
if self._translated_indexes:
indexes = [self._translated_indexes[idx] for idx in indexes]
tup = tuplegetter(*indexes)
-
- new_metadata = self.__class__.__new__(self.__class__)
- new_metadata._unpickled = self._unpickled
- new_metadata._processors = self._processors
- new_metadata._keys = new_keys
- new_metadata._tuplefilter = tup
- new_metadata._translated_indexes = indexes
-
new_recs = [(index,) + rec[1:] for index, rec in enumerate(recs)]
- new_metadata._keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
+ keymap: _KeyMapType = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs}
# TODO: need unit test for:
# result = connection.execute("raw sql, no columns").scalars()
# without the "or ()" it's failing because MD_OBJECTS is None
- new_metadata._keymap.update(
+ keymap.update(
(e, new_rec)
for new_rec in new_recs
for e in new_rec[MD_OBJECTS] or ()
)
- return new_metadata
+ return self._make_new_metadata(
+ unpickled=self._unpickled,
+ processors=self._processors,
+ keys=new_keys,
+ tuplefilter=tup,
+ translated_indexes=indexes,
+ keymap=keymap,
+ safe_for_cache=self._safe_for_cache,
+ keymap_by_result_column_idx=self._keymap_by_result_column_idx,
+ )
def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData:
"""When using a cached Compiled construct that has a _result_map,
as matched to those of the cached statement.
"""
+
if not context.compiled or not context.compiled._result_columns:
return self
# make a copy and add the columns from the invoked statement
# to the result map.
- md = self.__class__.__new__(self.__class__)
keymap_by_position = self._keymap_by_result_column_idx
for metadata_entry in self._keymap.values()
}
- md._keymap = compat.dict_union(
- self._keymap,
- {
- new: keymap_by_position[idx]
- for idx, new in enumerate(
- invoked_statement._all_selected_columns
- )
- if idx in keymap_by_position
- },
- )
-
- md._unpickled = self._unpickled
- md._processors = self._processors
assert not self._tuplefilter
- md._tuplefilter = None
- md._translated_indexes = None
- md._keys = self._keys
- md._keymap_by_result_column_idx = self._keymap_by_result_column_idx
- md._safe_for_cache = self._safe_for_cache
- return md
+ return self._make_new_metadata(
+ keymap=compat.dict_union(
+ self._keymap,
+ {
+ new: keymap_by_position[idx]
+ for idx, new in enumerate(
+ invoked_statement._all_selected_columns
+ )
+ if idx in keymap_by_position
+ },
+ ),
+ unpickled=self._unpickled,
+ processors=self._processors,
+ tuplefilter=None,
+ translated_indexes=None,
+ keys=self._keys,
+ safe_for_cache=self._safe_for_cache,
+ keymap_by_result_column_idx=self._keymap_by_result_column_idx,
+ )
def __init__(
self,
untranslated,
)
- def _key_fallback(self, key, err, raiseerr=True):
+ @overload
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: Literal[True] = ...
+ ) -> NoReturn:
+ ...
+
+ @overload
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: Literal[False] = ...
+ ) -> None:
+ ...
+
+ @overload
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: bool = ...
+ ) -> Optional[NoReturn]:
+ ...
+
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: bool = True
+ ) -> Optional[NoReturn]:
if raiseerr:
if self._unpickled and isinstance(key, elements.ColumnElement):
try:
rec = self._keymap[key]
except KeyError as ke:
- rec = self._key_fallback(key, ke, raiseerr)
- if rec is None:
- return None
+ x = self._key_fallback(key, ke, raiseerr)
+ assert x is None
+ return None
index = rec[0]
def _metadata_for_keys(
self, keys: Sequence[Any]
- ) -> Iterator[_CursorKeyMapRecType]:
+ ) -> Iterator[_NonAmbigCursorKeyMapRecType]:
for key in keys:
if int in key.__class__.__mro__:
key = self._keys[key]
if index is None:
self._raise_for_ambiguous_column_name(rec)
- yield rec
+ yield cast(_NonAmbigCursorKeyMapRecType, rec)
def __getstate__(self):
return {
SelfCursorResult = TypeVar("SelfCursorResult", bound="CursorResult[Any]")
+def null_dml_result() -> IteratorResult[Any]:
+ it: IteratorResult[Any] = IteratorResult(_NoResultMetaData(), iter([]))
+ it._soft_close()
+ return it
+
+
class CursorResult(Result[_T]):
"""A Result that is representing state from a DBAPI cursor.
"""
return self.context.returned_default_rows
+ def splice_horizontally(self, other):
+ """Return a new :class:`.CursorResult` that "horizontally splices"
+ together the rows of this :class:`.CursorResult` with that of another
+ :class:`.CursorResult`.
+
+ .. tip:: This method is for the benefit of the SQLAlchemy ORM and is
+ not intended for general use.
+
+ "horizontally splices" means that for each row in the first and second
+ result sets, a new row that concatenates the two rows together is
+ produced, which then becomes the new row. The incoming
+ :class:`.CursorResult` must have the identical number of rows. It is
+ typically expected that the two result sets come from the same sort
+ order as well, as the result rows are spliced together based on their
+ position in the result.
+
+ The expected use case here is so that multiple INSERT..RETURNING
+ statements against different tables can produce a single result
+ that looks like a JOIN of those two tables.
+
+ E.g.::
+
+ r1 = connection.execute(
+ users.insert().returning(users.c.user_name, users.c.user_id),
+ user_values
+ )
+
+ r2 = connection.execute(
+ addresses.insert().returning(
+ addresses.c.address_id,
+ addresses.c.address,
+ addresses.c.user_id,
+ ),
+ address_values
+ )
+
+ rows = r1.splice_horizontally(r2).all()
+ assert (
+ rows ==
+ [
+ ("john", 1, 1, "foo@bar.com", 1),
+ ("jack", 2, 2, "bar@bat.com", 2),
+ ]
+ )
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :meth:`.CursorResult.splice_vertically`
+
+
+ """
+
+ clone = self._generate()
+ total_rows = [
+ tuple(r1) + tuple(r2)
+ for r1, r2 in zip(
+ list(self._raw_row_iterator()),
+ list(other._raw_row_iterator()),
+ )
+ ]
+
+ clone._metadata = clone._metadata._splice_horizontally(other._metadata)
+
+ clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
+ None,
+ initial_buffer=total_rows,
+ )
+ clone._reset_memoizations()
+ return clone
+
+ def splice_vertically(self, other):
+ """Return a new :class:`.CursorResult` that "vertically splices",
+ i.e. "extends", the rows of this :class:`.CursorResult` with that of
+ another :class:`.CursorResult`.
+
+ .. tip:: This method is for the benefit of the SQLAlchemy ORM and is
+ not intended for general use.
+
+ "vertically splices" means the rows of the given result are appended to
+ the rows of this cursor result. The incoming :class:`.CursorResult`
+ must have rows that represent the identical list of columns in the
+ identical order as they are in this :class:`.CursorResult`.
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :ref:`.CursorResult.splice_horizontally`
+
+ """
+ clone = self._generate()
+ total_rows = list(self._raw_row_iterator()) + list(
+ other._raw_row_iterator()
+ )
+
+ clone.cursor_strategy = FullyBufferedCursorFetchStrategy(
+ None,
+ initial_buffer=total_rows,
+ )
+ clone._reset_memoizations()
+ return clone
+
+ def _rewind(self, rows):
+ """rewind this result back to the given rowset.
+
+ this is used internally for the case where an :class:`.Insert`
+ construct combines the use of
+ :meth:`.Insert.return_defaults` along with the
+ "supplemental columns" feature.
+
+ """
+
+ if self._echo:
+ self.context.connection._log_debug(
+ "CursorResult rewound %d row(s)", len(rows)
+ )
+
+ # the rows given are expected to be Row objects, so we
+ # have to clear out processors which have already run on these
+ # rows
+ self._metadata = cast(
+ CursorResultMetaData, self._metadata
+ )._remove_processors()
+
+ self.cursor_strategy = FullyBufferedCursorFetchStrategy(
+ None,
+ # TODO: if these are Row objects, can we save on not having to
+ # re-make new Row objects out of them a second time? is that
+ # what's actually happening right now? maybe look into this
+ initial_buffer=rows,
+ )
+ self._reset_memoizations()
+ return self
+
@property
def returned_defaults(self):
"""Return the values of default columns that were fetched using
_is_implicit_returning = False
_is_explicit_returning = False
+ _is_supplemental_returning = False
_is_server_side = False
_soft_closed = False
self.is_text = compiled.isplaintext
if ii or iu or id_:
+ dml_statement = compiled.compile_state.statement # type: ignore
if TYPE_CHECKING:
- assert isinstance(compiled.statement, UpdateBase)
+ assert isinstance(dml_statement, UpdateBase)
self.is_crud = True
- self._is_explicit_returning = ier = bool(
- compiled.statement._returning
- )
- self._is_implicit_returning = iir = is_implicit_returning = bool(
+ self._is_explicit_returning = ier = bool(dml_statement._returning)
+ self._is_implicit_returning = iir = bool(
compiled.implicit_returning
)
- assert not (
- is_implicit_returning and compiled.statement._returning
- )
+ if iir and dml_statement._supplemental_returning:
+ self._is_supplemental_returning = True
+
+ # dont mix implicit and explicit returning
+ assert not (iir and ier)
if (ier or iir) and compiled.for_executemany:
if ii and not self.dialect.insert_executemany_returning:
# are that the result has only one row, until executemany()
# support is added here.
assert result._metadata.returns_rows
- result._soft_close()
+
+ # Insert statement has both return_defaults() and
+ # returning(). rewind the result on the list of rows
+ # we just used.
+ if self._is_supplemental_returning:
+ result._rewind(rows)
+ else:
+ result._soft_close()
elif not self._is_explicit_returning:
result._soft_close()
# function so this is not necessarily true.
# assert not result.returns_rows
- elif self.isupdate and self._is_implicit_returning:
- # get rowcount
- # (which requires open cursor on some drivers)
- # we were not doing this in 1.4, however
- # test_rowcount -> test_update_rowcount_return_defaults
- # is testing this, and psycopg will no longer return
- # rowcount after cursor is closed.
- result.rowcount
- self._has_rowcount = True
+ elif self._is_implicit_returning:
+ rows = result.all()
- row = result.fetchone()
- if row is not None:
- self.returned_default_rows = [row]
+ if rows:
+ self.returned_default_rows = rows
+ result.rowcount = len(rows)
+ self._has_rowcount = True
- result._soft_close()
+ if self._is_supplemental_returning:
+ result._rewind(rows)
+ else:
+ result._soft_close()
# test that it has a cursor metadata that is accurate.
# the rows have all been fetched however.
elif self.isupdate or self.isdelete:
result.rowcount
self._has_rowcount = True
-
return result
@util.memoized_property
def _for_freeze(self) -> ResultMetaData:
raise NotImplementedError()
+ @overload
def _key_fallback(
- self, key: _KeyType, err: Exception, raiseerr: bool = True
+ self, key: Any, err: Exception, raiseerr: Literal[True] = ...
) -> NoReturn:
+ ...
+
+ @overload
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: Literal[False] = ...
+ ) -> None:
+ ...
+
+ @overload
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: bool = ...
+ ) -> Optional[NoReturn]:
+ ...
+
+ def _key_fallback(
+ self, key: Any, err: Exception, raiseerr: bool = True
+ ) -> Optional[NoReturn]:
assert raiseerr
raise KeyError(key) from err
"""
_hard_closed = False
+ _soft_closed = False
def __init__(
self,
self.raw._soft_close(hard=hard, **kw)
self.iterator = iter([])
self._reset_memoizations()
+ self._soft_closed = True
def _raise_hard_closed(self) -> NoReturn:
raise exc.ResourceClosedError("This result object is closed.")
from __future__ import annotations
from typing import Any
+from typing import cast
from typing import Dict
from typing import Iterable
+from typing import Optional
+from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import attributes
+from . import context
from . import evaluator
from . import exc as orm_exc
+from . import loading
from . import persistence
from .base import NO_VALUE
from .context import AbstractORMCompileState
+from .context import FromStatement
+from .context import ORMFromStatementCompileState
+from .context import QueryContext
from .. import exc as sa_exc
-from .. import sql
from .. import util
from ..engine import Dialect
from ..engine import result as _result
from ..sql import coercions
+from ..sql import dml
from ..sql import expression
from ..sql import roles
from ..sql import select
if TYPE_CHECKING:
from .mapper import Mapper
+ from .session import _BindArguments
from .session import ORMExecuteState
+ from .session import Session
from .session import SessionTransaction
from .state import InstanceState
+ from ..engine import Connection
+ from ..engine import cursor
+ from ..engine.interfaces import _CoreAnyExecuteParams
+ from ..engine.interfaces import _ExecuteOptionsParameter
_O = TypeVar("_O", bound=object)
-_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"]
+_SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"]
+_DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"]
+@overload
def _bulk_insert(
mapper: Mapper[_O],
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
isstates: bool,
return_defaults: bool,
render_nulls: bool,
+ use_orm_insert_stmt: Literal[None] = ...,
+ execution_options: Optional[_ExecuteOptionsParameter] = ...,
) -> None:
+ ...
+
+
+@overload
+def _bulk_insert(
+ mapper: Mapper[_O],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ return_defaults: bool,
+ render_nulls: bool,
+ use_orm_insert_stmt: Optional[dml.Insert] = ...,
+ execution_options: Optional[_ExecuteOptionsParameter] = ...,
+) -> cursor.CursorResult[Any]:
+ ...
+
+
+def _bulk_insert(
+ mapper: Mapper[_O],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ return_defaults: bool,
+ render_nulls: bool,
+ use_orm_insert_stmt: Optional[dml.Insert] = None,
+ execution_options: Optional[_ExecuteOptionsParameter] = None,
+) -> Optional[cursor.CursorResult[Any]]:
base_mapper = mapper.base_mapper
if session_transaction.session.connection_callable:
else:
mappings = [state.dict for state in mappings]
else:
- mappings = list(mappings)
+ mappings = [dict(m) for m in mappings]
+ _expand_composites(mapper, mappings)
connection = session_transaction.connection(base_mapper)
+
+ return_result: Optional[cursor.CursorResult[Any]] = None
+
for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
+ if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
continue
+ is_joined_inh_supertable = super_mapper is not mapper
+ bookkeeping = (
+ is_joined_inh_supertable
+ or return_defaults
+ or (
+ use_orm_insert_stmt is not None
+ and bool(use_orm_insert_stmt._returning)
+ )
+ )
+
records = (
(
None,
table,
((None, mapping, mapper, connection) for mapping in mappings),
bulk=True,
- return_defaults=return_defaults,
+ return_defaults=bookkeeping,
render_nulls=render_nulls,
)
)
- persistence._emit_insert_statements(
+ result = persistence._emit_insert_statements(
base_mapper,
None,
super_mapper,
table,
records,
- bookkeeping=return_defaults,
+ bookkeeping=bookkeeping,
+ use_orm_insert_stmt=use_orm_insert_stmt,
+ execution_options=execution_options,
)
+ if use_orm_insert_stmt is not None:
+ if not use_orm_insert_stmt._returning or return_result is None:
+ return_result = result
+ elif result.returns_rows:
+ return_result = return_result.splice_horizontally(result)
if return_defaults and isstates:
identity_cls = mapper._identity_class
tuple([dict_[key] for key in identity_props]),
)
+ if use_orm_insert_stmt is not None:
+ assert return_result is not None
+ return return_result
+
+@overload
def _bulk_update(
mapper: Mapper[Any],
mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
session_transaction: SessionTransaction,
isstates: bool,
update_changed_only: bool,
+ use_orm_update_stmt: Literal[None] = ...,
) -> None:
+ ...
+
+
+@overload
+def _bulk_update(
+ mapper: Mapper[Any],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ update_changed_only: bool,
+ use_orm_update_stmt: Optional[dml.Update] = ...,
+) -> _result.Result[Any]:
+ ...
+
+
+def _bulk_update(
+ mapper: Mapper[Any],
+ mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+ session_transaction: SessionTransaction,
+ isstates: bool,
+ update_changed_only: bool,
+ use_orm_update_stmt: Optional[dml.Update] = None,
+) -> Optional[_result.Result[Any]]:
base_mapper = mapper.base_mapper
search_keys = mapper._primary_key_propkeys
else:
mappings = [state.dict for state in mappings]
else:
- mappings = list(mappings)
+ mappings = [dict(m) for m in mappings]
+ _expand_composites(mapper, mappings)
if session_transaction.session.connection_callable:
raise NotImplementedError(
connection = session_transaction.connection(base_mapper)
for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
+ if not mapper.isa(super_mapper) or table not in mapper._pks_by_table:
continue
records = persistence._collect_update_commands(
for mapping in mappings
),
bulk=True,
+ use_orm_update_stmt=use_orm_update_stmt,
)
-
persistence._emit_update_statements(
base_mapper,
None,
table,
records,
bookkeeping=False,
+ use_orm_update_stmt=use_orm_update_stmt,
)
+ if use_orm_update_stmt is not None:
+ return _result.null_result()
+
+
+def _expand_composites(mapper, mappings):
+ composite_attrs = mapper.composites
+ if not composite_attrs:
+ return
+
+ composite_keys = set(composite_attrs.keys())
+ populators = {
+ key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn()
+ for key in composite_keys
+ }
+ for mapping in mappings:
+ for key in composite_keys.intersection(mapping):
+ populators[key](mapping)
+
class ORMDMLState(AbstractORMCompileState):
+ is_dml_returning = True
+ from_statement_ctx: Optional[ORMFromStatementCompileState] = None
+
+ @classmethod
+ def _get_orm_crud_kv_pairs(
+ cls, mapper, statement, kv_iterator, needs_to_be_cacheable
+ ):
+
+ core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
+
+ for k, v in kv_iterator:
+ k = coercions.expect(roles.DMLColumnRole, k)
+
+ if isinstance(k, str):
+ desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
+ if desc is NO_VALUE:
+ yield (
+ coercions.expect(roles.DMLColumnRole, k),
+ coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ )
+ if needs_to_be_cacheable
+ else v,
+ )
+ else:
+ yield from core_get_crud_kv_pairs(
+ statement,
+ desc._bulk_update_tuples(v),
+ needs_to_be_cacheable,
+ )
+ elif "entity_namespace" in k._annotations:
+ k_anno = k._annotations
+ attr = _entity_namespace_key(
+ k_anno["entity_namespace"], k_anno["proxy_key"]
+ )
+ yield from core_get_crud_kv_pairs(
+ statement,
+ attr._bulk_update_tuples(v),
+ needs_to_be_cacheable,
+ )
+ else:
+ yield (
+ k,
+ v
+ if not needs_to_be_cacheable
+ else coercions.expect(
+ roles.ExpressionElementRole,
+ v,
+ type_=sqltypes.NullType(),
+ is_crud=True,
+ ),
+ )
+
+ @classmethod
+ def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+ if not plugin_subject or not plugin_subject.mapper:
+ return UpdateDMLState._get_multi_crud_kv_pairs(
+ statement, kv_iterator
+ )
+
+ return [
+ dict(
+ cls._get_orm_crud_kv_pairs(
+ plugin_subject.mapper, statement, value_dict.items(), False
+ )
+ )
+ for value_dict in kv_iterator
+ ]
+
+ @classmethod
+ def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable):
+ assert (
+ needs_to_be_cacheable
+ ), "no test coverage for needs_to_be_cacheable=False"
+
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+ if not plugin_subject or not plugin_subject.mapper:
+ return UpdateDMLState._get_crud_kv_pairs(
+ statement, kv_iterator, needs_to_be_cacheable
+ )
+
+ return list(
+ cls._get_orm_crud_kv_pairs(
+ plugin_subject.mapper,
+ statement,
+ kv_iterator,
+ needs_to_be_cacheable,
+ )
+ )
+
@classmethod
def get_entity_description(cls, statement):
ext_info = statement.table._annotations["parententity"]
]
]
+ def _setup_orm_returning(
+ self,
+ compiler,
+ orm_level_statement,
+ dml_level_statement,
+ use_supplemental_cols=True,
+ dml_mapper=None,
+ ):
+ """establish ORM column handlers for an INSERT, UPDATE, or DELETE
+ which uses explicit returning().
+
+ called within compilation level create_for_statement.
+
+ The _return_orm_returning() method then receives the Result
+ after the statement was executed, and applies ORM loading to the
+ state that we first established here.
+
+ """
+
+ if orm_level_statement._returning:
+
+ fs = FromStatement(
+ orm_level_statement._returning, dml_level_statement
+ )
+ fs = fs.options(*orm_level_statement._with_options)
+ self.select_statement = fs
+ self.from_statement_ctx = (
+ fsc
+ ) = ORMFromStatementCompileState.create_for_statement(fs, compiler)
+ fsc.setup_dml_returning_compile_state(dml_mapper)
+
+ dml_level_statement = dml_level_statement._generate()
+ dml_level_statement._returning = ()
+
+ cols_to_return = [c for c in fsc.primary_columns if c is not None]
+
+ # since we are splicing result sets together, make sure there
+ # are columns of some kind returned in each result set
+ if not cols_to_return:
+ cols_to_return.extend(dml_mapper.primary_key)
+
+ if use_supplemental_cols:
+ dml_level_statement = dml_level_statement.return_defaults(
+ supplemental_cols=cols_to_return
+ )
+ else:
+ dml_level_statement = dml_level_statement.returning(
+ *cols_to_return
+ )
+
+ return dml_level_statement
+
+ @classmethod
+ def _return_orm_returning(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ ):
+
+ execution_context = result.context
+ compile_state = execution_context.compiled.compile_state
+
+ if compile_state.from_statement_ctx:
+ load_options = execution_options.get(
+ "_sa_orm_load_options", QueryContext.default_load_options
+ )
+ querycontext = QueryContext(
+ compile_state.from_statement_ctx,
+ compile_state.select_statement,
+ params,
+ session,
+ load_options,
+ execution_options,
+ bind_arguments,
+ )
+ return loading.instances(result, querycontext)
+ else:
+ return result
+
class BulkUDCompileState(ORMDMLState):
class default_update_options(Options):
- _synchronize_session: _SynchronizeSessionArgument = "evaluate"
- _is_delete_using = False
- _is_update_from = False
- _autoflush = True
- _subject_mapper = None
+ _dml_strategy: _DMLStrategyArgument = "auto"
+ _synchronize_session: _SynchronizeSessionArgument = "auto"
+ _can_use_returning: bool = False
+ _is_delete_using: bool = False
+ _is_update_from: bool = False
+ _autoflush: bool = True
+ _subject_mapper: Optional[Mapper[Any]] = None
_resolved_values = EMPTY_DICT
- _resolved_keys_as_propnames = EMPTY_DICT
- _value_evaluators = EMPTY_DICT
- _matched_objects = None
+ _eval_condition = None
_matched_rows = None
_refresh_identity_token = None
execution_options,
) = BulkUDCompileState.default_update_options.from_execution_options(
"_sa_orm_update_options",
- {"synchronize_session", "is_delete_using", "is_update_from"},
+ {
+ "synchronize_session",
+ "is_delete_using",
+ "is_update_from",
+ "dml_strategy",
+ },
execution_options,
statement._execution_options,
)
- sync = update_options._synchronize_session
- if sync is not None:
- if sync not in ("evaluate", "fetch", False):
- raise sa_exc.ArgumentError(
- "Valid strategies for session synchronization "
- "are 'evaluate', 'fetch', False"
- )
-
bind_arguments["clause"] = statement
try:
plugin_subject = statement._propagate_attrs["plugin_subject"]
update_options += {"_subject_mapper": plugin_subject.mapper}
+ if not isinstance(params, list):
+ if update_options._dml_strategy == "auto":
+ update_options += {"_dml_strategy": "orm"}
+ elif update_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "bulk" ORM insert strategy without '
+ "passing separate parameters"
+ )
+ else:
+ if update_options._dml_strategy == "auto":
+ update_options += {"_dml_strategy": "bulk"}
+ elif update_options._dml_strategy == "orm":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "orm" ORM insert strategy with a '
+ "separate parameter list"
+ )
+
+ sync = update_options._synchronize_session
+ if sync is not None:
+ if sync not in ("auto", "evaluate", "fetch", False):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for session synchronization "
+ "are 'auto', 'evaluate', 'fetch', False"
+ )
+ if update_options._dml_strategy == "bulk" and sync == "fetch":
+ raise sa_exc.InvalidRequestError(
+ "The 'fetch' synchronization strategy is not available "
+ "for 'bulk' ORM updates (i.e. multiple parameter sets)"
+ )
+
if update_options._autoflush:
session._autoflush()
+ if update_options._dml_strategy == "orm":
+
+ if update_options._synchronize_session == "auto":
+ update_options = cls._do_pre_synchronize_auto(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "evaluate":
+ update_options = cls._do_pre_synchronize_evaluate(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "fetch":
+ update_options = cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._dml_strategy == "bulk":
+ if update_options._synchronize_session == "auto":
+ update_options += {"_synchronize_session": "evaluate"}
+
+ # indicators from the "pre exec" step that are then
+ # added to the DML statement, which will also be part of the cache
+ # key. The compile level create_for_statement() method will then
+ # consume these at compiler time.
statement = statement._annotate(
{
"synchronize_session": update_options._synchronize_session,
"is_delete_using": update_options._is_delete_using,
"is_update_from": update_options._is_update_from,
+ "dml_strategy": update_options._dml_strategy,
+ "can_use_returning": update_options._can_use_returning,
}
)
- # this stage of the execution is called before the do_orm_execute event
- # hook. meaning for an extension like horizontal sharding, this step
- # happens before the extension splits out into multiple backends and
- # runs only once. if we do pre_sync_fetch, we execute a SELECT
- # statement, which the horizontal sharding extension splits amongst the
- # shards and combines the results together.
-
- if update_options._synchronize_session == "evaluate":
- update_options = cls._do_pre_synchronize_evaluate(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
- elif update_options._synchronize_session == "fetch":
- update_options = cls._do_pre_synchronize_fetch(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
-
return (
statement,
util.immutabledict(execution_options).union(
# individual ones we return here.
update_options = execution_options["_sa_orm_update_options"]
- if update_options._synchronize_session == "evaluate":
- cls._do_post_synchronize_evaluate(session, result, update_options)
- elif update_options._synchronize_session == "fetch":
- cls._do_post_synchronize_fetch(session, result, update_options)
+ if update_options._dml_strategy == "orm":
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_evaluate(
+ session, statement, result, update_options
+ )
+ elif update_options._synchronize_session == "fetch":
+ cls._do_post_synchronize_fetch(
+ session, statement, result, update_options
+ )
+ elif update_options._dml_strategy == "bulk":
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_bulk_evaluate(
+ session, params, result, update_options
+ )
+ return result
- return result
+ return cls._return_orm_returning(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
@classmethod
def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
primary_key_convert = [
lookup[bpk] for bpk in mapper.base_mapper.primary_key
]
-
return [tuple(row[idx] for idx in primary_key_convert) for row in rows]
@classmethod
- def _do_pre_synchronize_evaluate(
+ def _get_matched_objects_on_criteria(cls, update_options, states):
+ mapper = update_options._subject_mapper
+ eval_condition = update_options._eval_condition
+
+ raw_data = [
+ (state.obj(), state, state.dict)
+ for state in states
+ if state.mapper.isa(mapper) and not state.expired
+ ]
+
+ identity_token = update_options._refresh_identity_token
+ if identity_token is not None:
+ raw_data = [
+ (obj, state, dict_)
+ for obj, state, dict_ in raw_data
+ if state.identity_token == identity_token
+ ]
+
+ result = []
+ for obj, state, dict_ in raw_data:
+ evaled_condition = eval_condition(obj)
+
+ # caution: don't use "in ()" or == here, _EXPIRE_OBJECT
+ # evaluates as True for all comparisons
+ if (
+ evaled_condition is True
+ or evaled_condition is evaluator._EXPIRED_OBJECT
+ ):
+ result.append(
+ (
+ obj,
+ state,
+ dict_,
+ evaled_condition is evaluator._EXPIRED_OBJECT,
+ )
+ )
+ return result
+
+ @classmethod
+ def _eval_condition_from_statement(cls, update_options, statement):
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
+
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ crit = ()
+ if statement._where_criteria:
+ crit += statement._where_criteria
+
+ global_attributes = {}
+ for opt in statement._with_options:
+ if opt._is_criteria_option:
+ opt.get_global_criteria(global_attributes)
+
+ if global_attributes:
+ crit += cls._adjust_for_extra_criteria(global_attributes, mapper)
+
+ if crit:
+ eval_condition = evaluator_compiler.process(*crit)
+ else:
+
+ def eval_condition(obj):
+ return True
+
+ return eval_condition
+
+ @classmethod
+ def _do_pre_synchronize_auto(
cls,
session,
statement,
bind_arguments,
update_options,
):
- mapper = update_options._subject_mapper
- target_cls = mapper.class_
+ """setup auto sync strategy
+
+
+ "auto" checks if we can use "evaluate" first, then falls back
+ to "fetch"
- value_evaluators = resolved_keys_as_propnames = EMPTY_DICT
+ evaluate is vastly more efficient for the common case
+ where session is empty, only has a few objects, and the UPDATE
+ statement can potentially match thousands/millions of rows.
+
+ OTOH more complex criteria that fails to work with "evaluate"
+ we would hope usually correlates with fewer net rows.
+
+ """
try:
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- crit = ()
- if statement._where_criteria:
- crit += statement._where_criteria
+ eval_condition = cls._eval_condition_from_statement(
+ update_options, statement
+ )
- global_attributes = {}
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(global_attributes)
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ return update_options + {
+ "_eval_condition": eval_condition,
+ "_synchronize_session": "evaluate",
+ }
- if global_attributes:
- crit += cls._adjust_for_extra_criteria(
- global_attributes, mapper
- )
+ update_options += {"_synchronize_session": "fetch"}
+ return cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
- if crit:
- eval_condition = evaluator_compiler.process(*crit)
- else:
+ @classmethod
+ def _do_pre_synchronize_evaluate(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
- def eval_condition(obj):
- return True
+ try:
+ eval_condition = cls._eval_condition_from_statement(
+ update_options, statement
+ )
except evaluator.UnevaluatableError as err:
raise sa_exc.InvalidRequestError(
"synchronize_session execution option." % err
) from err
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
-
- if effective_statement.__visit_name__ == "update":
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- value_evaluators = {}
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
-
- # TODO: detect when the where clause is a trivial primary key match.
- matched_objects = [
- state.obj()
- for state in session.identity_map.all_states()
- if state.mapper.isa(mapper)
- and not state.expired
- and eval_condition(state.obj())
- and (
- update_options._refresh_identity_token is None
- # TODO: coverage for the case where horizontal sharding
- # invokes an update() or delete() given an explicit identity
- # token up front
- or state.identity_token
- == update_options._refresh_identity_token
- )
- ]
return update_options + {
- "_matched_objects": matched_objects,
- "_value_evaluators": value_evaluators,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ "_eval_condition": eval_condition,
}
@classmethod
def _resolved_keys_as_propnames(cls, mapper, resolved_values):
values = []
for k, v in resolved_values:
- if isinstance(k, attributes.QueryableAttribute):
- values.append((k.key, v))
- continue
- elif hasattr(k, "__clause_element__"):
- k = k.__clause_element__()
-
if mapper and isinstance(k, expression.ColumnElement):
try:
attr = mapper._columntoproperty[k]
values.append((attr.key, v))
else:
raise sa_exc.InvalidRequestError(
- "Invalid expression type: %r" % k
+ "Attribute name not found, can't be "
+ "synchronized back to objects: %r" % k
)
return values
)
select_stmt._where_criteria = statement._where_criteria
+ # conditionally run the SELECT statement for pre-fetch, testing the
+ # "bind" for if we can use RETURNING or not using the do_orm_execute
+ # event. If RETURNING is available, the do_orm_execute event
+ # will cancel the SELECT from being actually run.
+ #
+ # The way this is organized seems strange, why don't we just
+ # call can_use_returning() before invoking the statement and get
+ # answer?, why does this go through the whole execute phase using an
+ # event? Answer: because we are integrating with extensions such
+ # as the horizontal sharding extention that "multiplexes" an individual
+ # statement run through multiple engines, and it uses
+ # do_orm_execute() to do that.
+
+ can_use_returning = None
+
def skip_for_returning(orm_context: ORMExecuteState) -> Any:
bind = orm_context.session.get_bind(**orm_context.bind_arguments)
- if cls.can_use_returning(
+ nonlocal can_use_returning
+
+ per_bind_result = cls.can_use_returning(
bind.dialect,
mapper,
is_update_from=update_options._is_update_from,
is_delete_using=update_options._is_delete_using,
- ):
- return _result.null_result()
- else:
- return None
+ )
+
+ if can_use_returning is not None:
+ if can_use_returning != per_bind_result:
+ raise sa_exc.InvalidRequestError(
+ "For synchronize_session='fetch', can't mix multiple "
+ "backends where some support RETURNING and others "
+ "don't"
+ )
+ else:
+ can_use_returning = per_bind_result
+
+ if per_bind_result:
+ return _result.null_result()
+ else:
+ return None
result = session.execute(
select_stmt,
)
matched_rows = result.fetchall()
- value_evaluators = EMPTY_DICT
-
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
-
- if effective_statement.__visit_name__ == "update":
- target_cls = mapper.class_
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
-
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- value_evaluators = {}
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
-
- else:
- resolved_keys_as_propnames = EMPTY_DICT
-
return update_options + {
- "_value_evaluators": value_evaluators,
"_matched_rows": matched_rows,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ "_can_use_returning": can_use_returning,
}
@CompileState.plugin_for("orm", "insert")
-class ORMInsert(ORMDMLState, InsertDMLState):
+class BulkORMInsert(ORMDMLState, InsertDMLState):
+ class default_insert_options(Options):
+ _dml_strategy: _DMLStrategyArgument = "auto"
+ _render_nulls: bool = False
+ _return_defaults: bool = False
+ _subject_mapper: Optional[Mapper[Any]] = None
+
+ select_statement: Optional[FromStatement] = None
+
@classmethod
def orm_pre_session_exec(
cls,
bind_arguments,
is_reentrant_invoke,
):
+
+ (
+ insert_options,
+ execution_options,
+ ) = BulkORMInsert.default_insert_options.from_execution_options(
+ "_sa_orm_insert_options",
+ {"dml_strategy"},
+ execution_options,
+ statement._execution_options,
+ )
bind_arguments["clause"] = statement
try:
plugin_subject = statement._propagate_attrs["plugin_subject"]
else:
bind_arguments["mapper"] = plugin_subject.mapper
+ insert_options += {"_subject_mapper": plugin_subject.mapper}
+
+ if not params:
+ if insert_options._dml_strategy == "auto":
+ insert_options += {"_dml_strategy": "orm"}
+ elif insert_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "bulk" ORM insert strategy without '
+ "passing separate parameters"
+ )
+ else:
+ if insert_options._dml_strategy == "auto":
+ insert_options += {"_dml_strategy": "bulk"}
+ elif insert_options._dml_strategy == "orm":
+ raise sa_exc.InvalidRequestError(
+ 'Can\'t use "orm" ORM insert strategy with a '
+ "separate parameter list"
+ )
+
+ if insert_options._dml_strategy != "raw":
+ # for ORM object loading, like ORMContext, we have to disable
+ # result set adapt_to_context, because we will be generating a
+ # new statement with specific columns that's cached inside of
+ # an ORMFromStatementCompileState, which we will re-use for
+ # each result.
+ if not execution_options:
+ execution_options = context._orm_load_exec_options
+ else:
+ execution_options = execution_options.union(
+ context._orm_load_exec_options
+ )
+
+ statement = statement._annotate(
+ {"dml_strategy": insert_options._dml_strategy}
+ )
+
return (
statement,
- util.immutabledict(execution_options),
+ util.immutabledict(execution_options).union(
+ {"_sa_orm_insert_options": insert_options}
+ ),
)
@classmethod
- def orm_setup_cursor_result(
+ def orm_execute_statement(
cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- result,
- ):
- return result
+ session: Session,
+ statement: dml.Insert,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ insert_options = execution_options.get(
+ "_sa_orm_insert_options", cls.default_insert_options
+ )
+
+ if insert_options._dml_strategy not in (
+ "raw",
+ "bulk",
+ "orm",
+ "auto",
+ ):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM insert strategy "
+ "are 'raw', 'orm', 'bulk', 'auto"
+ )
+
+ result: _result.Result[Any]
+
+ if insert_options._dml_strategy == "raw":
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ return result
+
+ if insert_options._dml_strategy == "bulk":
+ mapper = insert_options._subject_mapper
+
+ if (
+ statement._post_values_clause is not None
+ and mapper._multiple_persistence_tables
+ ):
+ raise sa_exc.InvalidRequestError(
+ "bulk INSERT with a 'post values' clause "
+ "(typically upsert) not supported for multi-table "
+ f"mapper {mapper}"
+ )
+
+ assert mapper is not None
+ assert session._transaction is not None
+ result = _bulk_insert(
+ mapper,
+ cast(
+ "Iterable[Dict[str, Any]]",
+ [params] if isinstance(params, dict) else params,
+ ),
+ session._transaction,
+ isstates=False,
+ return_defaults=insert_options._return_defaults,
+ render_nulls=insert_options._render_nulls,
+ use_orm_insert_stmt=statement,
+ execution_options=execution_options,
+ )
+ elif insert_options._dml_strategy == "orm":
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ else:
+ raise AssertionError()
+
+ if not bool(statement._returning):
+ return result
+
+ return cls._return_orm_returning(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
+
+ self = cast(
+ BulkORMInsert,
+ super().create_for_statement(statement, compiler, **kw),
+ )
+
+ if compiler is not None:
+ toplevel = not compiler.stack
+ else:
+ toplevel = True
+ if not toplevel:
+ return self
+
+ mapper = statement._propagate_attrs["plugin_subject"]
+ dml_strategy = statement._annotations.get("dml_strategy", "raw")
+ if dml_strategy == "bulk":
+ self._setup_for_bulk_insert(compiler)
+ elif dml_strategy == "orm":
+ self._setup_for_orm_insert(compiler, mapper)
+
+ return self
+
+ @classmethod
+ def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict):
+ return {
+ col.key if col is not None else k: v
+ for col, k, v in (
+ (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items()
+ )
+ }
+
+ def _setup_for_orm_insert(self, compiler, mapper):
+ statement = orm_level_statement = cast(dml.Insert, self.statement)
+
+ statement = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ statement,
+ use_supplemental_cols=False,
+ )
+ self.statement = statement
+
+ def _setup_for_bulk_insert(self, compiler):
+ """establish an INSERT statement within the context of
+ bulk insert.
+
+ This method will be within the "conn.execute()" call that is invoked
+ by persistence._emit_insert_statement().
+
+ """
+ statement = orm_level_statement = cast(dml.Insert, self.statement)
+ an = statement._annotations
+
+ emit_insert_table, emit_insert_mapper = (
+ an["_emit_insert_table"],
+ an["_emit_insert_mapper"],
+ )
+
+ statement = statement._clone()
+
+ statement.table = emit_insert_table
+ if self._dict_parameters:
+ self._dict_parameters = {
+ col: val
+ for col, val in self._dict_parameters.items()
+ if col.table is emit_insert_table
+ }
+
+ statement = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ statement,
+ use_supplemental_cols=True,
+ dml_mapper=emit_insert_mapper,
+ )
+
+ self.statement = statement
@CompileState.plugin_for("orm", "update")
self = cls.__new__(cls)
+ dml_strategy = statement._annotations.get(
+ "dml_strategy", "unspecified"
+ )
+
+ if dml_strategy == "bulk":
+ self._setup_for_bulk_update(statement, compiler)
+ elif dml_strategy in ("orm", "unspecified"):
+ self._setup_for_orm_update(statement, compiler)
+
+ return self
+
+ def _setup_for_orm_update(self, statement, compiler, **kw):
+ orm_level_statement = statement
+
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
self.extra_criteria_entities = {}
- self._resolved_values = cls._get_resolved_values(mapper, statement)
+ self._resolved_values = self._get_resolved_values(mapper, statement)
extra_criteria_attributes = {}
if statement._values:
self._resolved_values = dict(self._resolved_values)
- new_stmt = sql.Update.__new__(sql.Update)
- new_stmt.__dict__.update(statement.__dict__)
+ new_stmt = statement._clone()
new_stmt.table = mapper.local_table
# note if the statement has _multi_values, these
elif statement._values:
new_stmt._values = self._resolved_values
- new_crit = cls._adjust_for_extra_criteria(
+ new_crit = self._adjust_for_extra_criteria(
extra_criteria_attributes, mapper
)
if new_crit:
UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
- if compiler._annotations.get(
+ use_supplemental_cols = False
+
+ synchronize_session = compiler._annotations.get(
"synchronize_session", None
- ) == "fetch" and self.can_use_returning(
- compiler.dialect, mapper, is_multitable=self.is_multitable
- ):
- if new_stmt._returning:
- raise sa_exc.InvalidRequestError(
- "Can't use synchronize_session='fetch' "
- "with explicit returning()"
+ )
+ can_use_returning = compiler._annotations.get(
+ "can_use_returning", None
+ )
+ if can_use_returning is not False:
+ # even though pre_exec has determined basic
+ # can_use_returning for the dialect, if we are to use
+ # RETURNING we need to run can_use_returning() at this level
+ # unconditionally because is_delete_using was not known
+ # at the pre_exec level
+ can_use_returning = (
+ synchronize_session == "fetch"
+ and self.can_use_returning(
+ compiler.dialect, mapper, is_multitable=self.is_multitable
)
- self.statement = self.statement.returning(
- *mapper.local_table.primary_key
)
- return self
+ if synchronize_session == "fetch" and can_use_returning:
+ use_supplemental_cols = True
+
+ # NOTE: we might want to RETURNING the actual columns to be
+ # synchronized also. however this is complicated and difficult
+ # to align against the behavior of "evaluate". Additionally,
+ # in a large number (if not the majority) of cases, we have the
+ # "evaluate" answer, usually a fixed value, in memory already and
+ # there's no need to re-fetch the same value
+ # over and over again. so perhaps if it could be RETURNING just
+ # the elements that were based on a SQL expression and not
+ # a constant. For now it doesn't quite seem worth it
+ new_stmt = new_stmt.return_defaults(
+ *(list(mapper.local_table.primary_key))
+ )
+
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ use_supplemental_cols=use_supplemental_cols,
+ )
+
+ self.statement = new_stmt
+
+ def _setup_for_bulk_update(self, statement, compiler, **kw):
+ """establish an UPDATE statement within the context of
+ bulk insert.
+
+ This method will be within the "conn.execute()" call that is invoked
+ by persistence._emit_update_statement().
+
+ """
+ statement = cast(dml.Update, statement)
+ an = statement._annotations
+
+ emit_update_table, _ = (
+ an["_emit_update_table"],
+ an["_emit_update_mapper"],
+ )
+
+ statement = statement._clone()
+ statement.table = emit_update_table
+
+ UpdateDMLState.__init__(self, statement, compiler, **kw)
+
+ if self._ordered_values:
+ raise sa_exc.InvalidRequestError(
+ "bulk ORM UPDATE does not support ordered_values() for "
+ "custom UPDATE statements with bulk parameter sets. Use a "
+ "non-bulk UPDATE statement or use values()."
+ )
+
+ if self._dict_parameters:
+ self._dict_parameters = {
+ col: val
+ for col, val in self._dict_parameters.items()
+ if col.table is emit_update_table
+ }
+ self.statement = statement
+
+ @classmethod
+ def orm_execute_statement(
+ cls,
+ session: Session,
+ statement: dml.Update,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ update_options = execution_options.get(
+ "_sa_orm_update_options", cls.default_update_options
+ )
+
+ if update_options._dml_strategy not in ("orm", "auto", "bulk"):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM UPDATE strategy "
+ "are 'orm', 'auto', 'bulk'"
+ )
+
+ result: _result.Result[Any]
+
+ if update_options._dml_strategy == "bulk":
+ if statement._where_criteria:
+ raise sa_exc.InvalidRequestError(
+ "WHERE clause with bulk ORM UPDATE not "
+ "supported right now. Statement may be invoked at the "
+ "Core level using "
+ "session.connection().execute(stmt, parameters)"
+ )
+ mapper = update_options._subject_mapper
+ assert mapper is not None
+ assert session._transaction is not None
+ result = _bulk_update(
+ mapper,
+ cast(
+ "Iterable[Dict[str, Any]]",
+ [params] if isinstance(params, dict) else params,
+ ),
+ session._transaction,
+ isstates=False,
+ update_changed_only=False,
+ use_orm_update_stmt=statement,
+ )
+ return cls.orm_setup_cursor_result(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+ else:
+ return super().orm_execute_statement(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ conn,
+ )
@classmethod
def can_use_returning(
return True
@classmethod
- def _get_crud_kv_pairs(cls, statement, kv_iterator):
- plugin_subject = statement._propagate_attrs["plugin_subject"]
-
- core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
-
- if not plugin_subject or not plugin_subject.mapper:
- return core_get_crud_kv_pairs(statement, kv_iterator)
-
- mapper = plugin_subject.mapper
-
- values = []
-
- for k, v in kv_iterator:
- k = coercions.expect(roles.DMLColumnRole, k)
+ def _do_post_synchronize_bulk_evaluate(
+ cls, session, params, result, update_options
+ ):
+ if not params:
+ return
- if isinstance(k, str):
- desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
- if desc is NO_VALUE:
- values.append(
- (
- k,
- coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- ),
- )
- )
- else:
- values.extend(
- core_get_crud_kv_pairs(
- statement, desc._bulk_update_tuples(v)
- )
- )
- elif "entity_namespace" in k._annotations:
- k_anno = k._annotations
- attr = _entity_namespace_key(
- k_anno["entity_namespace"], k_anno["proxy_key"]
- )
- values.extend(
- core_get_crud_kv_pairs(
- statement, attr._bulk_update_tuples(v)
- )
- )
- else:
- values.append(
- (
- k,
- coercions.expect(
- roles.ExpressionElementRole,
- v,
- type_=sqltypes.NullType(),
- is_crud=True,
- ),
- )
- )
- return values
+ mapper = update_options._subject_mapper
+ pk_keys = [prop.key for prop in mapper._identity_key_props]
- @classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
+ identity_map = session.identity_map
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
- for obj in update_options._matched_objects:
-
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
+ for param in params:
+ identity_key = mapper.identity_key_from_primary_key(
+ (param[key] for key in pk_keys),
+ update_options._refresh_identity_token,
)
-
- # the evaluated states were gathered across all identity tokens.
- # however the post_sync events are called per identity token,
- # so filter.
- if (
- update_options._refresh_identity_token is not None
- and state.identity_token
- != update_options._refresh_identity_token
- ):
+ state = identity_map.fast_get_state(identity_key)
+ if not state:
continue
+ evaluated_keys = set(param).difference(pk_keys)
+
+ dict_ = state.dict
# only evaluate unmodified attributes
to_evaluate = state.unmodified.intersection(evaluated_keys)
for key in to_evaluate:
if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
+ dict_[key] = param[key]
state.manager.dispatch.refresh(state, None, to_evaluate)
state._commit(dict_, list(to_evaluate))
- to_expire = attrib.intersection(dict_).difference(to_evaluate)
+ # attributes that were formerly modified instead get expired.
+ # this only gets hit if the session had pending changes
+ # and autoflush were set to False.
+ to_expire = evaluated_keys.intersection(dict_).difference(
+ to_evaluate
+ )
if to_expire:
state._expire_attributes(dict_, to_expire)
- states.add(state)
- session._register_altered(states)
+ @classmethod
+ def _do_post_synchronize_evaluate(
+ cls, session, statement, result, update_options
+ ):
+
+ matched_objects = cls._get_matched_objects_on_criteria(
+ update_options,
+ session.identity_map.all_states(),
+ )
+
+ cls._apply_update_set_values_to_objects(
+ session,
+ update_options,
+ statement,
+ [(obj, state, dict_) for obj, state, dict_, _ in matched_objects],
+ )
@classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
+ def _do_post_synchronize_fetch(
+ cls, session, statement, result, update_options
+ ):
target_mapper = update_options._subject_mapper
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
-
- if result.returns_rows:
- rows = cls._interpret_returning_rows(target_mapper, result.all())
+ returned_defaults_rows = result.returned_defaults_rows
+ if returned_defaults_rows:
+ pk_rows = cls._interpret_returning_rows(
+ target_mapper, returned_defaults_rows
+ )
matched_rows = [
tuple(row) + (update_options._refresh_identity_token,)
- for row in rows
+ for row in pk_rows
]
else:
matched_rows = update_options._matched_rows
if identity_key in session.identity_map
]
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
+ if not objs:
+ return
- for obj in objs:
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
- )
+ cls._apply_update_set_values_to_objects(
+ session,
+ update_options,
+ statement,
+ [
+ (
+ obj,
+ attributes.instance_state(obj),
+ attributes.instance_dict(obj),
+ )
+ for obj in objs
+ ],
+ )
+
+ @classmethod
+ def _apply_update_set_values_to_objects(
+ cls, session, update_options, statement, matched_objects
+ ):
+ """apply values to objects derived from an update statement, e.g.
+ UPDATE..SET <values>
+
+ """
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
+ evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+ resolved_values = cls._get_resolved_values(mapper, statement)
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ value_evaluators = {}
+ for key, value in resolved_keys_as_propnames:
+ try:
+ _evaluator = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
+ except evaluator.UnevaluatableError:
+ pass
+ else:
+ value_evaluators[key] = _evaluator
+
+ evaluated_keys = list(value_evaluators.keys())
+ attrib = set(k for k, v in resolved_keys_as_propnames)
+
+ states = set()
+ for obj, state, dict_ in matched_objects:
to_evaluate = state.unmodified.intersection(evaluated_keys)
+
for key in to_evaluate:
if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
+ # only run eval for attributes that are present.
+ dict_[key] = value_evaluators[key](obj)
+
state.manager.dispatch.refresh(state, None, to_evaluate)
state._commit(dict_, list(to_evaluate))
+ # attributes that were formerly modified instead get expired.
+ # this only gets hit if the session had pending changes
+ # and autoflush were set to False.
to_expire = attrib.intersection(dict_).difference(to_evaluate)
if to_expire:
state._expire_attributes(dict_, to_expire)
def create_for_statement(cls, statement, compiler, **kw):
self = cls.__new__(cls)
+ orm_level_statement = statement
+
ext_info = statement.table._annotations["parententity"]
self.mapper = mapper = ext_info.mapper
if opt._is_criteria_option:
opt.get_global_criteria(extra_criteria_attributes)
+ new_stmt = statement._clone()
+ new_stmt.table = mapper.local_table
+
new_crit = cls._adjust_for_extra_criteria(
extra_criteria_attributes, mapper
)
if new_crit:
- statement = statement.where(*new_crit)
+ new_stmt = new_stmt.where(*new_crit)
# do this first as we need to determine if there is
# DELETE..FROM
- DeleteDMLState.__init__(self, statement, compiler, **kw)
+ DeleteDMLState.__init__(self, new_stmt, compiler, **kw)
- if compiler._annotations.get(
+ use_supplemental_cols = False
+
+ synchronize_session = compiler._annotations.get(
"synchronize_session", None
- ) == "fetch" and self.can_use_returning(
- compiler.dialect,
- mapper,
- is_multitable=self.is_multitable,
- is_delete_using=compiler._annotations.get(
- "is_delete_using", False
- ),
- ):
- self.statement = statement.returning(*statement.table.primary_key)
+ )
+ can_use_returning = compiler._annotations.get(
+ "can_use_returning", None
+ )
+ if can_use_returning is not False:
+ # even though pre_exec has determined basic
+ # can_use_returning for the dialect, if we are to use
+ # RETURNING we need to run can_use_returning() at this level
+ # unconditionally because is_delete_using was not known
+ # at the pre_exec level
+ can_use_returning = (
+ synchronize_session == "fetch"
+ and self.can_use_returning(
+ compiler.dialect,
+ mapper,
+ is_multitable=self.is_multitable,
+ is_delete_using=compiler._annotations.get(
+ "is_delete_using", False
+ ),
+ )
+ )
+
+ if can_use_returning:
+ use_supplemental_cols = True
+
+ new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key)
+
+ new_stmt = self._setup_orm_returning(
+ compiler,
+ orm_level_statement,
+ new_stmt,
+ use_supplemental_cols=use_supplemental_cols,
+ )
+
+ self.statement = new_stmt
return self
+ @classmethod
+ def orm_execute_statement(
+ cls,
+ session: Session,
+ statement: dml.Delete,
+ params: _CoreAnyExecuteParams,
+ execution_options: _ExecuteOptionsParameter,
+ bind_arguments: _BindArguments,
+ conn: Connection,
+ ) -> _result.Result:
+
+ update_options = execution_options.get(
+ "_sa_orm_update_options", cls.default_update_options
+ )
+
+ if update_options._dml_strategy == "bulk":
+ raise sa_exc.InvalidRequestError(
+ "Bulk ORM DELETE not supported right now. "
+ "Statement may be invoked at the "
+ "Core level using "
+ "session.connection().execute(stmt, parameters)"
+ )
+
+ if update_options._dml_strategy not in (
+ "orm",
+ "auto",
+ ):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for ORM DELETE strategy are 'orm', 'auto'"
+ )
+
+ return super().orm_execute_statement(
+ session, statement, params, execution_options, bind_arguments, conn
+ )
+
@classmethod
def can_use_returning(
cls,
return True
@classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
-
- session._remove_newly_deleted(
- [
- attributes.instance_state(obj)
- for obj in update_options._matched_objects
- ]
+ def _do_post_synchronize_evaluate(
+ cls, session, statement, result, update_options
+ ):
+ matched_objects = cls._get_matched_objects_on_criteria(
+ update_options,
+ session.identity_map.all_states(),
)
+ to_delete = []
+
+ for _, state, dict_, is_partially_expired in matched_objects:
+ if is_partially_expired:
+ state._expire(dict_, session.identity_map._modified)
+ else:
+ to_delete.append(state)
+
+ if to_delete:
+ session._remove_newly_deleted(to_delete)
+
@classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
+ def _do_post_synchronize_fetch(
+ cls, session, statement, result, update_options
+ ):
target_mapper = update_options._subject_mapper
- if result.returns_rows:
- rows = cls._interpret_returning_rows(target_mapper, result.all())
+ returned_defaults_rows = result.returned_defaults_rows
+
+ if returned_defaults_rows:
+ pk_rows = cls._interpret_returning_rows(
+ target_mapper, returned_defaults_rows
+ )
matched_rows = [
tuple(row) + (update_options._refresh_identity_token,)
- for row in rows
+ for row in pk_rows
]
else:
matched_rows = update_options._matched_rows
from .query import Query
from .session import _BindArguments
from .session import Session
+ from ..engine import Result
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptionsParameter
from ..sql._typing import _ColumnsClauseArgument
class AbstractORMCompileState(CompileState):
+ is_dml_returning = False
+
@classmethod
def create_for_statement(
cls,
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> AbstractORMCompileState:
"""Create a context for a statement given a :class:`.Compiler`.
+
This method is always invoked in the context of SQLCompiler.process().
+
For a Select object, this would be invoked from
SQLCompiler.visit_select(). For the special FromStatement object used
by Query to indicate "Query.from_statement()", this is called by
):
raise NotImplementedError()
+ @classmethod
+ def orm_execute_statement(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ conn,
+ ) -> Result:
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
+ )
+ return cls.orm_setup_cursor_result(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ result,
+ )
+
@classmethod
def orm_setup_cursor_result(
cls,
def __init__(self, *arg, **kw):
raise NotImplementedError()
+ if TYPE_CHECKING:
+
+ @classmethod
+ def create_for_statement(
+ cls,
+ statement: Union[Select, FromStatement],
+ compiler: Optional[SQLCompiler],
+ **kw: Any,
+ ) -> ORMCompileState:
+ ...
+
def _append_dedupe_col_collection(self, obj, col_collection):
dedupe = self.dedupe_columns
if obj not in dedupe:
else:
return SelectState._column_naming_convention(label_style)
- @classmethod
- def create_for_statement(
- cls,
- statement: Union[Select, FromStatement],
- compiler: Optional[SQLCompiler],
- **kw: Any,
- ) -> ORMCompileState:
- """Create a context for a statement given a :class:`.Compiler`.
-
- This method is always invoked in the context of SQLCompiler.process().
-
- For a Select object, this would be invoked from
- SQLCompiler.visit_select(). For the special FromStatement object used
- by Query to indicate "Query.from_statement()", this is called by
- FromStatement._compiler_dispatch() that would be called by
- SQLCompiler.process().
-
- """
- raise NotImplementedError()
-
@classmethod
def get_column_descriptions(cls, statement):
return _column_descriptions(statement)
)
+class DMLReturningColFilter:
+ """an adapter used for the DML RETURNING case.
+
+ Has a subset of the interface used by
+ :class:`.ORMAdapter` and is used for :class:`._QueryEntity`
+ instances to set up their columns as used in RETURNING for a
+ DML statement.
+
+ """
+
+ __slots__ = ("mapper", "columns", "__weakref__")
+
+ def __init__(self, target_mapper, immediate_dml_mapper):
+ if (
+ immediate_dml_mapper is not None
+ and target_mapper.local_table
+ is not immediate_dml_mapper.local_table
+ ):
+ # joined inh, or in theory other kinds of multi-table mappings
+ self.mapper = immediate_dml_mapper
+ else:
+ # single inh, normal mappings, etc.
+ self.mapper = target_mapper
+ self.columns = self.columns = util.WeakPopulateDict(
+ self.adapt_check_present # type: ignore
+ )
+
+ def __call__(self, col, as_filter):
+ for cc in sql_util._find_columns(col):
+ c2 = self.adapt_check_present(cc)
+ if c2 is not None:
+ return col
+ else:
+ return None
+
+ def adapt_check_present(self, col):
+ mapper = self.mapper
+ prop = mapper._columntoproperty.get(col, None)
+ if prop is None:
+ return None
+ return mapper.local_table.c.corresponding_column(col)
+
+
@sql.base.CompileState.plugin_for("orm", "orm_from_statement")
class ORMFromStatementCompileState(ORMCompileState):
_from_obj_alias = None
statement_container: FromStatement
requested_statement: Union[SelectBase, TextClause, UpdateBase]
- dml_table: _DMLTableElement
+ dml_table: Optional[_DMLTableElement] = None
_has_orm_entities = False
multi_row_eager_loaders = False
statement_container: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> ORMFromStatementCompileState:
if compiler is not None:
toplevel = not compiler.stack
if statement.is_dml:
self.dml_table = statement.table
+ self.is_dml_returning = True
self._entities = []
self._polymorphic_adapters = {}
def _get_current_adapter(self):
return None
+ def setup_dml_returning_compile_state(self, dml_mapper):
+ """used by BulkORMInsert (and Update / Delete?) to set up a handler
+ for RETURNING to return ORM objects and expressions
+
+ """
+ target_mapper = self.statement._propagate_attrs.get(
+ "plugin_subject", None
+ )
+ adapter = DMLReturningColFilter(target_mapper, dml_mapper)
+ for entity in self._entities:
+ entity.setup_dml_returning_compile_state(self, adapter)
+
class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]):
"""Core construct that represents a load of ORM objects from various
statement: Union[Select, FromStatement],
compiler: Optional[SQLCompiler],
**kw: Any,
- ) -> ORMCompileState:
+ ) -> ORMSelectCompileState:
"""compiler hook, we arrive here from compiler.visit_select() only."""
self = cls.__new__(cls)
def setup_compile_state(self, compile_state: ORMCompileState) -> None:
raise NotImplementedError()
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ raise NotImplementedError()
+
def row_processor(self, context, result):
raise NotImplementedError()
return _instance, self._label_name, self._extra_entities
- def setup_compile_state(self, compile_state):
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ loading._setup_entity_query(
+ compile_state,
+ self.mapper,
+ self,
+ self.path,
+ adapter,
+ compile_state.primary_columns,
+ with_polymorphic=self._with_polymorphic_mappers,
+ only_load_props=compile_state.compile_options._only_load_props,
+ polymorphic_discriminator=self._polymorphic_discriminator,
+ )
+ def setup_compile_state(self, compile_state):
adapter = self._get_entity_clauses(compile_state)
single_table_crit = self.mapper._single_table_criterion
only_load_props=compile_state.compile_options._only_load_props,
polymorphic_discriminator=self._polymorphic_discriminator,
)
-
compile_state._fallback_from_clauses.append(self.selectable)
getter, label_name, extra_entities = self._row_processor
if self.translate_raw_column:
extra_entities += (
- result.context.invoked_statement._raw_columns[
- self.raw_column_index
- ],
+ context.query._raw_columns[self.raw_column_index],
)
return getter, label_name, extra_entities
if self.translate_raw_column:
extra_entities = self._extra_entities + (
- result.context.invoked_statement._raw_columns[
- self.raw_column_index
- ],
+ context.query._raw_columns[self.raw_column_index],
)
return getter, self._label_name, extra_entities
else:
current_adapter = compile_state._get_current_adapter()
if current_adapter:
column = current_adapter(self.column, False)
+ if column is None:
+ return
else:
column = self.column
self.entity_zero
) and entity.common_parent(self.entity_zero)
+ def setup_dml_returning_compile_state(
+ self,
+ compile_state: ORMCompileState,
+ adapter: DMLReturningColFilter,
+ ) -> None:
+ self._fetch_column = self.column
+ column = adapter(self.column, False)
+ if column is not None:
+ compile_state.dedupe_columns.add(column)
+ compile_state.primary_columns.append(column)
+
def setup_compile_state(self, compile_state):
current_adapter = compile_state._get_current_adapter()
if current_adapter:
column = current_adapter(self.column, False)
+ if column is None:
+ assert compile_state.is_dml_returning
+ self._fetch_column = self.column
+ return
else:
column = self.column
import typing
from typing import Any
from typing import Callable
+from typing import Dict
from typing import List
from typing import NoReturn
from typing import Optional
def _attribute_keys(self) -> Sequence[str]:
return [prop.key for prop in self.props]
+ def _populate_composite_bulk_save_mappings_fn(
+ self,
+ ) -> Callable[[Dict[str, Any]], None]:
+
+ if self._generated_composite_accessor:
+ get_values = self._generated_composite_accessor
+ else:
+
+ def get_values(val: Any) -> Tuple[Any]:
+ return val.__composite_values__() # type: ignore
+
+ attrs = [prop.key for prop in self.props]
+
+ def populate(dest_dict: Dict[str, Any]) -> None:
+ dest_dict.update(
+ {
+ key: val
+ for key, val in zip(
+ attrs, get_values(dest_dict.pop(self.key))
+ )
+ }
+ )
+
+ return populate
+
def get_history(
self,
state: InstanceState[Any],
from __future__ import annotations
-import operator
-
+from .base import LoaderCallableStatus
+from .base import PassiveFlag
from .. import exc
from .. import inspect
from .. import util
return None
+class _ExpiredObject(operators.ColumnOperators):
+ def operate(self, *arg, **kw):
+ return self
+
+ def reverse_operate(self, *arg, **kw):
+ return self
+
+
_NO_OBJECT = _NoObject()
+_EXPIRED_OBJECT = _ExpiredObject()
class EvaluatorCompiler:
f"alternate class {parentmapper.class_}"
)
key = parentmapper._columntoproperty[clause].key
+ impl = parentmapper.class_manager[key].impl
+
+ if impl is not None:
+
+ def get_corresponding_attr(obj):
+ if obj is None:
+ return _NO_OBJECT
+ state = inspect(obj)
+ dict_ = state.dict
+
+ value = impl.get(
+ state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH
+ )
+ if value is LoaderCallableStatus.PASSIVE_NO_RESULT:
+ return _EXPIRED_OBJECT
+ return value
+
+ return get_corresponding_attr
else:
key = clause.key
if (
"make use of the actual mapped columns in ORM-evaluated "
"UPDATE / DELETE expressions."
)
+
else:
raise UnevaluatableError(f"Cannot evaluate column: {clause}")
- get_corresponding_attr = operator.attrgetter(key)
- return (
- lambda obj: get_corresponding_attr(obj)
- if obj is not None
- else _NO_OBJECT
- )
+ def get_corresponding_attr(obj):
+ if obj is None:
+ return _NO_OBJECT
+ return getattr(obj, key, _EXPIRED_OBJECT)
+
+ return get_corresponding_attr
def visit_tuple(self, clause):
return self.visit_clauselist(clause)
has_null = False
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
- if value:
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif value:
return True
has_null = has_null or value is None
if has_null:
def evaluate(obj):
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+
if not value:
if value is None or value is _NO_OBJECT:
return None
values = []
for sub_evaluate in evaluators:
value = sub_evaluate(obj)
- if value is None or value is _NO_OBJECT:
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif value is None or value is _NO_OBJECT:
return None
values.append(value)
return tuple(values)
def visit_is_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
- return eval_left(obj) == eval_right(obj)
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ return left_val == right_val
return evaluate
def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause):
def evaluate(obj):
- return eval_left(obj) != eval_right(obj)
+ left_val = eval_left(obj)
+ right_val = eval_right(obj)
+ if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ return left_val != right_val
return evaluate
def evaluate(obj):
left_val = eval_left(obj)
right_val = eval_right(obj)
- if left_val is None or right_val is None:
+ if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif left_val is None or right_val is None:
return None
+
return operator(eval_left(obj), eval_right(obj))
return evaluate
def evaluate(obj):
value = eval_inner(obj)
- if value is None:
+ if value is _EXPIRED_OBJECT:
+ return _EXPIRED_OBJECT
+ elif value is None:
return None
return not value
) -> Optional[_O]:
raise NotImplementedError()
+ def fast_get_state(
+ self, key: _IdentityKeyType[_O]
+ ) -> Optional[InstanceState[_O]]:
+ raise NotImplementedError()
+
def keys(self) -> Iterable[_IdentityKeyType[Any]]:
return self._dict.keys()
self._dict[key] = state
state._instance_dict = self._wr
+ def fast_get_state(
+ self, key: _IdentityKeyType[_O]
+ ) -> Optional[InstanceState[_O]]:
+ return self._dict.get(key)
+
def get(
self, key: _IdentityKeyType[_O], default: Optional[_O] = None
) -> Optional[_O]:
from typing import TypeVar
from typing import Union
-from sqlalchemy.orm.context import FromStatement
from . import attributes
from . import exc as orm_exc
from . import path_registry
from .base import _RAISE_FOR_STATE
from .base import _SET_DEFERRED_EXPIRED
from .base import PassiveFlag
+from .context import FromStatement
from .util import _none_set
from .util import state_str
from .. import exc as sa_exc
from ..sql.selectable import ForUpdateArg
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..sql.selectable import SelectState
+from ..util import EMPTY_DICT
if TYPE_CHECKING:
from ._typing import _IdentityKeyType
)
quick_populators = path.get(
- context.attributes, "memoized_setups", _none_set
+ context.attributes, "memoized_setups", EMPTY_DICT
)
todo = []
_memoized_values: Dict[Any, Callable[[], Any]]
_inheriting_mappers: util.WeakSequence[Mapper[Any]]
_all_tables: Set[Table]
+ _polymorphic_attr_key: Optional[str]
_pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]]
_cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]]
"""
setter = False
+ polymorphic_key: Optional[str] = None
if self.polymorphic_on is not None:
setter = True
self._set_polymorphic_identity = (
mapper._set_polymorphic_identity
)
+ self._polymorphic_attr_key = (
+ mapper._polymorphic_attr_key
+ )
self._validate_polymorphic_identity = (
mapper._validate_polymorphic_identity
)
else:
self._set_polymorphic_identity = None
+ self._polymorphic_attr_key = None
return
if setter:
def _set_polymorphic_identity(state):
dict_ = state.dict
+ # TODO: what happens if polymorphic_on column attribute name
+ # does not match .key?
state.get_impl(polymorphic_key).set(
state,
dict_,
None,
)
+ self._polymorphic_attr_key = polymorphic_key
+
def _validate_polymorphic_identity(mapper, state, dict_):
if (
polymorphic_key in dict_
_validate_polymorphic_identity
)
else:
+ self._polymorphic_attr_key = None
self._set_polymorphic_identity = None
_validate_polymorphic_identity = None
def _compiled_cache(self):
return util.LRUCache(self._compiled_cache_size)
+ @HasMemoized.memoized_attribute
+ def _multiple_persistence_tables(self):
+ return len(self.tables) > 1
+
@HasMemoized.memoized_attribute
def _sorted_tables(self):
table_to_mapper: Dict[Table, Mapper[Any]] = {}
from .. import future
from .. import sql
from .. import util
+from ..engine import cursor as _cursor
from ..sql import operators
from ..sql.elements import BooleanClauseList
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
None
)
+ if bulk and mapper._set_polymorphic_identity:
+ params.setdefault(
+ mapper._polymorphic_attr_key, mapper.polymorphic_identity
+ )
+
yield (
state,
state_dict,
def _collect_update_commands(
- uowtransaction, table, states_to_update, bulk=False
+ uowtransaction,
+ table,
+ states_to_update,
+ bulk=False,
+ use_orm_update_stmt=None,
):
"""Identify sets of values to use in UPDATE statements for a
list of states.
pks = mapper._pks_by_table[table]
- value_params = {}
+ if use_orm_update_stmt is not None:
+ # TODO: ordered values, etc
+ value_params = use_orm_update_stmt._values
+ else:
+ value_params = {}
propkey_to_col = mapper._propkey_to_col[table]
table,
update,
bookkeeping=True,
+ use_orm_update_stmt=None,
):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
execution_options = {"compiled_cache": base_mapper._compiled_cache}
- def update_stmt():
+ def update_stmt(existing_stmt=None):
clauses = BooleanClauseList._construct_raw(operators.and_)
for col in mapper._pks_by_table[table]:
)
)
- stmt = table.update().where(clauses)
+ if existing_stmt is not None:
+ stmt = existing_stmt.where(clauses)
+ else:
+ stmt = table.update().where(clauses)
return stmt
- cached_stmt = base_mapper._memo(("update", table), update_stmt)
+ if use_orm_update_stmt is not None:
+ cached_stmt = update_stmt(use_orm_update_stmt)
+
+ else:
+ cached_stmt = base_mapper._memo(("update", table), update_stmt)
for (
(connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
records = list(records)
statement = cached_stmt
+
+ if use_orm_update_stmt is not None:
+ statement = statement._annotate(
+ {
+ "_emit_update_table": table,
+ "_emit_update_mapper": mapper,
+ }
+ )
+
return_defaults = False
if not has_all_pks:
table,
insert,
bookkeeping=True,
+ use_orm_insert_stmt=None,
+ execution_options=None,
):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
- cached_stmt = base_mapper._memo(("insert", table), table.insert)
+ if use_orm_insert_stmt is not None:
+ cached_stmt = use_orm_insert_stmt
+ exec_opt = util.EMPTY_DICT
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
+ # if a user query with RETURNING was passed, we definitely need
+ # to use RETURNING.
+ returning_is_required_anyway = bool(use_orm_insert_stmt._returning)
+ else:
+ returning_is_required_anyway = False
+ cached_stmt = base_mapper._memo(("insert", table), table.insert)
+ exec_opt = {"compiled_cache": base_mapper._compiled_cache}
+
+ if execution_options:
+ execution_options = util.EMPTY_DICT.merge_with(
+ exec_opt, execution_options
+ )
+ else:
+ execution_options = exec_opt
+
+ return_result = None
for (
- (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
+ (connection, _, hasvalue, has_all_pks, has_all_defaults),
records,
) in groupby(
insert,
statement = cached_stmt
+ if use_orm_insert_stmt is not None:
+ statement = statement._annotate(
+ {
+ "_emit_insert_table": table,
+ "_emit_insert_mapper": mapper,
+ }
+ )
+
if (
- not bookkeeping
- or (
- has_all_defaults
- or not base_mapper.eager_defaults
- or not base_mapper.local_table.implicit_returning
- or not connection.dialect.insert_returning
+ (
+ not bookkeeping
+ or (
+ has_all_defaults
+ or not base_mapper.eager_defaults
+ or not base_mapper.local_table.implicit_returning
+ or not connection.dialect.insert_returning
+ )
)
+ and not returning_is_required_anyway
and has_all_pks
and not hasvalue
):
+
# the "we don't need newly generated values back" section.
# here we have all the PKs, all the defaults or we don't want
# to fetch them, or the dialect doesn't support RETURNING at all
records = list(records)
multiparams = [rec[2] for rec in records]
- c = connection.execute(
+ result = connection.execute(
statement, multiparams, execution_options=execution_options
)
if bookkeeping:
has_all_defaults,
),
last_inserted_params,
- ) in zip(records, c.context.compiled_parameters):
+ ) in zip(records, result.context.compiled_parameters):
if state:
_postfetch(
mapper_rec,
table,
state,
state_dict,
- c,
+ result,
last_inserted_params,
value_params,
False,
- c.returned_defaults
- if not c.context.executemany
+ result.returned_defaults
+ if not result.context.executemany
else None,
)
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
else:
- # here, we need defaults and/or pk values back.
+ # here, we need defaults and/or pk values back or we otherwise
+ # know that we are using RETURNING in any case
records = list(records)
if (
and len(records) > 1
):
do_executemany = True
+ elif returning_is_required_anyway:
+ if connection.dialect.insert_executemany_returning:
+ do_executemany = True
+ else:
+ raise sa_exc.InvalidRequestError(
+ f"Can't use explicit RETURNING for bulk INSERT "
+ f"operation with "
+ f"{connection.dialect.dialect_description} backend; "
+ f"executemany is not supported with RETURNING"
+ )
else:
do_executemany = False
statement = statement.return_defaults(
*mapper._server_default_cols[table]
)
+
if mapper.version_id_col is not None:
statement = statement.return_defaults(mapper.version_id_col)
elif do_executemany:
if do_executemany:
multiparams = [rec[2] for rec in records]
- c = connection.execute(
+ result = connection.execute(
statement, multiparams, execution_options=execution_options
)
+ if use_orm_insert_stmt is not None:
+ if return_result is None:
+ return_result = result
+ else:
+ return_result = return_result.splice_vertically(result)
+
if bookkeeping:
for (
(
returned_defaults,
) in zip_longest(
records,
- c.context.compiled_parameters,
- c.inserted_primary_key_rows,
- c.returned_defaults_rows or (),
+ result.context.compiled_parameters,
+ result.inserted_primary_key_rows,
+ result.returned_defaults_rows or (),
):
if inserted_primary_key is None:
# this is a real problem and means that we didn't
table,
state,
state_dict,
- c,
+ result,
last_inserted_params,
value_params,
False,
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
else:
+ assert not returning_is_required_anyway
+
for (
state,
state_dict,
else:
_postfetch_bulk_save(mapper_rec, state_dict, table)
+ if use_orm_insert_stmt is not None:
+ if return_result is None:
+ return _cursor.null_dml_result()
+ else:
+ return return_result
+
def _emit_post_update_statements(
base_mapper, uowtransaction, mapper, table, update
)
def delete(
- self, synchronize_session: _SynchronizeSessionArgument = "evaluate"
+ self, synchronize_session: _SynchronizeSessionArgument = "auto"
) -> int:
r"""Perform a DELETE with an arbitrary WHERE clause.
def update(
self,
values: Dict[_DMLColumnArgument, Any],
- synchronize_session: _SynchronizeSessionArgument = "evaluate",
+ synchronize_session: _SynchronizeSessionArgument = "auto",
update_args: Optional[Dict[Any, Any]] = None,
) -> int:
r"""Perform an UPDATE with an arbitrary WHERE clause.
statement._propagate_attrs.get("compile_state_plugin", None)
== "orm"
):
- # note that even without "future" mode, we need
compile_state_cls = CompileState._get_plugin_class_for_plugin(
statement, "orm"
)
if TYPE_CHECKING:
- assert isinstance(compile_state_cls, ORMCompileState)
+ assert isinstance(
+ compile_state_cls, context.AbstractORMCompileState
+ )
else:
compile_state_cls = None
statement, params or {}, execution_options=execution_options
)
- result: Result[Any] = conn.execute(
- statement, params or {}, execution_options=execution_options
- )
-
if compile_state_cls:
- result = compile_state_cls.orm_setup_cursor_result(
+ result: Result[Any] = compile_state_cls.orm_execute_statement(
self,
statement,
- params,
+ params or {},
execution_options,
bind_arguments,
- result,
+ conn,
+ )
+ else:
+ result = conn.execute(
+ statement, params or {}, execution_options=execution_options
)
if _scalar_result:
def scalars(
self,
statement: TypedReturnsRows[Tuple[_T]],
- params: Optional[_CoreSingleExecuteParams] = None,
+ params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
def scalars(
self,
statement: Executable,
- params: Optional[_CoreSingleExecuteParams] = None,
+ params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
def scalars(
self,
statement: Executable,
- params: Optional[_CoreSingleExecuteParams] = None,
+ params: Optional[_CoreAnyExecuteParams] = None,
*,
execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
fetch = self.columns[0]
if adapter:
fetch = adapter.columns[fetch]
+ if fetch is None:
+ # None happens here only for dml bulk_persistence cases
+ # when context.DMLReturningColFilter is used
+ return
+
memoized_populators[self.parent_property] = fetch
def init_class_attribute(self, mapper):
fetch = columns[0]
if adapter:
fetch = adapter.columns[fetch]
+ if fetch is None:
+ # None is not expected to be the result of any
+ # adapter implementation here, however there may be theoretical
+ # usages of returning() with context.DMLReturningColFilter
+ return
+
memoized_populators[self.parent_property] = fetch
def create_row_processor(
# e.g. BindParameter, add it if present.
if cls.__dict__.get("inherit_cache", False):
anno_cls.inherit_cache = True # type: ignore
+ elif "inherit_cache" in cls.__dict__:
+ anno_cls.inherit_cache = cls.__dict__["inherit_cache"] # type: ignore
anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators)
delete_stmt, delete_stmt.table, extra_froms
)
+ crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
+
if delete_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
delete_stmt, table_text
text += table_text
- if delete_stmt._returning:
- if self.returning_precedes_values:
- text += " " + self.returning_clause(
- delete_stmt,
- delete_stmt._returning,
- populate_result_map=toplevel,
- )
+ if (
+ self.implicit_returning or delete_stmt._returning
+ ) and self.returning_precedes_values:
+ text += " " + self.returning_clause(
+ delete_stmt,
+ self.implicit_returning or delete_stmt._returning,
+ populate_result_map=toplevel,
+ )
if extra_froms:
extra_from_text = self.delete_extra_from_clause(
if t:
text += " WHERE " + t
- if delete_stmt._returning and not self.returning_precedes_values:
+ if (
+ self.implicit_returning or delete_stmt._returning
+ ) and not self.returning_precedes_values:
text += " " + self.returning_clause(
delete_stmt,
- delete_stmt._returning,
+ self.implicit_returning or delete_stmt._returning,
populate_result_map=toplevel,
)
self._label_select_column(None, c, True, False, {})
for c in base._select_iterables(returning_cols)
]
-
return "RETURNING " + ", ".join(columns)
def update_from_clause(
"return_defaults() simultaneously"
)
+ if compile_state.isdelete:
+ _setup_delete_return_defaults(
+ compiler,
+ stmt,
+ compile_state,
+ (),
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ (),
+ (),
+ toplevel,
+ kw,
+ )
+ return _CrudParams([], [])
+
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
kw,
):
- (
- need_pks,
- implicit_returning,
- implicit_return_defaults,
- postfetch_lastrowid,
- ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
-
cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
assert compiler.stack[-1]["selectable"] is stmt
postfetch_lastrowid,
) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel)
+ assert compile_state.isupdate or compile_state.isinsert
+
if compile_state._parameter_ordering:
parameter_ordering = [
_column_as_key(key) for key in compile_state._parameter_ordering
else:
autoincrement_col = insert_null_pk_still_autoincrements = None
+ if stmt._supplemental_returning:
+ supplemental_returning = set(stmt._supplemental_returning)
+ else:
+ supplemental_returning = set()
+
+ compiler_implicit_returning = compiler.implicit_returning
+
for c in cols:
# scan through every column in the target table
# column has a DDL-level default, and is either not a pk
# column or we don't need the pk.
if implicit_return_defaults and c in implicit_return_defaults:
- compiler.implicit_returning.append(c)
+ compiler_implicit_returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
+
elif implicit_return_defaults and c in implicit_return_defaults:
- compiler.implicit_returning.append(c)
+ compiler_implicit_returning.append(c)
+
elif (
c.primary_key
and c is not stmt.table._autoincrement_column
kw,
)
+ # adding supplemental cols to implicit_returning in table
+ # order so that order is maintained between multiple INSERT
+ # statements which may have different parameters included, but all
+ # have the same RETURNING clause
+ if (
+ c in supplemental_returning
+ and c not in compiler_implicit_returning
+ ):
+ compiler_implicit_returning.append(c)
+
+ if supplemental_returning:
+ # we should have gotten every col into implicit_returning,
+ # however supplemental returning can also have SQL functions etc.
+ # in it
+ remaining_supplemental = supplemental_returning.difference(
+ compiler_implicit_returning
+ )
+ compiler_implicit_returning.extend(
+ c
+ for c in stmt._supplemental_returning
+ if c in remaining_supplemental
+ )
+
+
+def _setup_delete_return_defaults(
+ compiler,
+ stmt,
+ compile_state,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ toplevel,
+ kw,
+):
+ (_, _, implicit_return_defaults, _) = _get_returning_modifiers(
+ compiler, stmt, compile_state, toplevel
+ )
+
+ if not implicit_return_defaults:
+ return
+
+ if stmt._return_defaults_columns:
+ compiler.implicit_returning.extend(implicit_return_defaults)
+
+ if stmt._supplemental_returning:
+ ir_set = set(compiler.implicit_returning)
+ compiler.implicit_returning.extend(
+ c for c in stmt._supplemental_returning if c not in ir_set
+ )
+
def _append_param_parameter(
compiler,
elif compiler.dialect.postfetch_lastrowid:
compiler.postfetch_lastrowid = True
- elif implicit_return_defaults and c in implicit_return_defaults:
+ elif implicit_return_defaults and (c in implicit_return_defaults):
compiler.implicit_returning.append(c)
else:
INSERT or UPDATE statement after it's invoked.
"""
+
need_pks = (
toplevel
and _compile_state_isinsert(compile_state)
)
)
and not stmt._returning
+ # and (not stmt._returning or stmt._return_defaults)
and not compile_state._has_multi_parameters
)
or stmt._return_defaults
)
)
-
if implicit_returning:
postfetch_lastrowid = False
if _compile_state_isinsert(compile_state):
- implicit_return_defaults = implicit_returning and stmt._return_defaults
+ should_implicit_return_defaults = (
+ implicit_returning and stmt._return_defaults
+ )
elif compile_state.isupdate:
- implicit_return_defaults = (
+ should_implicit_return_defaults = (
stmt._return_defaults
and compile_state._primary_table.implicit_returning
and compile_state._supports_implicit_returning
and compiler.dialect.update_returning
)
+ elif compile_state.isdelete:
+ should_implicit_return_defaults = (
+ stmt._return_defaults
+ and compile_state._primary_table.implicit_returning
+ and compile_state._supports_implicit_returning
+ and compiler.dialect.delete_returning
+ )
else:
- # this line is unused, currently we are always
- # isinsert or isupdate
- implicit_return_defaults = False # pragma: no cover
+ should_implicit_return_defaults = False # pragma: no cover
- if implicit_return_defaults:
+ if should_implicit_return_defaults:
if not stmt._return_defaults_columns:
implicit_return_defaults = set(stmt.table.c)
else:
implicit_return_defaults = set(stmt._return_defaults_columns)
+ else:
+ implicit_return_defaults = None
return (
need_pks,
- implicit_returning,
+ implicit_returning or should_implicit_return_defaults,
implicit_return_defaults,
postfetch_lastrowid,
)
def get_plugin_class(cls, statement: Executable) -> Type[DMLState]:
...
+ @classmethod
+ def _get_multi_crud_kv_pairs(
+ cls,
+ statement: UpdateBase,
+ multi_kv_iterator: Iterable[Dict[_DMLColumnArgument, Any]],
+ ) -> List[Dict[_DMLColumnElement, Any]]:
+ return [
+ {
+ coercions.expect(roles.DMLColumnRole, k): v
+ for k, v in mapping.items()
+ }
+ for mapping in multi_kv_iterator
+ ]
+
@classmethod
def _get_crud_kv_pairs(
cls,
statement: UpdateBase,
kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]],
+ needs_to_be_cacheable: bool,
) -> List[Tuple[_DMLColumnElement, Any]]:
return [
(
coercions.expect(roles.DMLColumnRole, k),
- coercions.expect(
+ v
+ if not needs_to_be_cacheable
+ else coercions.expect(
roles.ExpressionElementRole,
v,
type_=NullType(),
def _insert_col_keys(self) -> List[str]:
# this is also done in crud.py -> _key_getters_for_crud_column
return [
- coercions.expect_as_key(roles.DMLColumnRole, col)
+ coercions.expect(roles.DMLColumnRole, col, as_key=True)
for col in self._dict_parameters or ()
]
self._extra_froms = ef
self.is_multitable = mt = ef
-
self.include_table_with_column_exprs = bool(
mt and compiler.render_table_with_column_in_update_from
)
_return_defaults_columns: Optional[
Tuple[_ColumnsClauseElement, ...]
] = None
+ _supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None
_returning: Tuple[_ColumnsClauseElement, ...] = ()
is_dml = True
self._validate_dialect_kwargs(opt)
return self
+ @_generative
+ def return_defaults(
+ self: SelfUpdateBase,
+ *cols: _DMLColumnArgument,
+ supplemental_cols: Optional[Iterable[_DMLColumnArgument]] = None,
+ ) -> SelfUpdateBase:
+ """Make use of a :term:`RETURNING` clause for the purpose
+ of fetching server-side expressions and defaults, for supporting
+ backends only.
+
+ .. deepalchemy::
+
+ The :meth:`.UpdateBase.return_defaults` method is used by the ORM
+ for its internal work in fetching newly generated primary key
+ and server default values, in particular to provide the underyling
+ implementation of the :paramref:`_orm.Mapper.eager_defaults`
+ ORM feature as well as to allow RETURNING support with bulk
+ ORM inserts. Its behavior is fairly idiosyncratic
+ and is not really intended for general use. End users should
+ stick with using :meth:`.UpdateBase.returning` in order to
+ add RETURNING clauses to their INSERT, UPDATE and DELETE
+ statements.
+
+ Normally, a single row INSERT statement will automatically populate the
+ :attr:`.CursorResult.inserted_primary_key` attribute when executed,
+ which stores the primary key of the row that was just inserted in the
+ form of a :class:`.Row` object with column names as named tuple keys
+ (and the :attr:`.Row._mapping` view fully populated as well). The
+ dialect in use chooses the strategy to use in order to populate this
+ data; if it was generated using server-side defaults and / or SQL
+ expressions, dialect-specific approaches such as ``cursor.lastrowid``
+ or ``RETURNING`` are typically used to acquire the new primary key
+ value.
+
+ However, when the statement is modified by calling
+ :meth:`.UpdateBase.return_defaults` before executing the statement,
+ additional behaviors take place **only** for backends that support
+ RETURNING and for :class:`.Table` objects that maintain the
+ :paramref:`.Table.implicit_returning` parameter at its default value of
+ ``True``. In these cases, when the :class:`.CursorResult` is returned
+ from the statement's execution, not only will
+ :attr:`.CursorResult.inserted_primary_key` be populated as always, the
+ :attr:`.CursorResult.returned_defaults` attribute will also be
+ populated with a :class:`.Row` named-tuple representing the full range
+ of server generated
+ values from that single row, including values for any columns that
+ specify :paramref:`_schema.Column.server_default` or which make use of
+ :paramref:`_schema.Column.default` using a SQL expression.
+
+ When invoking INSERT statements with multiple rows using
+ :ref:`insertmanyvalues <engine_insertmanyvalues>`, the
+ :meth:`.UpdateBase.return_defaults` modifier will have the effect of
+ the :attr:`_engine.CursorResult.inserted_primary_key_rows` and
+ :attr:`_engine.CursorResult.returned_defaults_rows` attributes being
+ fully populated with lists of :class:`.Row` objects representing newly
+ inserted primary key values as well as newly inserted server generated
+ values for each row inserted. The
+ :attr:`.CursorResult.inserted_primary_key` and
+ :attr:`.CursorResult.returned_defaults` attributes will also continue
+ to be populated with the first row of these two collections.
+
+ If the backend does not support RETURNING or the :class:`.Table` in use
+ has disabled :paramref:`.Table.implicit_returning`, then no RETURNING
+ clause is added and no additional data is fetched, however the
+ INSERT, UPDATE or DELETE statement proceeds normally.
+
+ E.g.::
+
+ stmt = table.insert().values(data='newdata').return_defaults()
+
+ result = connection.execute(stmt)
+
+ server_created_at = result.returned_defaults['created_at']
+
+ When used against an UPDATE statement
+ :meth:`.UpdateBase.return_defaults` instead looks for columns that
+ include :paramref:`_schema.Column.onupdate` or
+ :paramref:`_schema.Column.server_onupdate` parameters assigned, when
+ constructing the columns that will be included in the RETURNING clause
+ by default if explicit columns were not specified. When used against a
+ DELETE statement, no columns are included in RETURNING by default, they
+ instead must be specified explicitly as there are no columns that
+ normally change values when a DELETE statement proceeds.
+
+ .. versionadded:: 2.0 :meth:`.UpdateBase.return_defaults` is supported
+ for DELETE statements also and has been moved from
+ :class:`.ValuesBase` to :class:`.UpdateBase`.
+
+ The :meth:`.UpdateBase.return_defaults` method is mutually exclusive
+ against the :meth:`.UpdateBase.returning` method and errors will be
+ raised during the SQL compilation process if both are used at the same
+ time on one statement. The RETURNING clause of the INSERT, UPDATE or
+ DELETE statement is therefore controlled by only one of these methods
+ at a time.
+
+ The :meth:`.UpdateBase.return_defaults` method differs from
+ :meth:`.UpdateBase.returning` in these ways:
+
+ 1. :meth:`.UpdateBase.return_defaults` method causes the
+ :attr:`.CursorResult.returned_defaults` collection to be populated
+ with the first row from the RETURNING result. This attribute is not
+ populated when using :meth:`.UpdateBase.returning`.
+
+ 2. :meth:`.UpdateBase.return_defaults` is compatible with existing
+ logic used to fetch auto-generated primary key values that are then
+ populated into the :attr:`.CursorResult.inserted_primary_key`
+ attribute. By contrast, using :meth:`.UpdateBase.returning` will
+ have the effect of the :attr:`.CursorResult.inserted_primary_key`
+ attribute being left unpopulated.
+
+ 3. :meth:`.UpdateBase.return_defaults` can be called against any
+ backend. Backends that don't support RETURNING will skip the usage
+ of the feature, rather than raising an exception. The return value
+ of :attr:`_engine.CursorResult.returned_defaults` will be ``None``
+ for backends that don't support RETURNING or for which the target
+ :class:`.Table` sets :paramref:`.Table.implicit_returning` to
+ ``False``.
+
+ 4. An INSERT statement invoked with executemany() is supported if the
+ backend database driver supports the
+ :ref:`insertmanyvalues <engine_insertmanyvalues>`
+ feature which is now supported by most SQLAlchemy-included backends.
+ When executemany is used, the
+ :attr:`_engine.CursorResult.returned_defaults_rows` and
+ :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
+ will return the inserted defaults and primary keys.
+
+ .. versionadded:: 1.4 Added
+ :attr:`_engine.CursorResult.returned_defaults_rows` and
+ :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors.
+ In version 2.0, the underlying implementation which fetches and
+ populates the data for these attributes was generalized to be
+ supported by most backends, whereas in 1.4 they were only
+ supported by the ``psycopg2`` driver.
+
+
+ :param cols: optional list of column key names or
+ :class:`_schema.Column` that acts as a filter for those columns that
+ will be fetched.
+ :param supplemental_cols: optional list of RETURNING expressions,
+ in the same form as one would pass to the
+ :meth:`.UpdateBase.returning` method. When present, the additional
+ columns will be included in the RETURNING clause, and the
+ :class:`.CursorResult` object will be "rewound" when returned, so
+ that methods like :meth:`.CursorResult.all` will return new rows
+ mostly as though the statement used :meth:`.UpdateBase.returning`
+ directly. However, unlike when using :meth:`.UpdateBase.returning`
+ directly, the **order of the columns is undefined**, so can only be
+ targeted using names or :attr:`.Row._mapping` keys; they cannot
+ reliably be targeted positionally.
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :meth:`.UpdateBase.returning`
+
+ :attr:`_engine.CursorResult.returned_defaults`
+
+ :attr:`_engine.CursorResult.returned_defaults_rows`
+
+ :attr:`_engine.CursorResult.inserted_primary_key`
+
+ :attr:`_engine.CursorResult.inserted_primary_key_rows`
+
+ """
+
+ if self._return_defaults:
+ # note _return_defaults_columns = () means return all columns,
+ # so if we have been here before, only update collection if there
+ # are columns in the collection
+ if self._return_defaults_columns and cols:
+ self._return_defaults_columns = tuple(
+ util.OrderedSet(self._return_defaults_columns).union(
+ coercions.expect(roles.ColumnsClauseRole, c)
+ for c in cols
+ )
+ )
+ else:
+ # set for all columns
+ self._return_defaults_columns = ()
+ else:
+ self._return_defaults_columns = tuple(
+ coercions.expect(roles.ColumnsClauseRole, c) for c in cols
+ )
+ self._return_defaults = True
+
+ if supplemental_cols:
+ # uniquifying while also maintaining order (the maintain of order
+ # is for test suites but also for vertical splicing
+ supplemental_col_tup = (
+ coercions.expect(roles.ColumnsClauseRole, c)
+ for c in supplemental_cols
+ )
+
+ if self._supplemental_returning is None:
+ self._supplemental_returning = tuple(
+ util.unique_list(supplemental_col_tup)
+ )
+ else:
+ self._supplemental_returning = tuple(
+ util.unique_list(
+ self._supplemental_returning
+ + tuple(supplemental_col_tup)
+ )
+ )
+
+ return self
+
@_generative
def returning(
self, *cols: _ColumnsClauseArgument[Any], **__kw: Any
.. seealso::
- :meth:`.ValuesBase.return_defaults` - an alternative method tailored
+ :meth:`.UpdateBase.return_defaults` - an alternative method tailored
towards efficient fetching of server-side defaults and triggers
for single-row INSERTs or UPDATEs.
_select_names: Optional[List[str]] = None
_inline: bool = False
- _returning: Tuple[_ColumnsClauseElement, ...] = ()
def __init__(self, table: _DMLTableArgument):
self.table = coercions.expect(
)
elif isinstance(arg, collections_abc.Sequence):
- if arg and isinstance(arg[0], (list, dict, tuple)):
+
+ if arg and isinstance(arg[0], dict):
+ multi_kv_generator = DMLState.get_plugin_class(
+ self
+ )._get_multi_crud_kv_pairs
+ self._multi_values += (multi_kv_generator(self, arg),)
+ return self
+
+ if arg and isinstance(arg[0], (list, tuple)):
self._multi_values += (arg,)
return self
# and ensures they get the "crud"-style name when rendered.
kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
- coerced_arg = {k: v for k, v in kv_generator(self, arg.items())}
+ coerced_arg = dict(kv_generator(self, arg.items(), True))
if self._values:
self._values = self._values.union(coerced_arg)
else:
self._values = util.immutabledict(coerced_arg)
return self
- @_generative
- def return_defaults(
- self: SelfValuesBase, *cols: _DMLColumnArgument
- ) -> SelfValuesBase:
- """Make use of a :term:`RETURNING` clause for the purpose
- of fetching server-side expressions and defaults, for supporting
- backends only.
-
- .. tip::
-
- The :meth:`.ValuesBase.return_defaults` method is used by the ORM
- for its internal work in fetching newly generated primary key
- and server default values, in particular to provide the underyling
- implementation of the :paramref:`_orm.Mapper.eager_defaults`
- ORM feature. Its behavior is fairly idiosyncratic
- and is not really intended for general use. End users should
- stick with using :meth:`.UpdateBase.returning` in order to
- add RETURNING clauses to their INSERT, UPDATE and DELETE
- statements.
-
- Normally, a single row INSERT statement will automatically populate the
- :attr:`.CursorResult.inserted_primary_key` attribute when executed,
- which stores the primary key of the row that was just inserted in the
- form of a :class:`.Row` object with column names as named tuple keys
- (and the :attr:`.Row._mapping` view fully populated as well). The
- dialect in use chooses the strategy to use in order to populate this
- data; if it was generated using server-side defaults and / or SQL
- expressions, dialect-specific approaches such as ``cursor.lastrowid``
- or ``RETURNING`` are typically used to acquire the new primary key
- value.
-
- However, when the statement is modified by calling
- :meth:`.ValuesBase.return_defaults` before executing the statement,
- additional behaviors take place **only** for backends that support
- RETURNING and for :class:`.Table` objects that maintain the
- :paramref:`.Table.implicit_returning` parameter at its default value of
- ``True``. In these cases, when the :class:`.CursorResult` is returned
- from the statement's execution, not only will
- :attr:`.CursorResult.inserted_primary_key` be populated as always, the
- :attr:`.CursorResult.returned_defaults` attribute will also be
- populated with a :class:`.Row` named-tuple representing the full range
- of server generated
- values from that single row, including values for any columns that
- specify :paramref:`_schema.Column.server_default` or which make use of
- :paramref:`_schema.Column.default` using a SQL expression.
-
- When invoking INSERT statements with multiple rows using
- :ref:`insertmanyvalues <engine_insertmanyvalues>`, the
- :meth:`.ValuesBase.return_defaults` modifier will have the effect of
- the :attr:`_engine.CursorResult.inserted_primary_key_rows` and
- :attr:`_engine.CursorResult.returned_defaults_rows` attributes being
- fully populated with lists of :class:`.Row` objects representing newly
- inserted primary key values as well as newly inserted server generated
- values for each row inserted. The
- :attr:`.CursorResult.inserted_primary_key` and
- :attr:`.CursorResult.returned_defaults` attributes will also continue
- to be populated with the first row of these two collections.
-
- If the backend does not support RETURNING or the :class:`.Table` in use
- has disabled :paramref:`.Table.implicit_returning`, then no RETURNING
- clause is added and no additional data is fetched, however the
- INSERT or UPDATE statement proceeds normally.
-
-
- E.g.::
-
- stmt = table.insert().values(data='newdata').return_defaults()
-
- result = connection.execute(stmt)
-
- server_created_at = result.returned_defaults['created_at']
-
-
- The :meth:`.ValuesBase.return_defaults` method is mutually exclusive
- against the :meth:`.UpdateBase.returning` method and errors will be
- raised during the SQL compilation process if both are used at the same
- time on one statement. The RETURNING clause of the INSERT or UPDATE
- statement is therefore controlled by only one of these methods at a
- time.
-
- The :meth:`.ValuesBase.return_defaults` method differs from
- :meth:`.UpdateBase.returning` in these ways:
-
- 1. :meth:`.ValuesBase.return_defaults` method causes the
- :attr:`.CursorResult.returned_defaults` collection to be populated
- with the first row from the RETURNING result. This attribute is not
- populated when using :meth:`.UpdateBase.returning`.
-
- 2. :meth:`.ValuesBase.return_defaults` is compatible with existing
- logic used to fetch auto-generated primary key values that are then
- populated into the :attr:`.CursorResult.inserted_primary_key`
- attribute. By contrast, using :meth:`.UpdateBase.returning` will
- have the effect of the :attr:`.CursorResult.inserted_primary_key`
- attribute being left unpopulated.
-
- 3. :meth:`.ValuesBase.return_defaults` can be called against any
- backend. Backends that don't support RETURNING will skip the usage
- of the feature, rather than raising an exception. The return value
- of :attr:`_engine.CursorResult.returned_defaults` will be ``None``
- for backends that don't support RETURNING or for which the target
- :class:`.Table` sets :paramref:`.Table.implicit_returning` to
- ``False``.
-
- 4. An INSERT statement invoked with executemany() is supported if the
- backend database driver supports the
- :ref:`insertmanyvalues <engine_insertmanyvalues>`
- feature which is now supported by most SQLAlchemy-included backends.
- When executemany is used, the
- :attr:`_engine.CursorResult.returned_defaults_rows` and
- :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors
- will return the inserted defaults and primary keys.
-
- .. versionadded:: 1.4 Added
- :attr:`_engine.CursorResult.returned_defaults_rows` and
- :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors.
- In version 2.0, the underlying implementation which fetches and
- populates the data for these attributes was generalized to be
- supported by most backends, whereas in 1.4 they were only
- supported by the ``psycopg2`` driver.
-
-
- :param cols: optional list of column key names or
- :class:`_schema.Column` that acts as a filter for those columns that
- will be fetched.
-
- .. seealso::
-
- :meth:`.UpdateBase.returning`
-
- :attr:`_engine.CursorResult.returned_defaults`
-
- :attr:`_engine.CursorResult.returned_defaults_rows`
-
- :attr:`_engine.CursorResult.inserted_primary_key`
-
- :attr:`_engine.CursorResult.inserted_primary_key_rows`
-
- """
-
- if self._return_defaults:
- # note _return_defaults_columns = () means return all columns,
- # so if we have been here before, only update collection if there
- # are columns in the collection
- if self._return_defaults_columns and cols:
- self._return_defaults_columns = tuple(
- set(self._return_defaults_columns).union(
- coercions.expect(roles.ColumnsClauseRole, c)
- for c in cols
- )
- )
- else:
- # set for all columns
- self._return_defaults_columns = ()
- else:
- self._return_defaults_columns = tuple(
- coercions.expect(roles.ColumnsClauseRole, c) for c in cols
- )
- self._return_defaults = True
- return self
-
SelfInsert = typing.TypeVar("SelfInsert", bound="Insert")
)
kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
- self._ordered_values = kv_generator(self, args)
+ self._ordered_values = kv_generator(self, args, True)
return self
@_generative
class CompiledSQL(SQLMatchRule):
def __init__(
- self, statement, params=None, dialect="default", enable_returning=False
+ self, statement, params=None, dialect="default", enable_returning=True
):
self.statement = statement
self.params = params
dialect.insert_returning = (
dialect.update_returning
) = dialect.delete_returning = True
+ dialect.use_insertmanyvalues = True
+ dialect.supports_multivalues_insert = True
+ dialect.update_returning_multifrom = True
+ dialect.delete_returning_multifrom = True
+ # dialect.favor_returning_over_lastrowid = True
+ # dialect.insert_null_pk_still_autoincrements = True
+
+ # this is calculated but we need it to be True for this
+ # to look like all the current RETURNING dialects
+ assert dialect.insert_executemany_returning
+
return dialect
else:
return url.URL.create(self.dialect).get_dialect()()
from .util import drop_all_tables_from_metadata
from .. import event
from .. import util
-from ..orm import declarative_base
from ..orm import DeclarativeBase
from ..orm import MappedAsDataclass
from ..orm import registry
metadata=metadata,
type_annotation_map={
str: sa.String().with_variant(
- sa.String(50), "mysql", "mariadb"
+ sa.String(50), "mysql", "mariadb", "oracle"
)
},
)
metadata = _md
type_annotation_map = {
str: sa.String().with_variant(
- sa.String(50), "mysql", "mariadb"
+ sa.String(50), "mysql", "mariadb", "oracle"
)
}
def _with_register_classes(cls, fn):
cls_registry = cls.classes
- class DeclarativeBasic:
+ class _DeclBase(DeclarativeBase):
__table_cls__ = schema.Table
+ metadata = cls._tables_metadata
+ type_annotation_map = {
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb", "oracle"
+ )
+ }
- def __init_subclass__(cls) -> None:
+ def __init_subclass__(cls, **kw) -> None:
assert cls_registry is not None
cls_registry[cls.__name__] = cls
- super().__init_subclass__()
-
- _DeclBase = declarative_base(
- metadata=cls._tables_metadata,
- cls=DeclarativeBasic,
- )
+ super().__init_subclass__(**kw)
cls.DeclarativeBasic = _DeclBase
eq_(r.rowcount, 3)
@testing.requires.update_returning
- @testing.requires.sane_rowcount_w_returning
def test_update_rowcount_return_defaults(self, connection):
+ """note this test should succeed for all RETURNING backends
+ as of 2.0. In
+ Idf28379f8705e403a3c6a937f6a798a042ef2540 we changed rowcount to use
+ len(rows) when we have implicit returning
+
+ """
employees_table = self.tables.employees
department = employees_table.c.department
from itertools import filterfalse
from typing import AbstractSet
from typing import Any
+from typing import Callable
from typing import cast
from typing import Collection
from typing import Dict
return "%s(%r)" % (type(self).__name__, list(self._members.values()))
-def unique_list(seq, hashfunc=None):
+def unique_list(
+ seq: Iterable[_T], hashfunc: Optional[Callable[[_T], int]] = None
+) -> List[_T]:
seen: Set[Any] = set()
seen_add = seen.add
if not hashfunc:
t = get_tokyo(sess2)
eq_(t.city, tokyo.city)
- def test_bulk_update_synchronize_evaluate(self):
+ @testing.combinations(
+ "fetch", "evaluate", "auto", argnames="synchronize_session"
+ )
+ @testing.combinations(True, False, argnames="legacy")
+ def test_orm_update_synchronize(self, synchronize_session, legacy):
sess = self._fixture_data()
eq_(
temps = sess.query(Report).all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
- sess.query(Report).filter(Report.temperature >= 80).update(
- {"temperature": Report.temperature + 6},
- synchronize_session="evaluate",
- )
-
- eq_(
- set(row.temperature for row in sess.query(Report.temperature)),
- {86.0, 75.0, 91.0},
- )
-
- # test synchronize session as well
- eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
-
- def test_bulk_update_synchronize_fetch(self):
- sess = self._fixture_data()
-
- eq_(
- set(row.temperature for row in sess.query(Report.temperature)),
- {80.0, 75.0, 85.0},
- )
+ if legacy:
+ sess.query(Report).filter(Report.temperature >= 80).update(
+ {"temperature": Report.temperature + 6},
+ synchronize_session=synchronize_session,
+ )
+ else:
+ sess.execute(
+ update(Report)
+ .filter(Report.temperature >= 80)
+ .values(temperature=Report.temperature + 6)
+ .execution_options(synchronize_session=synchronize_session)
+ )
- temps = sess.query(Report).all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+ # test synchronize session
+ def go():
+ eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
- sess.query(Report).filter(Report.temperature >= 80).update(
- {"temperature": Report.temperature + 6},
- synchronize_session="fetch",
+ self.assert_sql_count(
+ sess._ShardedSession__binds["north_america"], go, 0
)
eq_(
{86.0, 75.0, 91.0},
)
- # test synchronize session as well
- eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
-
- def test_bulk_delete_synchronize_evaluate(self):
- sess = self._fixture_data()
-
- temps = sess.query(Report).all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
- sess.query(Report).filter(Report.temperature >= 80).delete(
- synchronize_session="evaluate"
- )
-
- eq_(
- set(row.temperature for row in sess.query(Report.temperature)),
- {75.0},
- )
-
- # test synchronize session as well
- for t in temps:
- assert inspect(t).deleted is (t.temperature >= 80)
-
- def test_bulk_delete_synchronize_fetch(self):
+ @testing.combinations(
+ "fetch", "evaluate", "auto", argnames="synchronize_session"
+ )
+ @testing.combinations(True, False, argnames="legacy")
+ def test_orm_delete_synchronize(self, synchronize_session, legacy):
sess = self._fixture_data()
temps = sess.query(Report).all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
- sess.query(Report).filter(Report.temperature >= 80).delete(
- synchronize_session="fetch"
- )
-
- eq_(
- set(row.temperature for row in sess.query(Report.temperature)),
- {75.0},
- )
-
- # test synchronize session as well
- for t in temps:
- assert inspect(t).deleted is (t.temperature >= 80)
-
- def test_bulk_update_future_synchronize_evaluate(self):
- sess = self._fixture_data()
-
- eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
- {80.0, 75.0, 85.0},
- )
-
- temps = sess.execute(select(Report)).scalars().all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
- sess.execute(
- update(Report)
- .filter(Report.temperature >= 80)
- .values(
- {"temperature": Report.temperature + 6},
+ if legacy:
+ sess.query(Report).filter(Report.temperature >= 80).delete(
+ synchronize_session=synchronize_session
)
- .execution_options(synchronize_session="evaluate")
- )
-
- eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
- {86.0, 75.0, 91.0},
- )
-
- # test synchronize session as well
- eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
-
- def test_bulk_update_future_synchronize_fetch(self):
- sess = self._fixture_data()
-
- eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
- {80.0, 75.0, 85.0},
- )
-
- temps = sess.execute(select(Report)).scalars().all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
- # MARKMARK
- # omitting the criteria so that the UPDATE affects three out of
- # four shards
- sess.execute(
- update(Report)
- .values(
- {"temperature": Report.temperature + 6},
+ else:
+ sess.execute(
+ delete(Report)
+ .filter(Report.temperature >= 80)
+ .execution_options(synchronize_session=synchronize_session)
)
- .execution_options(synchronize_session="fetch")
- )
-
- eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
- {86.0, 81.0, 91.0},
- )
-
- # test synchronize session as well
- eq_(set(t.temperature for t in temps), {86.0, 81.0, 91.0})
-
- def test_bulk_delete_future_synchronize_evaluate(self):
- sess = self._fixture_data()
-
- temps = sess.execute(select(Report)).scalars().all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
-
- sess.execute(
- delete(Report)
- .filter(Report.temperature >= 80)
- .execution_options(synchronize_session="evaluate")
- )
- eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
- {75.0},
- )
-
- # test synchronize session as well
- for t in temps:
- assert inspect(t).deleted is (t.temperature >= 80)
-
- def test_bulk_delete_future_synchronize_fetch(self):
- sess = self._fixture_data()
-
- temps = sess.execute(select(Report)).scalars().all()
- eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+ def go():
+ # test synchronize session
+ for t in temps:
+ assert inspect(t).deleted is (t.temperature >= 80)
- sess.execute(
- delete(Report)
- .filter(Report.temperature >= 80)
- .execution_options(synchronize_session="fetch")
+ self.assert_sql_count(
+ sess._ShardedSession__binds["north_america"], go, 0
)
eq_(
- set(
- row.temperature
- for row in sess.execute(select(Report.temperature))
- ),
+ set(row.temperature for row in sess.query(Report.temperature)),
{75.0},
)
- # test synchronize session as well
- for t in temps:
- assert inspect(t).deleted is (t.temperature >= 80)
-
class DistinctEngineShardTest(ShardTest, fixtures.MappedTest):
def _init_dbs(self):
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import func
+from sqlalchemy import insert
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
params={"first_name": "Dr."},
)
- def test_update_expr(self):
+ @testing.combinations("attr", "str", "kwarg", argnames="keytype")
+ def test_update_expr(self, keytype):
Person = self.classes.Person
- statement = update(Person).values({Person.name: "Dr. No"})
+ if keytype == "attr":
+ statement = update(Person).values({Person.name: "Dr. No"})
+ elif keytype == "str":
+ statement = update(Person).values({"name": "Dr. No"})
+ elif keytype == "kwarg":
+ statement = update(Person).values(name="Dr. No")
+ else:
+ assert False
self.assert_compile(
statement,
"UPDATE person SET first_name=:first_name, last_name=:last_name",
- params={"first_name": "Dr.", "last_name": "No"},
+ checkparams={"first_name": "Dr.", "last_name": "No"},
+ )
+
+ @testing.combinations("attr", "str", "kwarg", argnames="keytype")
+ def test_insert_expr(self, keytype):
+ Person = self.classes.Person
+
+ if keytype == "attr":
+ statement = insert(Person).values({Person.name: "Dr. No"})
+ elif keytype == "str":
+ statement = insert(Person).values({"name": "Dr. No"})
+ elif keytype == "kwarg":
+ statement = insert(Person).values(name="Dr. No")
+ else:
+ assert False
+
+ self.assert_compile(
+ statement,
+ "INSERT INTO person (first_name, last_name) VALUES "
+ "(:first_name, :last_name)",
+ checkparams={"first_name": "Dr.", "last_name": "No"},
)
# these tests all run two UPDATES to assert that caching is not
from sqlalchemy import FetchedValue
from sqlalchemy import ForeignKey
+from sqlalchemy import Identity
+from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
+from sqlalchemy import update
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
+ __backend__ = True
+
@classmethod
def define_tables(cls, metadata):
Table(
class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):
+ __backend__ = True
+
@classmethod
def setup_mappers(cls):
User, Address, Order = cls.classes("User", "Address", "Order")
cls.mapper_registry.map_imperatively(Address, a)
cls.mapper_registry.map_imperatively(Order, o)
- def test_bulk_save_return_defaults(self):
+ @testing.combinations("save_objects", "insert_mappings", "insert_stmt")
+ def test_bulk_save_return_defaults(self, statement_type):
(User,) = self.classes("User")
s = fixture_session()
- objects = [User(name="u1"), User(name="u2"), User(name="u3")]
- assert "id" not in objects[0].__dict__
- with self.sql_execution_asserter() as asserter:
- s.bulk_save_objects(objects, return_defaults=True)
+ if statement_type == "save_objects":
+ objects = [User(name="u1"), User(name="u2"), User(name="u3")]
+ assert "id" not in objects[0].__dict__
+
+ returning_users_id = " RETURNING users.id"
+ with self.sql_execution_asserter() as asserter:
+ s.bulk_save_objects(objects, return_defaults=True)
+ elif statement_type == "insert_mappings":
+ data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
+ returning_users_id = " RETURNING users.id"
+ with self.sql_execution_asserter() as asserter:
+ s.bulk_insert_mappings(User, data, return_defaults=True)
+ elif statement_type == "insert_stmt":
+ data = [dict(name="u1"), dict(name="u2"), dict(name="u3")]
+
+ # for statement, "return_defaults" is heuristic on if we are
+ # a joined inh mapping if we don't otherwise include
+ # .returning() on the statement itself
+ returning_users_id = ""
+ with self.sql_execution_asserter() as asserter:
+ s.execute(insert(User), data)
asserter.assert_(
Conditional(
- testing.db.dialect.insert_executemany_returning,
+ testing.db.dialect.insert_executemany_returning
+ or statement_type == "insert_stmt",
[
CompiledSQL(
- "INSERT INTO users (name) VALUES (:name)",
+ "INSERT INTO users (name) "
+ f"VALUES (:name){returning_users_id}",
[{"name": "u1"}, {"name": "u2"}, {"name": "u3"}],
),
],
],
)
)
- eq_(objects[0].__dict__["id"], 1)
+ if statement_type == "save_objects":
+ eq_(objects[0].__dict__["id"], 1)
def test_bulk_save_mappings_preserve_order(self):
(User,) = self.classes("User")
)
)
- def test_bulk_update(self):
- (User,) = self.classes("User")
+ @testing.combinations("update_mappings", "update_stmt")
+ def test_bulk_update(self, statement_type):
+ User = self.classes.User
s = fixture_session(expire_on_commit=False)
objects = [User(name="u1"), User(name="u2"), User(name="u3")]
s.commit()
s = fixture_session()
- with self.sql_execution_asserter() as asserter:
- s.bulk_update_mappings(
- User,
- [
- {"id": 1, "name": "u1new"},
- {"id": 2, "name": "u2"},
- {"id": 3, "name": "u3new"},
- ],
- )
+ data = [
+ {"id": 1, "name": "u1new"},
+ {"id": 2, "name": "u2"},
+ {"id": 3, "name": "u3new"},
+ ]
+
+ if statement_type == "update_mappings":
+ with self.sql_execution_asserter() as asserter:
+ s.bulk_update_mappings(User, data)
+ elif statement_type == "update_stmt":
+ with self.sql_execution_asserter() as asserter:
+ s.execute(update(User), data)
asserter.assert_(
CompiledSQL(
class BulkUDPostfetchTest(BulkTest, fixtures.MappedTest):
+ __backend__ = True
+
@classmethod
def define_tables(cls, metadata):
Table(
class BulkUDTestAltColKeys(BulkTest, fixtures.MappedTest):
+ __backend__ = True
+
@classmethod
def define_tables(cls, metadata):
Table(
class BulkInheritanceTest(BulkTest, fixtures.MappedTest):
+ __backend__ = True
+
@classmethod
def define_tables(cls, metadata):
Table(
)
s = fixture_session()
+
objects = [
Manager(name="m1", status="s1", manager_name="mn1"),
Engineer(name="e1", status="s2", primary_language="l1"),
[
CompiledSQL(
"INSERT INTO people (name, type) "
- "VALUES (:name, :type)",
+ "VALUES (:name, :type) RETURNING people.person_id",
[
{"type": "engineer", "name": "e1"},
{"type": "engineer", "name": "e2"},
),
)
- def test_bulk_insert_joined_inh_return_defaults(self):
+ @testing.combinations("insert_mappings", "insert_stmt")
+ def test_bulk_insert_joined_inh_return_defaults(self, statement_type):
Person, Engineer, Manager, Boss = self.classes(
"Person", "Engineer", "Manager", "Boss"
)
s = fixture_session()
- with self.sql_execution_asserter() as asserter:
- s.bulk_insert_mappings(
- Boss,
- [
- dict(
- name="b1",
- status="s1",
- manager_name="mn1",
- golf_swing="g1",
- ),
- dict(
- name="b2",
- status="s2",
- manager_name="mn2",
- golf_swing="g2",
- ),
- dict(
- name="b3",
- status="s3",
- manager_name="mn3",
- golf_swing="g3",
- ),
- ],
- return_defaults=True,
- )
+ data = [
+ dict(
+ name="b1",
+ status="s1",
+ manager_name="mn1",
+ golf_swing="g1",
+ ),
+ dict(
+ name="b2",
+ status="s2",
+ manager_name="mn2",
+ golf_swing="g2",
+ ),
+ dict(
+ name="b3",
+ status="s3",
+ manager_name="mn3",
+ golf_swing="g3",
+ ),
+ ]
+
+ if statement_type == "insert_mappings":
+ with self.sql_execution_asserter() as asserter:
+ s.bulk_insert_mappings(
+ Boss,
+ data,
+ return_defaults=True,
+ )
+ elif statement_type == "insert_stmt":
+ with self.sql_execution_asserter() as asserter:
+ s.execute(insert(Boss), data)
asserter.assert_(
Conditional(
testing.db.dialect.insert_executemany_returning,
[
CompiledSQL(
- "INSERT INTO people (name) VALUES (:name)",
- [{"name": "b1"}, {"name": "b2"}, {"name": "b3"}],
+ "INSERT INTO people (name, type) "
+ "VALUES (:name, :type) RETURNING people.person_id",
+ [
+ {"name": "b1", "type": "boss"},
+ {"name": "b2", "type": "boss"},
+ {"name": "b3", "type": "boss"},
+ ],
),
],
[
CompiledSQL(
- "INSERT INTO people (name) VALUES (:name)",
- [{"name": "b1"}],
+ "INSERT INTO people (name, type) "
+ "VALUES (:name, :type)",
+ [{"name": "b1", "type": "boss"}],
),
CompiledSQL(
- "INSERT INTO people (name) VALUES (:name)",
- [{"name": "b2"}],
+ "INSERT INTO people (name, type) "
+ "VALUES (:name, :type)",
+ [{"name": "b2", "type": "boss"}],
),
CompiledSQL(
- "INSERT INTO people (name) VALUES (:name)",
- [{"name": "b3"}],
+ "INSERT INTO people (name, type) "
+ "VALUES (:name, :type)",
+ [{"name": "b3", "type": "boss"}],
),
],
),
),
)
+ @testing.combinations("update_mappings", "update_stmt")
+ def test_bulk_update(self, statement_type):
+ Person, Engineer, Manager, Boss = self.classes(
+ "Person", "Engineer", "Manager", "Boss"
+ )
+
+ s = fixture_session()
+
+ b1, b2, b3 = (
+ Boss(name="b1", status="s1", manager_name="mn1", golf_swing="g1"),
+ Boss(name="b2", status="s2", manager_name="mn2", golf_swing="g2"),
+ Boss(name="b3", status="s3", manager_name="mn3", golf_swing="g3"),
+ )
+ s.add_all([b1, b2, b3])
+ s.commit()
+
+ # slight non-convenient thing. we have to fill in boss_id here
+ # for update, this is not sent along automatically. this is not a
+ # new behavior in bulk
+ new_data = [
+ {
+ "person_id": b1.person_id,
+ "boss_id": b1.boss_id,
+ "name": "b1_updated",
+ "manager_name": "mn1_updated",
+ },
+ {
+ "person_id": b3.person_id,
+ "boss_id": b3.boss_id,
+ "manager_name": "mn2_updated",
+ "golf_swing": "g1_updated",
+ },
+ ]
+
+ if statement_type == "update_mappings":
+ with self.sql_execution_asserter() as asserter:
+ s.bulk_update_mappings(Boss, new_data)
+ elif statement_type == "update_stmt":
+ with self.sql_execution_asserter() as asserter:
+ s.execute(update(Boss), new_data)
+
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE people SET name=:name WHERE "
+ "people.person_id = :people_person_id",
+ [{"name": "b1_updated", "people_person_id": 1}],
+ ),
+ CompiledSQL(
+ "UPDATE managers SET manager_name=:manager_name WHERE "
+ "managers.person_id = :managers_person_id",
+ [
+ {"manager_name": "mn1_updated", "managers_person_id": 1},
+ {"manager_name": "mn2_updated", "managers_person_id": 3},
+ ],
+ ),
+ CompiledSQL(
+ "UPDATE boss SET golf_swing=:golf_swing WHERE "
+ "boss.boss_id = :boss_boss_id",
+ [{"golf_swing": "g1_updated", "boss_boss_id": 3}],
+ ),
+ )
+
class BulkIssue6793Test(BulkTest, fixtures.DeclarativeMappedTest):
+ __backend__ = True
+
@classmethod
def setup_classes(cls):
Base = cls.DeclarativeBasic
class User(Base):
__tablename__ = "users"
- id = Column(Integer, primary_key=True)
+ id = Column(Integer, Identity(), primary_key=True)
name = Column(String(255), nullable=False)
def test_issue_6793(self):
[{"name": "A"}, {"name": "B"}],
),
CompiledSQL(
- "INSERT INTO users (name) VALUES (:name)",
+ "INSERT INTO users (name) VALUES (:name) "
+ "RETURNING users.id",
[{"name": "C"}, {"name": "D"}],
),
],
--- /dev/null
+from __future__ import annotations
+
+from typing import Any
+from typing import List
+from typing import Optional
+import uuid
+
+from sqlalchemy import exc
+from sqlalchemy import ForeignKey
+from sqlalchemy import func
+from sqlalchemy import Identity
+from sqlalchemy import insert
+from sqlalchemy import inspect
+from sqlalchemy import literal_column
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy import update
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import load_only
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.testing import config
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import mock
+from sqlalchemy.testing import provision
+from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.fixtures import fixture_session
+
+
+class NoReturningTest(fixtures.TestBase):
+ def test_no_returning_error(self, decl_base):
+ class A(fixtures.ComparableEntity, decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+
+ decl_base.metadata.create_all(testing.db)
+ s = fixture_session()
+
+ if testing.requires.insert_executemany_returning.enabled:
+ result = s.scalars(
+ insert(A).returning(A),
+ [
+ {"data": "d3", "x": 5},
+ {"data": "d4", "x": 6},
+ ],
+ )
+ eq_(result.all(), [A(data="d3", x=5), A(data="d4", x=6)])
+
+ else:
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "Can't use explicit RETURNING for bulk INSERT operation",
+ ):
+ s.scalars(
+ insert(A).returning(A),
+ [
+ {"data": "d3", "x": 5},
+ {"data": "d4", "x": 6},
+ ],
+ )
+
+ def test_omit_returning_ok(self, decl_base):
+ class A(decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+
+ decl_base.metadata.create_all(testing.db)
+ s = fixture_session()
+
+ s.execute(
+ insert(A),
+ [
+ {"data": "d3", "x": 5},
+ {"data": "d4", "x": 6},
+ ],
+ )
+ eq_(
+ s.execute(select(A.data, A.x).order_by(A.id)).all(),
+ [("d3", 5), ("d4", 6)],
+ )
+
+
+class BulkDMLReturningInhTest:
+ def test_insert_col_key_also_works_currently(self):
+ """using the column key, not mapped attr key.
+
+ right now this passes through to the INSERT. when doing this with
+ an UPDATE, it tends to fail because the synchronize session
+ strategies can't match "xcol" back. however w/ INSERT we aren't
+ doing that, so there's no place this gets checked. UPDATE also
+ succeeds if synchronize_session is turned off.
+
+ """
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+ s.execute(insert(A).values(type="a", data="d", xcol=10))
+ eq_(s.scalars(select(A.x)).all(), [10])
+
+ @testing.combinations(True, False, argnames="use_returning")
+ def test_heterogeneous_keys(self, use_returning):
+ A, B = self.classes("A", "B")
+
+ values = [
+ {"data": "d3", "x": 5, "type": "a"},
+ {"data": "d4", "x": 6, "type": "a"},
+ {"data": "d5", "type": "a"},
+ {"data": "d6", "x": 8, "y": 9, "type": "a"},
+ {"data": "d7", "x": 12, "y": 12, "type": "a"},
+ {"data": "d8", "x": 7, "type": "a"},
+ ]
+
+ s = fixture_session()
+
+ stmt = insert(A)
+ if use_returning:
+ stmt = stmt.returning(A)
+
+ with self.sql_execution_asserter() as asserter:
+ result = s.execute(stmt, values)
+
+ if inspect(B).single:
+ single_inh = ", a.bd, a.zcol, a.q"
+ else:
+ single_inh = ""
+
+ if use_returning:
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol) VALUES "
+ "(:type, :data, :xcol) "
+ f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+ [
+ {"type": "a", "data": "d3", "xcol": 5},
+ {"type": "a", "data": "d4", "xcol": 6},
+ ],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data) VALUES (:type, :data) "
+ f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+ [{"type": "a", "data": "d5"}],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol, y) "
+ "VALUES (:type, :data, :xcol, :y) "
+ f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+ [
+ {"type": "a", "data": "d6", "xcol": 8, "y": 9},
+ {"type": "a", "data": "d7", "xcol": 12, "y": 12},
+ ],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol) "
+ "VALUES (:type, :data, :xcol) "
+ f"RETURNING a.id, a.type, a.data, a.xcol, a.y{single_inh}",
+ [{"type": "a", "data": "d8", "xcol": 7}],
+ ),
+ )
+ else:
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol) VALUES "
+ "(:type, :data, :xcol)",
+ [
+ {"type": "a", "data": "d3", "xcol": 5},
+ {"type": "a", "data": "d4", "xcol": 6},
+ ],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data) VALUES (:type, :data)",
+ [{"type": "a", "data": "d5"}],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol, y) "
+ "VALUES (:type, :data, :xcol, :y)",
+ [
+ {"type": "a", "data": "d6", "xcol": 8, "y": 9},
+ {"type": "a", "data": "d7", "xcol": 12, "y": 12},
+ ],
+ ),
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol) "
+ "VALUES (:type, :data, :xcol)",
+ [{"type": "a", "data": "d8", "xcol": 7}],
+ ),
+ )
+
+ if use_returning:
+ eq_(
+ result.scalars().all(),
+ [
+ A(data="d3", id=mock.ANY, type="a", x=5, y=None),
+ A(data="d4", id=mock.ANY, type="a", x=6, y=None),
+ A(data="d5", id=mock.ANY, type="a", x=None, y=None),
+ A(data="d6", id=mock.ANY, type="a", x=8, y=9),
+ A(data="d7", id=mock.ANY, type="a", x=12, y=12),
+ A(data="d8", id=mock.ANY, type="a", x=7, y=None),
+ ],
+ )
+
+ @testing.combinations(
+ "strings",
+ "cols",
+ "strings_w_exprs",
+ "cols_w_exprs",
+ argnames="paramstyle",
+ )
+ @testing.combinations(
+ True,
+ (False, testing.requires.multivalues_inserts),
+ argnames="single_element",
+ )
+ def test_single_values_returning_fn(self, paramstyle, single_element):
+ """test using insert().values().
+
+ these INSERT statements go straight in as a single execute without any
+ insertmanyreturning or bulk_insert_mappings thing going on. the
+ advantage here is that SQL expressions can be used in the values also.
+ Disadvantage is none of the automation for inheritance mappers.
+
+ """
+ A, B = self.classes("A", "B")
+
+ if paramstyle == "strings":
+ values = [
+ {"data": "d3", "x": 5, "y": 9, "type": "a"},
+ {"data": "d4", "x": 10, "y": 8, "type": "a"},
+ ]
+ elif paramstyle == "cols":
+ values = [
+ {A.data: "d3", A.x: 5, A.y: 9, A.type: "a"},
+ {A.data: "d4", A.x: 10, A.y: 8, A.type: "a"},
+ ]
+ elif paramstyle == "strings_w_exprs":
+ values = [
+ {"data": func.lower("D3"), "x": 5, "y": 9, "type": "a"},
+ {
+ "data": "d4",
+ "x": literal_column("5") + 5,
+ "y": 8,
+ "type": "a",
+ },
+ ]
+ elif paramstyle == "cols_w_exprs":
+ values = [
+ {A.data: func.lower("D3"), A.x: 5, A.y: 9, A.type: "a"},
+ {
+ A.data: "d4",
+ A.x: literal_column("5") + 5,
+ A.y: 8,
+ A.type: "a",
+ },
+ ]
+ else:
+ assert False
+
+ s = fixture_session()
+
+ if single_element:
+ if paramstyle.startswith("strings"):
+ stmt = (
+ insert(A)
+ .values(**values[0])
+ .returning(A, func.upper(A.data, type_=String))
+ )
+ else:
+ stmt = (
+ insert(A)
+ .values(values[0])
+ .returning(A, func.upper(A.data, type_=String))
+ )
+ else:
+ stmt = (
+ insert(A)
+ .values(values)
+ .returning(A, func.upper(A.data, type_=String))
+ )
+
+ for i in range(3):
+ result = s.execute(stmt)
+ expected: List[Any] = [(A(data="d3", x=5, y=9), "D3")]
+ if not single_element:
+ expected.append((A(data="d4", x=10, y=8), "D4"))
+ eq_(result.all(), expected)
+
+ def test_bulk_w_sql_expressions(self):
+ A, B = self.classes("A", "B")
+
+ data = [
+ {"x": 5, "y": 9, "type": "a"},
+ {
+ "x": 10,
+ "y": 8,
+ "type": "a",
+ },
+ ]
+
+ s = fixture_session()
+
+ stmt = (
+ insert(A)
+ .values(data=func.lower("DD"))
+ .returning(A, func.upper(A.data, type_=String))
+ )
+
+ for i in range(3):
+ result = s.execute(stmt, data)
+ expected: List[Any] = [
+ (A(data="dd", x=5, y=9), "DD"),
+ (A(data="dd", x=10, y=8), "DD"),
+ ]
+ eq_(result.all(), expected)
+
+ def test_bulk_w_sql_expressions_subclass(self):
+ A, B = self.classes("A", "B")
+
+ data = [
+ {"bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+
+ s = fixture_session()
+
+ stmt = (
+ insert(B)
+ .values(data=func.lower("DD"))
+ .returning(B, func.upper(B.data, type_=String))
+ )
+
+ for i in range(3):
+ result = s.execute(stmt, data)
+ expected: List[Any] = [
+ (B(bd="bd1", data="dd", q=4, type="b", x=1, y=2, z=3), "DD"),
+ (B(bd="bd2", data="dd", q=8, type="b", x=5, y=6, z=7), "DD"),
+ ]
+ eq_(result.all(), expected)
+
+ @testing.combinations(True, False, argnames="use_ordered")
+ def test_bulk_upd_w_sql_expressions_no_ordered_values(self, use_ordered):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ stmt = update(B).ordered_values(
+ ("data", func.lower("DD_UPDATE")),
+ ("z", literal_column("3 + 12")),
+ )
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ r"bulk ORM UPDATE does not support ordered_values\(\) "
+ r"for custom UPDATE",
+ ):
+ s.execute(
+ stmt,
+ [
+ {"id": 5, "bd": "bd1_updated"},
+ {"id": 6, "bd": "bd2_updated"},
+ ],
+ )
+
+ def test_bulk_upd_w_sql_expressions_subclass(self):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+ ids = s.scalars(insert(B).returning(B.id), data).all()
+
+ stmt = update(B).values(
+ data=func.lower("DD_UPDATE"), z=literal_column("3 + 12")
+ )
+
+ result = s.execute(
+ stmt,
+ [
+ {"id": ids[0], "bd": "bd1_updated"},
+ {"id": ids[1], "bd": "bd2_updated"},
+ ],
+ )
+
+ # this is a nullresult at the moment
+ assert result is not None
+
+ eq_(
+ s.scalars(select(B)).all(),
+ [
+ B(
+ bd="bd1_updated",
+ data="dd_update",
+ id=ids[0],
+ q=4,
+ type="b",
+ x=1,
+ y=2,
+ z=15,
+ ),
+ B(
+ bd="bd2_updated",
+ data="dd_update",
+ id=ids[1],
+ q=8,
+ type="b",
+ x=5,
+ y=6,
+ z=15,
+ ),
+ ],
+ )
+
+ def test_single_returning_fn(self):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+ for i in range(3):
+ result = s.execute(
+ insert(A).returning(A, func.upper(A.data, type_=String)),
+ [{"data": "d3"}, {"data": "d4"}],
+ )
+ eq_(result.all(), [(A(data="d3"), "D3"), (A(data="d4"), "D4")])
+
+ @testing.combinations(
+ True,
+ False,
+ argnames="single_element",
+ )
+ def test_subclass_no_returning(self, single_element):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ if single_element:
+ data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+ else:
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+
+ result = s.execute(insert(B), data)
+ assert result._soft_closed
+
+ @testing.combinations(
+ True,
+ False,
+ argnames="single_element",
+ )
+ def test_subclass_load_only(self, single_element):
+ """test that load_only() prevents additional attributes from being
+ populated.
+
+ """
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ if single_element:
+ data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+ else:
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+
+ for i in range(3):
+ # tests both caching and that the data dictionaries aren't
+ # mutated...
+ result = s.execute(
+ insert(B).returning(B).options(load_only(B.data, B.y, B.q)),
+ data,
+ )
+ objects = result.scalars().all()
+ for obj in objects:
+ assert "data" in obj.__dict__
+ assert "q" in obj.__dict__
+ assert "z" not in obj.__dict__
+ assert "x" not in obj.__dict__
+
+ expected = [
+ B(data="d3", bd="bd1", x=1, y=2, z=3, q=4),
+ ]
+ if not single_element:
+ expected.append(B(data="d4", bd="bd2", x=5, y=6, z=7, q=8))
+ eq_(objects, expected)
+
+ @testing.combinations(
+ True,
+ False,
+ argnames="single_element",
+ )
+ def test_subclass_load_only_doesnt_fetch_cols(self, single_element):
+ """test that when using load_only(), the actual INSERT statement
+ does not include the deferred columns
+
+ """
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+ if single_element:
+ data = data[0]
+
+ with self.sql_execution_asserter() as asserter:
+
+ # tests both caching and that the data dictionaries aren't
+ # mutated...
+
+ # note that if we don't put B.id here, accessing .id on the
+ # B object for joined inheritance is triggering a SELECT
+ # (and not for single inheritance). this seems not great, but is
+ # likely a different issue
+ result = s.execute(
+ insert(B)
+ .returning(B)
+ .options(load_only(B.id, B.data, B.y, B.q)),
+ data,
+ )
+ objects = result.scalars().all()
+ if single_element:
+ id0 = objects[0].id
+ id1 = None
+ else:
+ id0, id1 = objects[0].id, objects[1].id
+
+ if inspect(B).single or inspect(B).concrete:
+ expected_params = [
+ {
+ "type": "b",
+ "data": "d3",
+ "xcol": 1,
+ "y": 2,
+ "bd": "bd1",
+ "zcol": 3,
+ "q": 4,
+ },
+ {
+ "type": "b",
+ "data": "d4",
+ "xcol": 5,
+ "y": 6,
+ "bd": "bd2",
+ "zcol": 7,
+ "q": 8,
+ },
+ ]
+ if single_element:
+ expected_params[1:] = []
+ # RETURNING only includes PK, discriminator, then the cols
+ # we asked for data, y, q. xcol, z, bd are omitted
+
+ if inspect(B).single:
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol, y, bd, zcol, q) "
+ "VALUES "
+ "(:type, :data, :xcol, :y, :bd, :zcol, :q) "
+ "RETURNING a.id, a.type, a.data, a.y, a.q",
+ expected_params,
+ ),
+ )
+ else:
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO b (type, data, xcol, y, bd, zcol, q) "
+ "VALUES "
+ "(:type, :data, :xcol, :y, :bd, :zcol, :q) "
+ "RETURNING b.id, b.type, b.data, b.y, b.q",
+ expected_params,
+ ),
+ )
+ else:
+ a_data = [
+ {"type": "b", "data": "d3", "xcol": 1, "y": 2},
+ {"type": "b", "data": "d4", "xcol": 5, "y": 6},
+ ]
+ b_data = [
+ {"id": id0, "bd": "bd1", "zcol": 3, "q": 4},
+ {"id": id1, "bd": "bd2", "zcol": 7, "q": 8},
+ ]
+ if single_element:
+ a_data[1:] = []
+ b_data[1:] = []
+ # RETURNING only includes PK, discriminator, then the cols
+ # we asked for data, y, q. xcol, z, bd are omitted. plus they
+ # are broken out correctly in the two statements.
+ asserter.assert_(
+ CompiledSQL(
+ "INSERT INTO a (type, data, xcol, y) VALUES "
+ "(:type, :data, :xcol, :y) "
+ "RETURNING a.id, a.type, a.data, a.y",
+ a_data,
+ ),
+ CompiledSQL(
+ "INSERT INTO b (id, bd, zcol, q) "
+ "VALUES (:id, :bd, :zcol, :q) "
+ "RETURNING b.id, b.q",
+ b_data,
+ ),
+ )
+
+ @testing.combinations(
+ True,
+ False,
+ argnames="single_element",
+ )
+ def test_subclass_returning_bind_expr(self, single_element):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ if single_element:
+ data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+ else:
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+ # note there's a fix in compiler.py ->
+ # _deliver_insertmanyvalues_batches
+ # for this re: the parameter rendering that isn't tested anywhere
+ # else. two different versions of the bug for both positional
+ # and non
+ result = s.execute(insert(B).returning(B.data, B.y, B.q + 5), data)
+ if single_element:
+ eq_(result.all(), [("d3", 2, 9)])
+ else:
+ eq_(result.all(), [("d3", 2, 9), ("d4", 6, 13)])
+
+ def test_subclass_bulk_update(self):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+ ids = s.scalars(insert(B).returning(B.id), data).all()
+
+ result = s.execute(
+ update(B),
+ [
+ {"id": ids[0], "data": "d3_updated", "bd": "bd1_updated"},
+ {"id": ids[1], "data": "d4_updated", "bd": "bd2_updated"},
+ ],
+ )
+
+ # this is a nullresult at the moment
+ assert result is not None
+
+ eq_(
+ s.scalars(select(B)).all(),
+ [
+ B(
+ bd="bd1_updated",
+ data="d3_updated",
+ id=ids[0],
+ q=4,
+ type="b",
+ x=1,
+ y=2,
+ z=3,
+ ),
+ B(
+ bd="bd2_updated",
+ data="d4_updated",
+ id=ids[1],
+ q=8,
+ type="b",
+ x=5,
+ y=6,
+ z=7,
+ ),
+ ],
+ )
+
+ @testing.combinations(True, False, argnames="single_element")
+ def test_subclass_return_just_subclass_ids(self, single_element):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ if single_element:
+ data = {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4}
+ else:
+ data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+
+ ids = s.scalars(insert(B).returning(B.id), data).all()
+ actual_ids = s.scalars(select(B.id).order_by(B.data)).all()
+
+ eq_(ids, actual_ids)
+
+ @testing.combinations(
+ "orm",
+ "bulk",
+ argnames="insert_strategy",
+ )
+ @testing.requires.provisioned_upsert
+ def test_base_class_upsert(self, insert_strategy):
+ """upsert is really tricky. if you dont have any data updated,
+ then you dont get the rows back and things dont work so well.
+
+ so we need to be careful how much we document this because this is
+ still a thorny use case.
+
+ """
+ A = self.classes.A
+
+ s = fixture_session()
+
+ initial_data = [
+ {"data": "d3", "x": 1, "y": 2, "q": 4},
+ {"data": "d4", "x": 5, "y": 6, "q": 8},
+ ]
+ ids = s.scalars(insert(A).returning(A.id), initial_data).all()
+
+ upsert_data = [
+ {
+ "id": ids[0],
+ "type": "a",
+ "data": "d3",
+ "x": 1,
+ "y": 2,
+ },
+ {
+ "id": 32,
+ "type": "a",
+ "data": "d32",
+ "x": 19,
+ "y": 5,
+ },
+ {
+ "id": ids[1],
+ "type": "a",
+ "data": "d4",
+ "x": 5,
+ "y": 6,
+ },
+ {
+ "id": 28,
+ "type": "a",
+ "data": "d28",
+ "x": 9,
+ "y": 15,
+ },
+ ]
+
+ stmt = provision.upsert(
+ config,
+ A,
+ (A,),
+ lambda inserted: {"data": inserted.data + " upserted"},
+ )
+
+ if insert_strategy == "orm":
+ result = s.scalars(stmt.values(upsert_data))
+ elif insert_strategy == "bulk":
+ result = s.scalars(stmt, upsert_data)
+ else:
+ assert False
+
+ eq_(
+ result.all(),
+ [
+ A(data="d3 upserted", id=ids[0], type="a", x=1, y=2),
+ A(data="d32", id=32, type="a", x=19, y=5),
+ A(data="d4 upserted", id=ids[1], type="a", x=5, y=6),
+ A(data="d28", id=28, type="a", x=9, y=15),
+ ],
+ )
+
+ @testing.combinations(
+ "orm",
+ "bulk",
+ argnames="insert_strategy",
+ )
+ @testing.requires.provisioned_upsert
+ def test_subclass_upsert(self, insert_strategy):
+ """note this is overridden in the joined version to expect failure"""
+
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ idd3 = 1
+ idd4 = 2
+ id32 = 32
+ id28 = 28
+
+ initial_data = [
+ {
+ "id": idd3,
+ "data": "d3",
+ "bd": "bd1",
+ "x": 1,
+ "y": 2,
+ "z": 3,
+ "q": 4,
+ },
+ {
+ "id": idd4,
+ "data": "d4",
+ "bd": "bd2",
+ "x": 5,
+ "y": 6,
+ "z": 7,
+ "q": 8,
+ },
+ ]
+ ids = s.scalars(insert(B).returning(B.id), initial_data).all()
+
+ upsert_data = [
+ {
+ "id": ids[0],
+ "type": "b",
+ "data": "d3",
+ "bd": "bd1_upserted",
+ "x": 1,
+ "y": 2,
+ "z": 33,
+ "q": 44,
+ },
+ {
+ "id": id32,
+ "type": "b",
+ "data": "d32",
+ "bd": "bd 32",
+ "x": 19,
+ "y": 5,
+ "z": 20,
+ "q": 21,
+ },
+ {
+ "id": ids[1],
+ "type": "b",
+ "bd": "bd2_upserted",
+ "data": "d4",
+ "x": 5,
+ "y": 6,
+ "z": 77,
+ "q": 88,
+ },
+ {
+ "id": id28,
+ "type": "b",
+ "data": "d28",
+ "bd": "bd 28",
+ "x": 9,
+ "y": 15,
+ "z": 10,
+ "q": 11,
+ },
+ ]
+
+ stmt = provision.upsert(
+ config,
+ B,
+ (B,),
+ lambda inserted: {
+ "data": inserted.data + " upserted",
+ "bd": inserted.bd + " upserted",
+ },
+ )
+ result = s.scalars(stmt, upsert_data)
+ eq_(
+ result.all(),
+ [
+ B(
+ bd="bd1_upserted upserted",
+ data="d3 upserted",
+ id=ids[0],
+ q=4,
+ type="b",
+ x=1,
+ y=2,
+ z=3,
+ ),
+ B(
+ bd="bd 32",
+ data="d32",
+ id=32,
+ q=21,
+ type="b",
+ x=19,
+ y=5,
+ z=20,
+ ),
+ B(
+ bd="bd2_upserted upserted",
+ data="d4 upserted",
+ id=ids[1],
+ q=8,
+ type="b",
+ x=5,
+ y=6,
+ z=7,
+ ),
+ B(
+ bd="bd 28",
+ data="d28",
+ id=28,
+ q=11,
+ type="b",
+ x=9,
+ y=15,
+ z=10,
+ ),
+ ],
+ )
+
+
+class BulkDMLReturningJoinedInhTest(
+ BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
+):
+
+ __requires__ = ("insert_returning",)
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ decl_base = cls.DeclarativeBasic
+
+ class A(fixtures.ComparableEntity, decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ type: Mapped[str]
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+ y: Mapped[Optional[int]]
+
+ __mapper_args__ = {
+ "polymorphic_identity": "a",
+ "polymorphic_on": "type",
+ }
+
+ class B(A):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(
+ ForeignKey("a.id"), primary_key=True
+ )
+ bd: Mapped[str]
+ z: Mapped[Optional[int]] = mapped_column("zcol")
+ q: Mapped[Optional[int]]
+
+ __mapper_args__ = {"polymorphic_identity": "b"}
+
+ @testing.combinations(
+ "orm",
+ "bulk",
+ argnames="insert_strategy",
+ )
+ @testing.combinations(
+ True,
+ False,
+ argnames="single_param",
+ )
+ @testing.requires.provisioned_upsert
+ def test_subclass_upsert(self, insert_strategy, single_param):
+ A, B = self.classes("A", "B")
+
+ s = fixture_session()
+
+ initial_data = [
+ {"data": "d3", "bd": "bd1", "x": 1, "y": 2, "z": 3, "q": 4},
+ {"data": "d4", "bd": "bd2", "x": 5, "y": 6, "z": 7, "q": 8},
+ ]
+ ids = s.scalars(insert(B).returning(B.id), initial_data).all()
+
+ upsert_data = [
+ {
+ "id": ids[0],
+ "type": "b",
+ },
+ {
+ "id": 32,
+ "type": "b",
+ },
+ ]
+ if single_param:
+ upsert_data = upsert_data[0]
+
+ stmt = provision.upsert(
+ config,
+ B,
+ (B,),
+ lambda inserted: {
+ "bd": inserted.bd + " upserted",
+ },
+ )
+
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ r"bulk INSERT with a 'post values' clause \(typically upsert\) "
+ r"not supported for multi-table mapper",
+ ):
+ s.scalars(stmt, upsert_data)
+
+
+class BulkDMLReturningSingleInhTest(
+ BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
+):
+ __requires__ = ("insert_returning",)
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ decl_base = cls.DeclarativeBasic
+
+ class A(fixtures.ComparableEntity, decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ type: Mapped[str]
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+ y: Mapped[Optional[int]]
+
+ __mapper_args__ = {
+ "polymorphic_identity": "a",
+ "polymorphic_on": "type",
+ }
+
+ class B(A):
+ bd: Mapped[str] = mapped_column(nullable=True)
+ z: Mapped[Optional[int]] = mapped_column("zcol")
+ q: Mapped[Optional[int]]
+
+ __mapper_args__ = {"polymorphic_identity": "b"}
+
+
+class BulkDMLReturningConcreteInhTest(
+ BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
+):
+ __requires__ = ("insert_returning",)
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ decl_base = cls.DeclarativeBasic
+
+ class A(fixtures.ComparableEntity, decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ type: Mapped[str]
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+ y: Mapped[Optional[int]]
+
+ __mapper_args__ = {
+ "polymorphic_identity": "a",
+ "polymorphic_on": "type",
+ }
+
+ class B(A):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ type: Mapped[str]
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column("xcol")
+ y: Mapped[Optional[int]]
+
+ bd: Mapped[str] = mapped_column(nullable=True)
+ z: Mapped[Optional[int]] = mapped_column("zcol")
+ q: Mapped[Optional[int]]
+
+ __mapper_args__ = {
+ "polymorphic_identity": "b",
+ "concrete": True,
+ "polymorphic_on": "type",
+ }
+
+
+class CTETest(fixtures.DeclarativeMappedTest):
+ __requires__ = ("insert_returning", "ctes_on_dml")
+ __backend__ = True
+
+ @classmethod
+ def setup_classes(cls):
+ decl_base = cls.DeclarativeBasic
+
+ class User(fixtures.ComparableEntity, decl_base):
+ __tablename__ = "users"
+ id: Mapped[uuid.UUID] = mapped_column(primary_key=True)
+ username: Mapped[str]
+
+ @testing.combinations(
+ ("cte_aliased", True),
+ ("cte", False),
+ argnames="wrap_cte_in_aliased",
+ id_="ia",
+ )
+ @testing.combinations(
+ ("use_union", True),
+ ("no_union", False),
+ argnames="use_a_union",
+ id_="ia",
+ )
+ @testing.combinations(
+ "from_statement", "aliased", "direct", argnames="fetch_entity_type"
+ )
+ def test_select_from_insert_cte(
+ self, wrap_cte_in_aliased, use_a_union, fetch_entity_type
+ ):
+ """test the use case from #8544; SELECT that selects from a
+ CTE INSERT...RETURNING.
+
+ """
+ User = self.classes.User
+
+ id_ = uuid.uuid4()
+
+ cte = (
+ insert(User)
+ .values(id=id_, username="some user")
+ .returning(User)
+ .cte()
+ )
+ if wrap_cte_in_aliased:
+ cte = aliased(User, cte)
+
+ if use_a_union:
+ stmt = select(User).where(User.id == id_).union(select(cte))
+ else:
+ stmt = select(cte)
+
+ if fetch_entity_type == "from_statement":
+ outer_stmt = select(User).from_statement(stmt)
+ expect_entity = True
+ elif fetch_entity_type == "aliased":
+ outer_stmt = select(aliased(User, stmt.subquery()))
+ expect_entity = True
+ elif fetch_entity_type == "direct":
+ outer_stmt = stmt
+ expect_entity = not use_a_union and wrap_cte_in_aliased
+ else:
+ assert False
+
+ sess = fixture_session()
+ with self.sql_execution_asserter() as asserter:
+
+ if not expect_entity:
+ row = sess.execute(outer_stmt).one()
+ eq_(row, (id_, "some user"))
+ else:
+ new_user = sess.scalars(outer_stmt).one()
+ eq_(new_user, User(id=id_, username="some user"))
+
+ cte_sql = (
+ "(INSERT INTO users (id, username) "
+ "VALUES (:param_1, :param_2) "
+ "RETURNING users.id, users.username)"
+ )
+
+ if fetch_entity_type == "aliased" and not use_a_union:
+ expected = (
+ f"WITH anon_2 AS {cte_sql} "
+ "SELECT anon_1.id, anon_1.username "
+ "FROM (SELECT anon_2.id AS id, anon_2.username AS username "
+ "FROM anon_2) AS anon_1"
+ )
+ elif not use_a_union:
+ expected = (
+ f"WITH anon_1 AS {cte_sql} "
+ "SELECT anon_1.id, anon_1.username FROM anon_1"
+ )
+ elif fetch_entity_type == "aliased":
+ expected = (
+ f"WITH anon_2 AS {cte_sql} SELECT anon_1.id, anon_1.username "
+ "FROM (SELECT users.id AS id, users.username AS username "
+ "FROM users WHERE users.id = :id_1 "
+ "UNION SELECT anon_2.id AS id, anon_2.username AS username "
+ "FROM anon_2) AS anon_1"
+ )
+ else:
+ expected = (
+ f"WITH anon_1 AS {cte_sql} "
+ "SELECT users.id, users.username FROM users "
+ "WHERE users.id = :id_1 "
+ "UNION SELECT anon_1.id, anon_1.username FROM anon_1"
+ )
+
+ asserter.assert_(
+ CompiledSQL(expected, [{"param_1": id_, "param_2": "some user"}])
+ )
"""test #3162"""
User = self.classes.User
-
with expect_raises_message(
evaluator.UnevaluatableError,
r"Custom operator '\^\^' can't be evaluated in "
+from sqlalchemy import bindparam
from sqlalchemy import Boolean
from sqlalchemy import case
from sqlalchemy import column
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import insert
+from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import lambda_stmt
from sqlalchemy import MetaData
from sqlalchemy import text
from sqlalchemy import update
from sqlalchemy.orm import backref
+from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import in_
from sqlalchemy.testing import not_in
},
)
+ def test_update_dont_use_col_key(self):
+ User = self.classes.User
+
+ s = fixture_session()
+
+ # make sure objects are present to synchronize
+ _ = s.query(User).all()
+
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "Attribute name not found, can't be synchronized back "
+ "to objects: 'age_int'",
+ ):
+ s.execute(update(User).values(age_int=5))
+
+ stmt = update(User).values(age=5)
+ s.execute(stmt)
+ eq_(s.scalars(select(User.age)).all(), [5, 5, 5, 5])
+
@testing.combinations("table", "mapper", "both", argnames="bind_type")
@testing.combinations(
"update", "insert", "delete", argnames="statement_type"
assert_raises_message(
exc.ArgumentError,
"Valid strategies for session synchronization "
- "are 'evaluate', 'fetch', False",
+ "are 'auto', 'evaluate', 'fetch', False",
s.query(User).update,
{},
synchronize_session="fake",
def test_evaluate_dont_refresh_expired_objects(
self, expire_jane_age, add_filter_criteria
):
+ """test #5664.
+
+ approach is revised in SQLAlchemy 2.0 to not pre-emptively
+ unexpire the involved attributes
+
+ """
User = self.classes.User
sess = fixture_session()
if add_filter_criteria:
if expire_jane_age:
asserter.assert_(
- # it has to unexpire jane.name, because jane is not fully
- # expired and the criteria needs to look at this particular
- # key
- CompiledSQL(
- "SELECT users.age_int AS users_age_int, "
- "users.name AS users_name FROM users "
- "WHERE users.id = :pk_1",
- [{"pk_1": 4}],
- ),
+ # previously, this would unexpire the attribute and
+ # cause an additional SELECT. The
+ # 2.0 approach is that if the object has expired attrs
+ # we just expire the whole thing, avoiding SQL up front
CompiledSQL(
"UPDATE users "
"SET age_int=(users.age_int + :age_int_1) "
)
else:
asserter.assert_(
- # it has to unexpire jane.name, because jane is not fully
- # expired and the criteria needs to look at this particular
- # key
- CompiledSQL(
- "SELECT users.name AS users_name FROM users "
- "WHERE users.id = :pk_1",
- [{"pk_1": 4}],
- ),
+ # previously, this would unexpire the attribute and
+ # cause an additional SELECT. The
+ # 2.0 approach is that if the object has expired attrs
+ # we just expire the whole thing, avoiding SQL up front
CompiledSQL(
"UPDATE users SET "
"age_int=(users.age_int + :age_int_1) "
),
]
- if expire_jane_age and not add_filter_criteria:
+ if expire_jane_age:
to_assert.append(
- # refresh jane
+ # refresh jane for partial attributes
CompiledSQL(
"SELECT users.age_int AS users_age_int, "
"users.name AS users_name FROM users "
)
asserter.assert_(*to_assert)
+ @testing.combinations(True, False, argnames="is_evaluable")
+ def test_auto_synchronize(self, is_evaluable):
+ User = self.classes.User
+
+ sess = fixture_session()
+
+ john, jack, jill, jane = sess.query(User).order_by(User.id).all()
+
+ if is_evaluable:
+ crit = or_(User.name == "jack", User.name == "jane")
+ else:
+ crit = case((User.name.in_(["jack", "jane"]), True), else_=False)
+
+ with self.sql_execution_asserter() as asserter:
+ sess.execute(update(User).where(crit).values(age=User.age + 10))
+
+ if is_evaluable:
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE users SET age_int=(users.age_int + :age_int_1) "
+ "WHERE users.name = :name_1 OR users.name = :name_2",
+ [{"age_int_1": 10, "name_1": "jack", "name_2": "jane"}],
+ ),
+ )
+ elif testing.db.dialect.update_returning:
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE users SET age_int=(users.age_int + :age_int_1) "
+ "WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) "
+ "THEN :param_1 ELSE :param_2 END = 1 RETURNING users.id",
+ [
+ {
+ "age_int_1": 10,
+ "name_1": ["jack", "jane"],
+ "param_1": True,
+ "param_2": False,
+ }
+ ],
+ ),
+ )
+ else:
+ asserter.assert_(
+ CompiledSQL(
+ "SELECT users.id FROM users WHERE CASE WHEN "
+ "(users.name IN (__[POSTCOMPILE_name_1])) "
+ "THEN :param_1 ELSE :param_2 END = 1",
+ [
+ {
+ "name_1": ["jack", "jane"],
+ "param_1": True,
+ "param_2": False,
+ }
+ ],
+ ),
+ CompiledSQL(
+ "UPDATE users SET age_int=(users.age_int + :age_int_1) "
+ "WHERE CASE WHEN (users.name IN (__[POSTCOMPILE_name_1])) "
+ "THEN :param_1 ELSE :param_2 END = 1",
+ [
+ {
+ "age_int_1": 10,
+ "name_1": ["jack", "jane"],
+ "param_1": True,
+ "param_2": False,
+ }
+ ],
+ ),
+ )
+
def test_fetch_dont_refresh_expired_objects(self):
User = self.classes.User
),
)
- def test_delete(self):
+ @testing.combinations(False, None, "auto", "evaluate", "fetch")
+ def test_delete(self, synchronize_session):
User = self.classes.User
sess = fixture_session()
john, jack, jill, jane = sess.query(User).order_by(User.id).all()
- sess.query(User).filter(
+
+ stmt = delete(User).filter(
or_(User.name == "john", User.name == "jill")
- ).delete()
+ )
+ if synchronize_session is not None:
+ stmt = stmt.execution_options(
+ synchronize_session=synchronize_session
+ )
+ sess.execute(stmt)
- assert john not in sess and jill not in sess
+ if synchronize_session not in (False, None):
+ assert john not in sess and jill not in sess
eq_(sess.query(User).order_by(User.id).all(), [jack, jane])
eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane])
+ def test_update_multirow_not_supported(self):
+ User = self.classes.User
+
+ sess = fixture_session()
+
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "WHERE clause with bulk ORM UPDATE not supported " "right now.",
+ ):
+ sess.execute(
+ update(User).where(User.id == bindparam("id")),
+ [{"id": 1, "age": 27}, {"id": 2, "age": 37}],
+ )
+
+ def test_delete_bulk_not_supported(self):
+ User = self.classes.User
+
+ sess = fixture_session()
+
+ with expect_raises_message(
+ exc.InvalidRequestError, "Bulk ORM DELETE not supported right now."
+ ):
+ sess.execute(
+ delete(User),
+ [{"id": 1}, {"id": 2}],
+ )
+
def test_update(self):
User, users = self.classes.User, self.tables.users
)
eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
+
eq_(
sess.query(User.age).order_by(User.id).all(),
list(zip([25, 37, 29, 27])),
)
@testing.requires.update_returning
- def test_update_explicit_returning(self):
+ def test_update_evaluate_w_explicit_returning(self):
User = self.classes.User
sess = fixture_session()
.filter(User.age > 29)
.values({"age": User.age - 10})
.returning(User.id)
+ .execution_options(synchronize_session="evaluate")
)
rows = sess.execute(stmt).all()
)
@testing.requires.update_returning
- def test_no_fetch_w_explicit_returning(self):
+ @testing.combinations("update", "delete", argnames="crud_type")
+ def test_fetch_w_explicit_returning(self, crud_type):
User = self.classes.User
sess = fixture_session()
- stmt = (
- update(User)
- .filter(User.age > 29)
- .values({"age": User.age - 10})
- .execution_options(synchronize_session="fetch")
- .returning(User.id)
- )
- with expect_raises_message(
- exc.InvalidRequestError,
- r"Can't use synchronize_session='fetch' "
- r"with explicit returning\(\)",
- ):
- sess.execute(stmt)
+ if crud_type == "update":
+ stmt = (
+ update(User)
+ .filter(User.age > 29)
+ .values({"age": User.age - 10})
+ .execution_options(synchronize_session="fetch")
+ .returning(User, User.name)
+ )
+ expected = [
+ (User(age=37), "jack"),
+ (User(age=27), "jane"),
+ ]
+ elif crud_type == "delete":
+ stmt = (
+ delete(User)
+ .filter(User.age > 29)
+ .execution_options(synchronize_session="fetch")
+ .returning(User, User.name)
+ )
+ expected = [
+ (User(age=47), "jack"),
+ (User(age=37), "jane"),
+ ]
+ else:
+ assert False
+
+ result = sess.execute(stmt)
+
+ eq_(result.all(), expected)
@testing.combinations(True, False, argnames="implicit_returning")
def test_delete_fetch_returning(self, implicit_returning):
list(zip([25, 47, 44, 37])),
)
- def test_update_changes_resets_dirty(self):
+ @testing.combinations("orm", "bulk")
+ def test_update_changes_resets_dirty(self, update_type):
User = self.classes.User
sess = fixture_session(autoflush=False)
# autoflush is false. therefore our '50' and '37' are getting
# blown away by this operation.
- sess.query(User).filter(User.age > 29).update(
- {"age": User.age - 10}, synchronize_session="evaluate"
- )
+ if update_type == "orm":
+ sess.execute(
+ update(User)
+ .filter(User.age > 29)
+ .values({"age": User.age - 10}),
+ execution_options=dict(synchronize_session="evaluate"),
+ )
+ elif update_type == "bulk":
+
+ data = [
+ {"id": john.id, "age": 25},
+ {"id": jack.id, "age": 37},
+ {"id": jill.id, "age": 29},
+ {"id": jane.id, "age": 27},
+ ]
+
+ sess.execute(
+ update(User),
+ data,
+ execution_options=dict(synchronize_session="evaluate"),
+ )
+
+ else:
+ assert False
for x in (john, jack, jill, jane):
assert not sess.is_modified(x)
assert not sess.is_modified(john)
assert not sess.is_modified(jack)
+ @testing.combinations(
+ None, False, "evaluate", "fetch", argnames="synchronize_session"
+ )
+ @testing.combinations(True, False, argnames="homogeneous_keys")
+ def test_bulk_update_synchronize_session(
+ self, synchronize_session, homogeneous_keys
+ ):
+ User = self.classes.User
+
+ sess = fixture_session(expire_on_commit=False)
+
+ john, jack, jill, jane = sess.query(User).order_by(User.id).all()
+
+ if homogeneous_keys:
+ data = [
+ {"id": john.id, "age": 35},
+ {"id": jack.id, "age": 27},
+ {"id": jill.id, "age": 30},
+ ]
+ else:
+ data = [
+ {"id": john.id, "age": 35},
+ {"id": jack.id, "name": "new jack"},
+ {"id": jill.id, "age": 30, "name": "new jill"},
+ ]
+
+ with self.sql_execution_asserter() as asserter:
+ if synchronize_session is not None:
+ opts = {"synchronize_session": synchronize_session}
+ else:
+ opts = {}
+
+ if synchronize_session == "fetch":
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "The 'fetch' synchronization strategy is not available "
+ "for 'bulk' ORM updates",
+ ):
+ sess.execute(update(User), data, execution_options=opts)
+ return
+ else:
+ sess.execute(update(User), data, execution_options=opts)
+
+ if homogeneous_keys:
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE users SET age_int=:age_int "
+ "WHERE users.id = :users_id",
+ [
+ {"age_int": 35, "users_id": 1},
+ {"age_int": 27, "users_id": 2},
+ {"age_int": 30, "users_id": 3},
+ ],
+ )
+ )
+ else:
+ asserter.assert_(
+ CompiledSQL(
+ "UPDATE users SET age_int=:age_int "
+ "WHERE users.id = :users_id",
+ [{"age_int": 35, "users_id": 1}],
+ ),
+ CompiledSQL(
+ "UPDATE users SET name=:name WHERE users.id = :users_id",
+ [{"name": "new jack", "users_id": 2}],
+ ),
+ CompiledSQL(
+ "UPDATE users SET name=:name, age_int=:age_int "
+ "WHERE users.id = :users_id",
+ [{"name": "new jill", "age_int": 30, "users_id": 3}],
+ ),
+ )
+
+ if synchronize_session is False:
+ eq_(jill.name, "jill")
+ eq_(jack.name, "jack")
+ eq_(jill.age, 29)
+ eq_(jack.age, 47)
+ else:
+ if not homogeneous_keys:
+ eq_(jill.name, "new jill")
+ eq_(jack.name, "new jack")
+ eq_(jack.age, 47)
+ else:
+ eq_(jack.age, 27)
+ eq_(jill.age, 30)
+
def test_update_changes_with_autoflush(self):
User = self.classes.User
)
@testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount)
- def test_update_returns_rowcount(self):
+ @testing.combinations("auto", "fetch", "evaluate")
+ def test_update_returns_rowcount(self, synchronize_session):
User = self.classes.User
sess = fixture_session()
rowcount = (
sess.query(User)
.filter(User.age > 29)
- .update({"age": User.age + 0})
+ .update(
+ {"age": User.age + 0}, synchronize_session=synchronize_session
+ )
)
eq_(rowcount, 2)
rowcount = (
sess.query(User)
.filter(User.age > 29)
- .update({"age": User.age - 10})
+ .update(
+ {"age": User.age - 10}, synchronize_session=synchronize_session
+ )
)
eq_(rowcount, 2)
# test future
result = sess.execute(
- update(User).where(User.age > 19).values({"age": User.age - 10})
+ update(User).where(User.age > 19).values({"age": User.age - 10}),
+ execution_options={"synchronize_session": synchronize_session},
)
eq_(result.rowcount, 4)
)
assert john not in sess
- def test_evaluate_before_update(self):
+ @testing.combinations(True, False)
+ def test_evaluate_before_update(self, full_expiration):
User = self.classes.User
sess = fixture_session()
john = sess.query(User).filter_by(name="john").one()
- sess.expire(john, ["age"])
+
+ if full_expiration:
+ sess.expire(john)
+ else:
+ sess.expire(john, ["age"])
# eval must be before the update. otherwise
# we eval john, age has been expired and doesn't
eq_(john.name, "j2")
eq_(john.age, 40)
- def test_evaluate_before_delete(self):
+ @testing.combinations(True, False)
+ def test_evaluate_before_delete(self, full_expiration):
User = self.classes.User
sess = fixture_session()
john = sess.query(User).filter_by(name="john").one()
- sess.expire(john, ["age"])
+ jill = sess.query(User).filter_by(name="jill").one()
+ jane = sess.query(User).filter_by(name="jane").one()
- sess.query(User).filter_by(name="john").filter_by(age=25).delete(
+ if full_expiration:
+ sess.expire(jill)
+ sess.expire(john)
+ else:
+ sess.expire(jill, ["age"])
+ sess.expire(john, ["age"])
+
+ sess.query(User).filter(or_(User.age == 25, User.age == 37)).delete(
synchronize_session="evaluate"
)
- assert john not in sess
+
+ # was fully deleted
+ assert jane not in sess
+
+ # deleted object was expired, but not otherwise affected
+ assert jill in sess
+
+ # deleted object was expired, but not otherwise affected
+ assert john in sess
+
+ # partially expired row fully expired
+ assert inspect(jill).expired
+
+ # non-deleted row still present
+ eq_(jill.age, 29)
+
+ # partially expired row fully expired
+ assert inspect(john).expired
+
+ # is deleted
+ with expect_raises(orm_exc.ObjectDeletedError):
+ john.name
def test_fetch_before_delete(self):
User = self.classes.User
sess.query(User).filter_by(name="john").filter_by(age=25).delete(
synchronize_session="fetch"
)
+
assert john not in sess
def test_update_unordered_dict(self):
]
eq_(["name", "age_int"], cols)
+ @testing.requires.sqlite
+ def test_sharding_extension_returning_mismatch(self, testing_engine):
+ """test one horizontal shard case where the given binds don't match
+ for RETURNING support; we dont support this.
+
+ See test/ext/test_horizontal_shard.py for complete round trip
+ test cases for ORM update/delete
+
+ """
+ e1 = testing_engine("sqlite://")
+ e2 = testing_engine("sqlite://")
+ e1.connect().close()
+ e2.connect().close()
+
+ e1.dialect.update_returning = True
+ e2.dialect.update_returning = False
+
+ engines = [e1, e2]
+
+ # a simulated version of the horizontal sharding extension
+ def execute_and_instances(orm_context):
+ execution_options = dict(orm_context.local_execution_options)
+ partial = []
+ for engine in engines:
+ bind_arguments = dict(orm_context.bind_arguments)
+ bind_arguments["bind"] = engine
+ result_ = orm_context.invoke_statement(
+ bind_arguments=bind_arguments,
+ execution_options=execution_options,
+ )
+
+ partial.append(result_)
+ return partial[0].merge(*partial[1:])
+
+ User = self.classes.User
+ session = Session()
+
+ event.listen(
+ session, "do_orm_execute", execute_and_instances, retval=True
+ )
+
+ stmt = (
+ update(User)
+ .filter(User.id == 15)
+ .values(age=123)
+ .execution_options(synchronize_session="fetch")
+ )
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "For synchronize_session='fetch', can't mix multiple backends "
+ "where some support RETURNING and others don't",
+ ):
+ session.execute(stmt)
+
class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest):
@classmethod
"Could not evaluate current criteria in Python.",
q.update,
{"samename": "ed"},
+ synchronize_session="evaluate",
)
@testing.requires.multi_table_update
sess.commit()
eq_(d1.cnt, 0)
- sess.query(Data).update({Data.cnt: Data.cnt + 1})
+ sess.query(Data).update({Data.cnt: Data.cnt + 1}, "evaluate")
sess.flush()
eq_(d1.cnt, 1)
)
@testing.requires.update_returning
- def test_load_from_update(self, connection):
+ @testing.combinations(True, False, argnames="use_from_statement")
+ def test_load_from_update(self, connection, use_from_statement):
User = self.classes.User
stmt = (
.returning(User)
)
- stmt = select(User).from_statement(stmt)
+ if use_from_statement:
+ # this is now a legacy-ish case, because as of 2.0 you can just
+ # use returning() directly to get the objects back.
+ #
+ # when from_statement is used, the UPDATE statement is no
+ # longer interpreted by
+ # BulkUDCompileState.orm_pre_session_exec or
+ # BulkUDCompileState.orm_setup_cursor_result. The compilation
+ # level routines still take place though
+ stmt = select(User).from_statement(stmt)
with Session(connection) as sess:
rows = sess.execute(stmt).scalars().all()
("multiple", testing.requires.multivalues_inserts),
argnames="params",
)
- def test_load_from_insert(self, connection, params):
+ @testing.combinations(True, False, argnames="use_from_statement")
+ def test_load_from_insert(self, connection, params, use_from_statement):
User = self.classes.User
if params == "multiple":
stmt = insert(User).values(values).returning(User)
- stmt = select(User).from_statement(stmt)
+ if use_from_statement:
+ stmt = select(User).from_statement(stmt)
with Session(connection) as sess:
rows = sess.execute(stmt).scalars().all()
)
else:
assert False
+
+ @testing.requires.delete_returning
+ @testing.combinations(True, False, argnames="use_from_statement")
+ def test_load_from_delete(self, connection, use_from_statement):
+ User = self.classes.User
+
+ stmt = (
+ delete(User).where(User.name.in_(["jack", "jill"])).returning(User)
+ )
+
+ if use_from_statement:
+ stmt = select(User).from_statement(stmt)
+
+ with Session(connection) as sess:
+ rows = sess.execute(stmt).scalars().all()
+
+ eq_(
+ rows,
+ [User(name="jack", age=47), User(name="jill", age=29)],
+ )
+
+ # TODO: state of above objects should be "deleted"
and testing.db.dialect.supports_default_metavalue,
[
CompiledSQL(
- "INSERT INTO a (id) VALUES (DEFAULT)", [{}, {}, {}, {}]
+ "INSERT INTO a (id) VALUES (DEFAULT) RETURNING a.id",
+ [{}, {}, {}, {}],
),
],
[
),
(
lambda User: update(User)
+ .execution_options(synchronize_session=False)
.values(name="not ed")
.where(User.name == "ed"),
lambda User: {"clause": mock.ANY, "mapper": inspect(User)},
engine = {"e1": e1, "e2": e2, "e3": e3}[expected_engine_name]
with mock.patch(
- "sqlalchemy.orm.context.ORMCompileState.orm_setup_cursor_result"
+ "sqlalchemy.orm.context." "ORMCompileState.orm_setup_cursor_result"
+ ), mock.patch(
+ "sqlalchemy.orm.context.ORMCompileState.orm_execute_statement"
+ ), mock.patch(
+ "sqlalchemy.orm.bulk_persistence."
+ "BulkORMInsert.orm_execute_statement"
+ ), mock.patch(
+ "sqlalchemy.orm.bulk_persistence."
+ "BulkUDCompileState.orm_setup_cursor_result"
):
sess.execute(statement)
import dataclasses
import operator
+import random
import sqlalchemy as sa
from sqlalchemy import ForeignKey
+from sqlalchemy import insert
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy import String
is g.edges[1]
)
- def test_bulk_update_sql(self):
+ def test_update_crit_sql(self):
Edge, Point = (self.classes.Edge, self.classes.Point)
sess = self._fixture()
dialect="default",
)
- def test_bulk_update_evaluate(self):
+ def test_update_crit_evaluate(self):
Edge, Point = (self.classes.Edge, self.classes.Point)
sess = self._fixture()
eq_(e1.end, Point(17, 8))
- def test_bulk_update_fetch(self):
+ def test_update_crit_fetch(self):
Edge, Point = (self.classes.Edge, self.classes.Point)
sess = self._fixture()
eq_(e1.end, Point(17, 8))
+ @testing.combinations(
+ "legacy", "statement", "values", "stmt_returning", "values_returning"
+ )
+ def test_bulk_insert(self, type_):
+ Edge, Point = (self.classes.Edge, self.classes.Point)
+ Graph = self.classes.Graph
+
+ sess = self._fixture()
+
+ graph = Graph(id=2)
+ sess.add(graph)
+ sess.flush()
+ graph_id = 2
+
+ data = [
+ {
+ "start": Point(random.randint(1, 50), random.randint(1, 50)),
+ "end": Point(random.randint(1, 50), random.randint(1, 50)),
+ "graph_id": graph_id,
+ }
+ for i in range(25)
+ ]
+ returning = False
+ if type_ == "statement":
+ sess.execute(insert(Edge), data)
+ elif type_ == "stmt_returning":
+ result = sess.scalars(insert(Edge).returning(Edge), data)
+ returning = True
+ elif type_ == "values":
+ sess.execute(insert(Edge).values(data))
+ elif type_ == "values_returning":
+ result = sess.scalars(insert(Edge).values(data).returning(Edge))
+ returning = True
+ elif type_ == "legacy":
+ sess.bulk_insert_mappings(Edge, data)
+ else:
+ assert False
+
+ if returning:
+ eq_(result.all(), [Edge(rec["start"], rec["end"]) for rec in data])
+
+ edges = self.tables.edges
+ eq_(
+ sess.execute(
+ select(edges.c["x1", "y1", "x2", "y2"])
+ .where(edges.c.graph_id == graph_id)
+ .order_by(edges.c.id)
+ ).all(),
+ [
+ (e["start"].x, e["start"].y, e["end"].x, e["end"].y)
+ for e in data
+ ],
+ )
+
+ @testing.combinations("legacy", "statement")
+ def test_bulk_insert_heterogeneous(self, type_):
+ Edge, Point = (self.classes.Edge, self.classes.Point)
+ Graph = self.classes.Graph
+
+ sess = self._fixture()
+
+ graph = Graph(id=2)
+ sess.add(graph)
+ sess.flush()
+ graph_id = 2
+
+ d1 = [
+ {
+ "start": Point(random.randint(1, 50), random.randint(1, 50)),
+ "end": Point(random.randint(1, 50), random.randint(1, 50)),
+ "graph_id": graph_id,
+ }
+ for i in range(3)
+ ]
+ d2 = [
+ {
+ "start": Point(random.randint(1, 50), random.randint(1, 50)),
+ "graph_id": graph_id,
+ }
+ for i in range(2)
+ ]
+ d3 = [
+ {
+ "x2": random.randint(1, 50),
+ "y2": random.randint(1, 50),
+ "graph_id": graph_id,
+ }
+ for i in range(2)
+ ]
+ data = d1 + d2 + d3
+ random.shuffle(data)
+
+ assert_data = [
+ {
+ "start": d["start"] if "start" in d else None,
+ "end": d["end"]
+ if "end" in d
+ else Point(d["x2"], d["y2"])
+ if "x2" in d
+ else None,
+ "graph_id": d["graph_id"],
+ }
+ for d in data
+ ]
+
+ if type_ == "statement":
+ sess.execute(insert(Edge), data)
+ elif type_ == "legacy":
+ sess.bulk_insert_mappings(Edge, data)
+ else:
+ assert False
+
+ edges = self.tables.edges
+ eq_(
+ sess.execute(
+ select(edges.c["x1", "y1", "x2", "y2"])
+ .where(edges.c.graph_id == graph_id)
+ .order_by(edges.c.id)
+ ).all(),
+ [
+ (
+ e["start"].x if e["start"] else None,
+ e["start"].y if e["start"] else None,
+ e["end"].x if e["end"] else None,
+ e["end"].y if e["end"] else None,
+ )
+ for e in assert_data
+ ],
+ )
+
+ @testing.combinations("legacy", "statement")
+ def test_bulk_update(self, type_):
+ Edge, Point = (self.classes.Edge, self.classes.Point)
+ Graph = self.classes.Graph
+
+ sess = self._fixture()
+
+ graph = Graph(id=2)
+ sess.add(graph)
+ sess.flush()
+ graph_id = 2
+
+ data = [
+ {
+ "start": Point(random.randint(1, 50), random.randint(1, 50)),
+ "end": Point(random.randint(1, 50), random.randint(1, 50)),
+ "graph_id": graph_id,
+ }
+ for i in range(25)
+ ]
+ sess.execute(insert(Edge), data)
+
+ inserted_data = [
+ dict(row._mapping)
+ for row in sess.execute(
+ select(Edge.id, Edge.start, Edge.end, Edge.graph_id)
+ .where(Edge.graph_id == graph_id)
+ .order_by(Edge.id)
+ )
+ ]
+
+ to_update = []
+ updated_pks = {}
+ for rec in random.choices(inserted_data, k=7):
+ rec_copy = dict(rec)
+ updated_pks[rec_copy["id"]] = rec_copy
+ rec_copy["start"] = Point(
+ random.randint(1, 50), random.randint(1, 50)
+ )
+ rec_copy["end"] = Point(
+ random.randint(1, 50), random.randint(1, 50)
+ )
+ to_update.append(rec_copy)
+
+ expected_dataset = [
+ updated_pks[row["id"]] if row["id"] in updated_pks else row
+ for row in inserted_data
+ ]
+
+ if type_ == "statement":
+ sess.execute(update(Edge), to_update)
+ elif type_ == "legacy":
+ sess.bulk_update_mappings(Edge, to_update)
+ else:
+ assert False
+
+ edges = self.tables.edges
+ eq_(
+ sess.execute(
+ select(edges.c["x1", "y1", "x2", "y2"])
+ .where(edges.c.graph_id == graph_id)
+ .order_by(edges.c.id)
+ ).all(),
+ [
+ (e["start"].x, e["start"].y, e["end"].x, e["end"].y)
+ for e in expected_dataset
+ ],
+ )
+
def test_get_history(self):
Edge = self.classes.Edge
Point = self.classes.Point
[
CompiledSQL(
"INSERT INTO ball (person_id, data) "
- "VALUES (:person_id, :data)",
+ "VALUES (:person_id, :data) RETURNING ball.id",
[
{"person_id": None, "data": "some data"},
{"person_id": None, "data": "some data"},
CompiledSQL(
"UPDATE test SET foo=:foo WHERE test.id = :test_id",
[{"foo": 5, "test_id": 1}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test SET foo=:foo WHERE test.id = :test_id",
[{"foo": 6, "test_id": 2}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test.bar AS test_bar FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 1}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test.bar AS test_bar FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 2}],
+ enable_returning=False,
),
)
else:
canary = self._flag_fixture(sess)
- sess.execute(delete(User).filter_by(id=18))
- sess.execute(update(User).filter_by(id=18).values(name="eighteen"))
+ sess.execute(
+ delete(User)
+ .filter_by(id=18)
+ .execution_options(synchronize_session="evaluate")
+ )
+ sess.execute(
+ update(User)
+ .filter_by(id=18)
+ .values(name="eighteen")
+ .execution_options(synchronize_session="evaluate")
+ )
eq_(
canary.mock_calls,
testing.db.dialect.insert_executemany_returning,
[
CompiledSQL(
- "INSERT INTO users (name) VALUES (:name)",
+ "INSERT INTO users (name) VALUES (:name) "
+ "RETURNING users.id",
[{"name": "u1"}, {"name": "u2"}],
),
CompiledSQL(
"INSERT INTO addresses (user_id, email_address) "
- "VALUES (:user_id, :email_address)",
+ "VALUES (:user_id, :email_address) "
+ "RETURNING addresses.id",
[
{"user_id": 1, "email_address": "a1"},
{"user_id": 2, "email_address": "a2"},
[
CompiledSQL(
"INSERT INTO addresses (user_id, email_address) "
- "VALUES (:user_id, :email_address)",
+ "VALUES (:user_id, :email_address) "
+ "RETURNING addresses.id",
lambda ctx: [
{"email_address": "a1", "user_id": u1.id},
{"email_address": "a2", "user_id": u1.id},
[
CompiledSQL(
"INSERT INTO addresses (user_id, email_address) "
- "VALUES (:user_id, :email_address)",
+ "VALUES (:user_id, :email_address) "
+ "RETURNING addresses.id",
lambda ctx: [
{"email_address": "a1", "user_id": u1.id},
{"email_address": "a2", "user_id": u1.id},
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
- "(:parent_id, :data)",
+ "(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n1.id, "data": "n2"},
{"parent_id": n1.id, "data": "n3"},
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
- "(:parent_id, :data)",
+ "(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n1.id, "data": "n2"},
{"parent_id": n1.id, "data": "n3"},
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
- "(:parent_id, :data)",
+ "(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n1.id, "data": "n11"},
{"parent_id": n1.id, "data": "n12"},
[
CompiledSQL(
"INSERT INTO nodes (parent_id, data) VALUES "
- "(:parent_id, :data)",
+ "(:parent_id, :data) RETURNING nodes.id",
lambda ctx: [
{"parent_id": n12.id, "data": "n121"},
{"parent_id": n12.id, "data": "n122"},
testing.db.dialect.insert_executemany_returning,
[
CompiledSQL(
- "INSERT INTO t (data) VALUES (:data)",
+ "INSERT INTO t (data) VALUES (:data) RETURNING t.id",
[{"data": "t1"}, {"data": "t2"}],
),
],
CompiledSQL(
"INSERT INTO test (id, foo) VALUES (:id, 2 + 5)",
[{"id": 1}],
+ enable_returning=False,
),
CompiledSQL(
"INSERT INTO test (id, foo) VALUES (:id, 5 + 5)",
[{"id": 2}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test.foo AS test_foo FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 1}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test.foo AS test_foo FROM test "
"WHERE test.id = :pk_1",
[{"pk_1": 2}],
+ enable_returning=False,
),
)
CompiledSQL(
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
[{"foo": 5, "test2_id": 1}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=:bar "
"WHERE test2.id = :test2_id",
[{"foo": 6, "bar": 10, "test2_id": 2}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
[{"foo": 7, "test2_id": 3}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=:bar "
"WHERE test2.id = :test2_id",
[{"foo": 8, "bar": 12, "test2_id": 4}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test2.bar AS test2_bar FROM test2 "
"UPDATE test4 SET foo=:foo, bar=5 + 3 "
"WHERE test4.id = :test4_id",
[{"foo": 5, "test4_id": 1}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test4 SET foo=:foo, bar=:bar "
"WHERE test4.id = :test4_id",
[{"foo": 6, "bar": 10, "test4_id": 2}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test4 SET foo=:foo, bar=5 + 3 "
"WHERE test4.id = :test4_id",
[{"foo": 7, "test4_id": 3}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test4 SET foo=:foo, bar=:bar "
"WHERE test4.id = :test4_id",
[{"foo": 8, "bar": 12, "test4_id": 4}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test4.bar AS test4_bar FROM test4 "
"WHERE test4.id = :pk_1",
[{"pk_1": 1}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test4.bar AS test4_bar FROM test4 "
"WHERE test4.id = :pk_1",
[{"pk_1": 3}],
+ enable_returning=False,
),
],
),
"UPDATE test2 SET foo=:foo, bar=1 + 1 "
"WHERE test2.id = :test2_id",
[{"foo": 5, "test2_id": 1}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=:bar "
"WHERE test2.id = :test2_id",
[{"foo": 6, "bar": 10, "test2_id": 2}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
[{"foo": 7, "test2_id": 3}],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE test2 SET foo=:foo, bar=5 + 7 "
"WHERE test2.id = :test2_id",
[{"foo": 8, "test2_id": 4}],
+ enable_returning=False,
),
CompiledSQL(
"SELECT test2.bar AS test2_bar FROM test2 "
sess.add(f1)
statements = [
- # note that the assertsql tests the rule against
- # "default" - on a "returning" backend, the statement
- # includes "RETURNING"
CompiledSQL(
"INSERT INTO version_table (version_id, value) "
- "VALUES (1, :value)",
+ "VALUES (1, :value) "
+ "RETURNING version_table.id, version_table.version_id",
lambda ctx: [{"value": "f1"}],
)
]
"value": "f2",
}
],
+ enable_returning=False,
),
CompiledSQL(
"SELECT version_table.version_id "
"value": "f1a",
}
],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE version_table SET version_id=2, value=:value "
"value": "f2a",
}
],
+ enable_returning=False,
),
CompiledSQL(
"UPDATE version_table SET version_id=2, value=:value "
"value": "f3a",
}
],
+ enable_returning=False,
),
CompiledSQL(
"SELECT version_table.version_id "
Table(
"test",
metadata,
- Column("x", Integer, primary_key=True),
+ Column(
+ "x", Integer, primary_key=True, test_needs_autoincrement=False
+ ),
Column("y", String(50)),
)
+ @testing.requires.insert_returning
+ def test_splice_horizontally(self, connection):
+ users = self.tables.users
+ addresses = self.tables.addresses
+
+ r1 = connection.execute(
+ users.insert().returning(users.c.user_name, users.c.user_id),
+ [
+ dict(user_id=1, user_name="john"),
+ dict(user_id=2, user_name="jack"),
+ ],
+ )
+
+ r2 = connection.execute(
+ addresses.insert().returning(
+ addresses.c.address_id,
+ addresses.c.address,
+ addresses.c.user_id,
+ ),
+ [
+ dict(address_id=1, user_id=1, address="foo@bar.com"),
+ dict(address_id=2, user_id=2, address="bar@bat.com"),
+ ],
+ )
+
+ rows = r1.splice_horizontally(r2).all()
+ eq_(
+ rows,
+ [
+ ("john", 1, 1, "foo@bar.com", 1),
+ ("jack", 2, 2, "bar@bat.com", 2),
+ ],
+ )
+
+ eq_(rows[0]._mapping[users.c.user_id], 1)
+ eq_(rows[0]._mapping[addresses.c.user_id], 1)
+ eq_(rows[1].address, "bar@bat.com")
+
+ with expect_raises_message(
+ exc.InvalidRequestError, "Ambiguous column name 'user_id'"
+ ):
+ rows[0].user_id
+
def test_keys_no_rows(self, connection):
for i in range(2):
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
from sqlalchemy.testing import provision
from sqlalchemy.testing.schema import Column
stmt = stmt.returning(t.c.x)
stmt = stmt.return_defaults()
+
assert_raises_message(
sa_exc.CompileError,
r"Can't compile statement that includes returning\(\) "
table = self.tables.returning_tbl
exprs = testing.resolve_lambda(testcase, table=table)
+
result = connection.execute(
table.insert().returning(*exprs),
{"persons": 5, "full": False, "strval": "str1"},
Column("upddef", Integer, onupdate=IncDefault()),
)
+ Table(
+ "table_no_addtl_defaults",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+
+ class MyType(TypeDecorator):
+ impl = String(50)
+
+ def process_result_value(self, value, dialect):
+ return f"PROCESSED! {value}"
+
+ Table(
+ "table_datatype_has_result_proc",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", MyType()),
+ )
+
def test_chained_insert_pk(self, connection):
t1 = self.tables.t1
result = connection.execute(
)
eq_(result.inserted_primary_key, (1,))
+ def test_insert_w_defaults_supplemental_cols(self, connection):
+ t1 = self.tables.t1
+ result = connection.execute(
+ t1.insert().return_defaults(supplemental_cols=[t1.c.id]),
+ {"data": "d1"},
+ )
+ eq_(result.all(), [(1, 0, None)])
+
+ def test_insert_w_no_defaults_supplemental_cols(self, connection):
+ t1 = self.tables.table_no_addtl_defaults
+ result = connection.execute(
+ t1.insert().return_defaults(supplemental_cols=[t1.c.id]),
+ {"data": "d1"},
+ )
+ eq_(result.all(), [(1,)])
+
+ def test_insert_w_defaults_supplemental_processor_cols(self, connection):
+ """test that the cursor._rewind() used by supplemental RETURNING
+ clears out result-row processors as we will have already processed
+ the rows.
+
+ """
+
+ t1 = self.tables.table_datatype_has_result_proc
+ result = connection.execute(
+ t1.insert().return_defaults(
+ supplemental_cols=[t1.c.id, t1.c.data]
+ ),
+ {"data": "d1"},
+ )
+ eq_(result.all(), [(1, "PROCESSED! d1")])
+
class UpdatedReturnDefaultsTest(fixtures.TablesTest):
__requires__ = ("update_returning",)
t1 = self.tables.t1
connection.execute(t1.insert().values(upddef=1))
+
result = connection.execute(
t1.update().values(upddef=2).return_defaults(t1.c.data)
)
[None],
)
+ def test_update_values_col_is_excluded(self, connection):
+ """columns that are in values() are not returned"""
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+
+ result = connection.execute(
+ t1.update().values(data="x", upddef=2).return_defaults(t1.c.data)
+ )
+ is_(result.returned_defaults, None)
+
+ result = connection.execute(
+ t1.update()
+ .values(data="x", upddef=2)
+ .return_defaults(t1.c.data, t1.c.id)
+ )
+ eq_(result.returned_defaults, (1,))
+
+ def test_update_supplemental_cols(self, connection):
+ """with supplemental_cols, we can get back arbitrary cols."""
+
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(
+ t1.update()
+ .values(data="x", insdef=3)
+ .return_defaults(supplemental_cols=[t1.c.data, t1.c.insdef])
+ )
+
+ row = result.returned_defaults
+
+ # row has all the cols in it
+ eq_(row, ("x", 3, 1))
+ eq_(row._mapping[t1.c.upddef], 1)
+ eq_(row._mapping[t1.c.insdef], 3)
+
+ # result is rewound
+ # but has both return_defaults + supplemental_cols
+ eq_(result.all(), [("x", 3, 1)])
+
+ def test_update_expl_return_defaults_plus_supplemental_cols(
+ self, connection
+ ):
+ """with supplemental_cols, we can get back arbitrary cols."""
+
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(
+ t1.update()
+ .values(data="x", insdef=3)
+ .return_defaults(
+ t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef]
+ )
+ )
+
+ row = result.returned_defaults
+
+ # row has all the cols in it
+ eq_(row, (1, "x", 3))
+ eq_(row._mapping[t1.c.id], 1)
+ eq_(row._mapping[t1.c.insdef], 3)
+ assert t1.c.upddef not in row._mapping
+
+ # result is rewound
+ # but has both return_defaults + supplemental_cols
+ eq_(result.all(), [(1, "x", 3)])
+
def test_update_sql_expr(self, connection):
from sqlalchemy import literal
eq_(dict(result.returned_defaults._mapping), {"upddef": 1})
+class DeleteReturnDefaultsTest(fixtures.TablesTest):
+ __requires__ = ("delete_returning",)
+ run_define_tables = "each"
+ __backend__ = True
+
+ define_tables = InsertReturnDefaultsTest.define_tables
+
+ def test_delete(self, connection):
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(t1.delete().return_defaults(t1.c.upddef))
+ eq_(
+ [result.returned_defaults._mapping[k] for k in (t1.c.upddef,)], [1]
+ )
+
+ def test_delete_empty_return_defaults(self, connection):
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=5))
+ result = connection.execute(t1.delete().return_defaults())
+
+ # there's no "delete" default, so we get None. we have to
+ # ask for them in all cases
+ eq_(result.returned_defaults, None)
+
+ def test_delete_non_default(self, connection):
+ """test that a column not marked at all as a
+ default works with this feature."""
+
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(t1.delete().return_defaults(t1.c.data))
+ eq_(
+ [result.returned_defaults._mapping[k] for k in (t1.c.data,)],
+ [None],
+ )
+
+ def test_delete_non_default_plus_default(self, connection):
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(
+ t1.delete().return_defaults(t1.c.data, t1.c.upddef)
+ )
+ eq_(
+ dict(result.returned_defaults._mapping),
+ {"data": None, "upddef": 1},
+ )
+
+ def test_delete_supplemental_cols(self, connection):
+ """with supplemental_cols, we can get back arbitrary cols."""
+
+ t1 = self.tables.t1
+ connection.execute(t1.insert().values(upddef=1))
+ result = connection.execute(
+ t1.delete().return_defaults(
+ t1.c.id, supplemental_cols=[t1.c.data, t1.c.insdef]
+ )
+ )
+
+ row = result.returned_defaults
+
+ # row has all the cols in it
+ eq_(row, (1, None, 0))
+ eq_(row._mapping[t1.c.insdef], 0)
+
+ # result is rewound
+ # but has both return_defaults + supplemental_cols
+ eq_(result.all(), [(1, None, 0)])
+
+
class InsertManyReturnDefaultsTest(fixtures.TablesTest):
__requires__ = ("insert_executemany_returning",)
run_define_tables = "each"
from sqlalchemy.sql import table
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import visitors
+from sqlalchemy.sql.dml import Insert
from sqlalchemy.sql.selectable import LABEL_STYLE_NONE
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
eq_(whereclause.left._annotations, {"foo": "bar"})
eq_(whereclause.right._annotations, {"foo": "bar"})
+ @testing.combinations(True, False, None)
+ def test_setup_inherit_cache(self, inherit_cache_value):
+ if inherit_cache_value is None:
+
+ class MyInsertThing(Insert):
+ pass
+
+ else:
+
+ class MyInsertThing(Insert):
+ inherit_cache = inherit_cache_value
+
+ t = table("t", column("x"))
+ anno = MyInsertThing(t)._annotate({"foo": "bar"})
+
+ if inherit_cache_value is not None:
+ is_(type(anno).__dict__["inherit_cache"], inherit_cache_value)
+ else:
+ assert "inherit_cache" not in type(anno).__dict__
+
def test_proxy_set_iteration_includes_annotated(self):
from sqlalchemy.schema import Column