def _raw_row_iterator(self):
return self._fetchiter_impl()
+ def merge(self, *others):
+ merged_result = super(CursorResult, self).merge(*others)
+ setup_rowcounts = not self._metadata.returns_rows
+ if setup_rowcounts:
+ merged_result.rowcount = sum(
+ result.rowcount for result in (self,) + others
+ )
+ return merged_result
+
def close(self):
"""Close this :class:`_engine.CursorResult`.
"""
return self._allrows()
- def _only_one_row(self, raise_for_second_row, raise_for_none):
+ def _only_one_row(self, raise_for_second_row, raise_for_none, scalar):
onerow = self._fetchone_impl
row = onerow(hard_close=True)
# if we checked for second row then that would have
# closed us :)
self._soft_close(hard=True)
- post_creational_filter = self._post_creational_filter
- if post_creational_filter:
- row = post_creational_filter(row)
- return row
+ if not scalar:
+ post_creational_filter = self._post_creational_filter
+ if post_creational_filter:
+ row = post_creational_filter(row)
+
+ if scalar and row:
+ return row[0]
+ else:
+ return row
def first(self):
"""Fetch the first row or None if no row is present.
Closes the result set and discards remaining rows.
+ .. note:: This method returns one **row**, e.g. tuple, by default.
+ To return exactly one single scalar value, that is, the first
+ column of the first row, use the :meth:`.Result.scalar` method,
+ or combine :meth:`.Result.scalars` and :meth:`.Result.first`.
+
.. comment: A warning is emitted if additional rows remain.
:return: a :class:`.Row` object if no filters are applied, or None
if no rows remain.
When filters are applied, such as :meth:`_engine.Result.mappings`
- or :meth:`._engine.Result.scalar`, different kinds of objects
+ or :meth:`._engine.Result.scalars`, different kinds of objects
may be returned.
+ .. seealso::
+
+ :meth:`_result.Result.scalar`
+
+ :meth:`_result.Result.one`
+
"""
- return self._only_one_row(False, False)
+ return self._only_one_row(False, False, False)
def one_or_none(self):
"""Return at most one result or raise an exception.
:meth:`_result.Result.one`
"""
- return self._only_one_row(True, False)
+ return self._only_one_row(True, False, False)
+
+ def scalar_one(self):
+ """Return exactly one scalar result or raise an exception.
+
+ This is equvalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one`.
+
+ .. seealso::
+
+ :meth:`.Result.one`
+
+ :meth:`.Result.scalars`
+
+ """
+ return self._only_one_row(True, True, True)
+
+ def scalar_one_or_none(self):
+ """Return exactly one or no scalar result.
+
+ This is equvalent to calling :meth:`.Result.scalars` and then
+ :meth:`.Result.one_or_none`.
+
+ .. seealso::
+
+ :meth:`.Result.one_or_none`
+
+ :meth:`.Result.scalars`
+
+ """
+ return self._only_one_row(True, False, True)
def one(self):
- """Return exactly one result or raise an exception.
+ """Return exactly one row or raise an exception.
Raises :class:`.NoResultFound` if the result returns no
rows, or :class:`.MultipleResultsFound` if multiple rows
would be returned.
+ .. note:: This method returns one **row**, e.g. tuple, by default.
+ To return exactly one single scalar value, that is, the first
+ column of the first row, use the :meth:`.Result.scalar_one` method,
+ or combine :meth:`.Result.scalars` and :meth:`.Result.one`.
+
.. versionadded:: 1.4
:return: The first :class:`.Row`.
:meth:`_result.Result.one_or_none`
+ :meth:`_result.Result.scalar_one`
+
"""
- return self._only_one_row(True, True)
+ return self._only_one_row(True, True, False)
def scalar(self):
"""Fetch the first column of the first row, and close the result set.
+ Returns None if there are no rows to fetch.
+
+ No validation is performed to test if additional rows remain.
+
After calling this method, the object is fully closed,
e.g. the :meth:`_engine.CursorResult.close`
method will have been called.
- :return: a Python scalar value , or None if no rows remain
+ :return: a Python scalar value , or None if no rows remain.
"""
- row = self.first()
- if row is not None:
- return row[0]
- else:
- return None
+ return self._only_one_row(False, False, True)
class FrozenResult(object):
"""
return self.execution_options(_sa_shard_id=shard_id)
- def _execute_crud(self, stmt, mapper):
- def exec_for_shard(shard_id):
- conn = self.session.connection(
- mapper=mapper,
- shard_id=shard_id,
- clause=stmt,
- close_with_result=True,
- )
- result = conn._execute_20(
- stmt, self.load_options._params, self._execution_options
- )
- return result
-
- if self._shard_id is not None:
- return exec_for_shard(self._shard_id)
- else:
- rowcount = 0
- results = []
- # TODO: this will have to be the new object
- for shard_id in self.execute_chooser(self):
- result = exec_for_shard(shard_id)
- rowcount += result.rowcount
- results.append(result)
-
- return ShardedResult(results, rowcount)
-
-
-class ShardedResult(object):
- """A value object that represents multiple :class:`_engine.CursorResult`
- objects.
-
- This is used by the :meth:`.ShardedQuery._execute_crud` hook to return
- an object that takes the place of the single :class:`_engine.CursorResult`.
-
- Attribute include ``result_proxies``, which is a sequence of the
- actual :class:`_engine.CursorResult` objects,
- as well as ``aggregate_rowcount``
- or ``rowcount``, which is the sum of all the individual rowcount values.
-
- .. versionadded:: 1.3
- """
-
- __slots__ = ("result_proxies", "aggregate_rowcount")
-
- def __init__(self, result_proxies, aggregate_rowcount):
- self.result_proxies = result_proxies
- self.aggregate_rowcount = aggregate_rowcount
-
- @property
- def rowcount(self):
- return self.aggregate_rowcount
-
class ShardedSession(Session):
def __init__(
def execute_and_instances(orm_context):
- if orm_context.bind_arguments.get("_horizontal_shard", False):
- return None
-
params = orm_context.parameters
- load_options = orm_context.load_options
+ if orm_context.is_select:
+ load_options = active_options = orm_context.load_options
+ update_options = None
+ if params is None:
+ params = active_options._params
+
+ else:
+ load_options = None
+ update_options = active_options = orm_context.update_delete_options
+
session = orm_context.session
# orm_query = orm_context.orm_query
- if params is None:
- params = load_options._params
-
- def iter_for_shard(shard_id, load_options):
+ def iter_for_shard(shard_id, load_options, update_options):
execution_options = dict(orm_context.local_execution_options)
bind_arguments = dict(orm_context.bind_arguments)
- bind_arguments["_horizontal_shard"] = True
bind_arguments["shard_id"] = shard_id
- load_options += {"_refresh_identity_token": shard_id}
- execution_options["_sa_orm_load_options"] = load_options
+ if orm_context.is_select:
+ load_options += {"_refresh_identity_token": shard_id}
+ execution_options["_sa_orm_load_options"] = load_options
+ else:
+ update_options += {"_refresh_identity_token": shard_id}
+ execution_options["_sa_orm_update_options"] = update_options
- return session.execute(
- orm_context.statement,
- orm_context.parameters,
- execution_options,
- bind_arguments,
+ return orm_context.invoke_statement(
+ bind_arguments=bind_arguments, execution_options=execution_options
)
- if load_options._refresh_identity_token is not None:
- shard_id = load_options._refresh_identity_token
+ if active_options._refresh_identity_token is not None:
+ shard_id = active_options._refresh_identity_token
elif "_sa_shard_id" in orm_context.merged_execution_options:
shard_id = orm_context.merged_execution_options["_sa_shard_id"]
elif "shard_id" in orm_context.bind_arguments:
shard_id = None
if shard_id is not None:
- return iter_for_shard(shard_id, load_options)
+ return iter_for_shard(shard_id, load_options, update_options)
else:
partial = []
for shard_id in session.execute_chooser(orm_context):
- result_ = iter_for_shard(shard_id, load_options)
+ result_ = iter_for_shard(shard_id, load_options, update_options)
partial.append(result_)
return partial[0].merge(*partial[1:])
from .. import util
from ..orm import attributes
from ..orm import interfaces
-
+from ..sql import elements
HYBRID_METHOD = util.symbol("HYBRID_METHOD")
"""Symbol indicating an :class:`InspectionAttr` that's
return self.hybrid.info
def _bulk_update_tuples(self, value):
+ if isinstance(value, elements.BindParameter):
+ value = value.value
+
if isinstance(self.expression, attributes.QueryableAttribute):
return self.expression._bulk_update_tuples(value)
elif self.hybrid.update_expr is not None:
@classmethod
def orm_pre_session_exec(
- cls, session, statement, execution_options, bind_arguments
+ cls, session, statement, params, execution_options, bind_arguments
):
load_options = execution_options.get(
"_sa_orm_load_options", QueryContext.default_load_options
if load_options._autoflush:
session._autoflush()
+ return execution_options
+
@classmethod
def orm_setup_cursor_result(
cls, session, statement, execution_options, bind_arguments, result
return CompositeProperty.CompositeBundle(self.prop, clauses)
def _bulk_update_tuples(self, value):
+ if isinstance(value, sql.elements.BindParameter):
+ value = value.value
+
if value is None:
values = [None for key in self.prop._attribute_keys]
elif isinstance(value, self.prop.composite_class):
lambda update_context: (
update_context.session,
update_context.query,
- update_context.context,
+ None,
update_context.result,
),
)
was called upon.
* ``values`` The "values" dictionary that was passed to
:meth:`_query.Query.update`.
- * ``context`` The :class:`.QueryContext` object, corresponding
- to the invocation of an ORM query.
* ``result`` the :class:`_engine.CursorResult`
returned as a result of the
bulk UPDATE operation.
+ .. versionchanged:: 1.4 the update_context no longer has a
+ ``QueryContext`` object associated with it.
+
.. seealso::
:meth:`.QueryEvents.before_compile_update`
lambda delete_context: (
delete_context.session,
delete_context.query,
- delete_context.context,
+ None,
delete_context.result,
),
)
* ``query`` -the :class:`_query.Query`
object that this update operation
was called upon.
- * ``context`` The :class:`.QueryContext` object, corresponding
- to the invocation of an ORM query.
* ``result`` the :class:`_engine.CursorResult`
returned as a result of the
bulk DELETE operation.
+ .. versionchanged:: 1.4 the update_context no longer has a
+ ``QueryContext`` object associated with it.
+
.. seealso::
:meth:`.QueryEvents.before_compile_delete`
@HasMemoized.memoized_instancemethod
def __clause_element__(self):
- return self.selectable._annotate(
- {
- "entity_namespace": self,
- "parententity": self,
- "parentmapper": self,
- "compile_state_plugin": "orm",
- }
- )._set_propagate_attrs(
+
+ annotations = {
+ "entity_namespace": self,
+ "parententity": self,
+ "parentmapper": self,
+ "compile_state_plugin": "orm",
+ }
+ if self.persist_selectable is not self.local_table:
+ # joined table inheritance, with polymorphic selectable,
+ # etc.
+ annotations["dml_table"] = self.local_table._annotate(
+ {
+ "entity_namespace": self,
+ "parententity": self,
+ "parentmapper": self,
+ "compile_state_plugin": "orm",
+ }
+ )._set_propagate_attrs(
+ {"compile_state_plugin": "orm", "plugin_subject": self}
+ )
+
+ return self.selectable._annotate(annotations)._set_propagate_attrs(
{"compile_state_plugin": "orm", "plugin_subject": self}
)
from .. import future
from .. import sql
from .. import util
+from ..future import select as future_select
from ..sql import coercions
from ..sql import expression
from ..sql import operators
from ..sql import roles
-from ..sql.base import _from_objects
+from ..sql.base import CompileState
+from ..sql.base import Options
+from ..sql.dml import DeleteDMLState
+from ..sql.dml import UpdateDMLState
from ..sql.elements import BooleanClauseList
)
-class BulkUD(object):
- """Handle bulk update and deletes via a :class:`_query.Query`."""
+_EMPTY_DICT = util.immutabledict()
- def __init__(self, query):
- self.query = query.enable_eagerloads(False)
- self._validate_query_state()
- def _validate_query_state(self):
- for attr, methname, notset, op in (
- ("_limit_clause", "limit()", None, operator.is_),
- ("_offset_clause", "offset()", None, operator.is_),
- ("_order_by_clauses", "order_by()", (), operator.eq),
- ("_group_by_clauses", "group_by()", (), operator.eq),
- ("_distinct", "distinct()", False, operator.is_),
- (
- "_from_obj",
- "join(), outerjoin(), select_from(), or from_self()",
- (),
- operator.eq,
- ),
- (
- "_legacy_setup_joins",
- "join(), outerjoin(), select_from(), or from_self()",
- (),
- operator.eq,
- ),
- ):
- if not op(getattr(self.query, attr), notset):
- raise sa_exc.InvalidRequestError(
- "Can't call Query.update() or Query.delete() "
- "when %s has been called" % (methname,)
- )
-
- @property
- def session(self):
- return self.query.session
+class BulkUDCompileState(CompileState):
+ class default_update_options(Options):
+ _synchronize_session = "evaluate"
+ _autoflush = True
+ _subject_mapper = None
+ _resolved_values = _EMPTY_DICT
+ _resolved_keys_as_propnames = _EMPTY_DICT
+ _value_evaluators = _EMPTY_DICT
+ _matched_objects = None
+ _matched_rows = None
+ _refresh_identity_token = None
@classmethod
- def _factory(cls, lookup, synchronize_session, *arg):
- try:
- klass = lookup[synchronize_session]
- except KeyError as err:
- util.raise_(
- sa_exc.ArgumentError(
- "Valid strategies for session synchronization "
- "are %s" % (", ".join(sorted(repr(x) for x in lookup)))
- ),
- replace_context=err,
+ def orm_pre_session_exec(
+ cls, session, statement, params, execution_options, bind_arguments
+ ):
+ sync = execution_options.get("synchronize_session", None)
+ if sync is None:
+ sync = statement._execution_options.get(
+ "synchronize_session", None
)
- else:
- return klass(*arg)
-
- def exec_(self):
- self._do_before_compile()
- self._do_pre()
- self._do_pre_synchronize()
- self._do_exec()
- self._do_post_synchronize()
- self._do_post()
-
- def _execute_stmt(self, stmt):
- self.result = self.query._execute_crud(stmt, self.mapper)
- self.rowcount = self.result.rowcount
-
- def _do_before_compile(self):
- raise NotImplementedError()
- @util.preload_module("sqlalchemy.orm.context")
- def _do_pre(self):
- query_context = util.preloaded.orm_context
- query = self.query
-
- self.compile_state = (
- self.context
- ) = compile_state = query._compile_state()
-
- self.mapper = compile_state._entity_zero()
-
- if isinstance(
- compile_state._entities[0], query_context._RawColumnEntity,
- ):
- # check for special case of query(table)
- tables = set()
- for ent in compile_state._entities:
- if not isinstance(ent, query_context._RawColumnEntity,):
- tables.clear()
- break
- else:
- tables.update(_from_objects(ent.column))
+ update_options = execution_options.get(
+ "_sa_orm_update_options",
+ BulkUDCompileState.default_update_options,
+ )
- if len(tables) != 1:
- raise sa_exc.InvalidRequestError(
- "This operation requires only one Table or "
- "entity be specified as the target."
+ if sync is not None:
+ if sync not in ("evaluate", "fetch", False):
+ raise sa_exc.ArgumentError(
+ "Valid strategies for session synchronization "
+ "are 'evaluate', 'fetch', False"
)
- else:
- self.primary_table = tables.pop()
+ update_options += {"_synchronize_session": sync}
+ bind_arguments["clause"] = statement
+ try:
+ plugin_subject = statement._propagate_attrs["plugin_subject"]
+ except KeyError:
+ assert False, "statement had 'orm' plugin but no plugin_subject"
else:
- self.primary_table = compile_state._only_entity_zero(
- "This operation requires only one Table or "
- "entity be specified as the target."
- ).mapper.local_table
+ bind_arguments["mapper"] = plugin_subject.mapper
- session = query.session
+ update_options += {"_subject_mapper": plugin_subject.mapper}
- if query.load_options._autoflush:
+ if update_options._autoflush:
session._autoflush()
- def _do_pre_synchronize(self):
- pass
+ if update_options._synchronize_session == "evaluate":
+ update_options = cls._do_pre_synchronize_evaluate(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
+ elif update_options._synchronize_session == "fetch":
+ update_options = cls._do_pre_synchronize_fetch(
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ )
- def _do_post_synchronize(self):
- pass
+ return util.immutabledict(execution_options).union(
+ dict(_sa_orm_update_options=update_options)
+ )
+ @classmethod
+ def orm_setup_cursor_result(
+ cls, session, statement, execution_options, bind_arguments, result
+ ):
+ update_options = execution_options["_sa_orm_update_options"]
+ if update_options._synchronize_session == "evaluate":
+ cls._do_post_synchronize_evaluate(session, update_options)
+ elif update_options._synchronize_session == "fetch":
+ cls._do_post_synchronize_fetch(session, update_options)
-class BulkEvaluate(BulkUD):
- """BulkUD which does the 'evaluate' method of session state resolution."""
+ return result
- def _additional_evaluators(self, evaluator_compiler):
- pass
+ @classmethod
+ def _do_pre_synchronize_evaluate(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
+ mapper = update_options._subject_mapper
+ target_cls = mapper.class_
- def _do_pre_synchronize(self):
- query = self.query
- target_cls = self.compile_state._mapper_zero().class_
+ value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
try:
evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- if query._where_criteria:
+ if statement._where_criteria:
eval_condition = evaluator_compiler.process(
- *query._where_criteria
+ *statement._where_criteria
)
else:
def eval_condition(obj):
return True
- self._additional_evaluators(evaluator_compiler)
+ # TODO: something more robust for this conditional
+ if statement.__visit_name__ == "update":
+ resolved_values = cls._get_resolved_values(mapper, statement)
+ value_evaluators = {}
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ for key, value in resolved_keys_as_propnames:
+ value_evaluators[key] = evaluator_compiler.process(
+ coercions.expect(roles.ExpressionElementRole, value)
+ )
except evaluator.UnevaluatableError as err:
util.raise_(
sa_exc.InvalidRequestError(
'Could not evaluate current criteria in Python: "%s". '
"Specify 'fetch' or False for the "
- "synchronize_session parameter." % err
+ "synchronize_session execution option." % err
),
from_=err,
)
# TODO: detect when the where clause is a trivial primary key match
- self.matched_objects = [
+ matched_objects = [
obj
- for (
- cls,
- pk,
- identity_token,
- ), obj in query.session.identity_map.items()
- if issubclass(cls, target_cls) and eval_condition(obj)
+ for (cls, pk, identity_token,), obj in session.identity_map.items()
+ if issubclass(cls, target_cls)
+ and eval_condition(obj)
+ and identity_token == update_options._refresh_identity_token
]
-
-
-class BulkFetch(BulkUD):
- """BulkUD which does the 'fetch' method of session state resolution."""
-
- def _do_pre_synchronize(self):
- query = self.query
- session = query.session
- select_stmt = self.compile_state.statement.with_only_columns(
- self.primary_table.primary_key
- )
- self.matched_rows = session.execute(
- select_stmt, mapper=self.mapper, params=query.load_options._params
- ).fetchall()
-
-
-class BulkUpdate(BulkUD):
- """BulkUD which handles UPDATEs."""
-
- def __init__(self, query, values, update_kwargs):
- super(BulkUpdate, self).__init__(query)
- self.values = values
- self.update_kwargs = update_kwargs
+ return update_options + {
+ "_matched_objects": matched_objects,
+ "_value_evaluators": value_evaluators,
+ "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ }
@classmethod
- def factory(cls, query, synchronize_session, values, update_kwargs):
- return BulkUD._factory(
- {
- "evaluate": BulkUpdateEvaluate,
- "fetch": BulkUpdateFetch,
- False: BulkUpdate,
- },
- synchronize_session,
- query,
- values,
- update_kwargs,
- )
-
- def _do_before_compile(self):
- if self.query.dispatch.before_compile_update:
- for fn in self.query.dispatch.before_compile_update:
- new_query = fn(self.query, self)
- if new_query is not None:
- self.query = new_query
+ def _get_resolved_values(cls, mapper, statement):
+ if statement._multi_values:
+ return []
+ elif statement._ordered_values:
+ iterator = statement._ordered_values
+ elif statement._values:
+ iterator = statement._values.items()
+ else:
+ return []
- @property
- def _resolved_values(self):
values = []
- for k, v in (
- self.values.items()
- if hasattr(self.values, "items")
- else self.values
- ):
- if self.mapper:
- if isinstance(k, util.string_types):
- desc = sql.util._entity_namespace_key(self.mapper, k)
- values.extend(desc._bulk_update_tuples(v))
- elif isinstance(k, attributes.QueryableAttribute):
- values.extend(k._bulk_update_tuples(v))
+ if iterator:
+ for k, v in iterator:
+ if mapper:
+ if isinstance(k, util.string_types):
+ desc = sql.util._entity_namespace_key(mapper, k)
+ values.extend(desc._bulk_update_tuples(v))
+ elif isinstance(k, attributes.QueryableAttribute):
+ values.extend(k._bulk_update_tuples(v))
+ else:
+ values.append((k, v))
else:
values.append((k, v))
- else:
- values.append((k, v))
return values
- @property
- def _resolved_values_keys_as_propnames(self):
+ @classmethod
+ def _resolved_keys_as_propnames(cls, mapper, resolved_values):
values = []
- for k, v in self._resolved_values:
+ for k, v in resolved_values:
if isinstance(k, attributes.QueryableAttribute):
values.append((k.key, v))
continue
elif hasattr(k, "__clause_element__"):
k = k.__clause_element__()
- if self.mapper and isinstance(k, expression.ColumnElement):
+ if mapper and isinstance(k, expression.ColumnElement):
try:
- attr = self.mapper._columntoproperty[k]
+ attr = mapper._columntoproperty[k]
except orm_exc.UnmappedColumnError:
pass
else:
)
return values
- def _do_exec(self):
- values = self._resolved_values
+ @classmethod
+ def _do_pre_synchronize_fetch(
+ cls,
+ session,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ update_options,
+ ):
+ mapper = update_options._subject_mapper
- if not self.update_kwargs.get("preserve_parameter_order", False):
- values = dict(values)
+ if mapper:
+ primary_table = mapper.local_table
+ else:
+ primary_table = statement._raw_columns[0]
- update_stmt = sql.update(
- self.primary_table, **self.update_kwargs
- ).values(values)
+ # note this creates a Select() *without* the ORM plugin.
+ # we don't want that here.
+ select_stmt = future_select(*primary_table.primary_key)
+ select_stmt._where_criteria = statement._where_criteria
- update_stmt._where_criteria = self.compile_state._where_criteria
+ matched_rows = session.execute(
+ select_stmt, params, execution_options, bind_arguments
+ ).fetchall()
- self._execute_stmt(update_stmt)
+ if statement.__visit_name__ == "update":
+ resolved_values = cls._get_resolved_values(mapper, statement)
+ resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+ mapper, resolved_values
+ )
+ else:
+ resolved_keys_as_propnames = _EMPTY_DICT
- def _do_post(self):
- session = self.query.session
- session.dispatch.after_bulk_update(self)
+ return update_options + {
+ "_matched_rows": matched_rows,
+ "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+ }
-class BulkDelete(BulkUD):
- """BulkUD which handles DELETEs."""
+@CompileState.plugin_for("orm", "update")
+class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
- def __init__(self, query):
- super(BulkDelete, self).__init__(query)
+ self = cls.__new__(cls)
- @classmethod
- def factory(cls, query, synchronize_session):
- return BulkUD._factory(
- {
- "evaluate": BulkDeleteEvaluate,
- "fetch": BulkDeleteFetch,
- False: BulkDelete,
- },
- synchronize_session,
- query,
+ self.mapper = mapper = statement.table._annotations.get(
+ "parentmapper", None
)
- def _do_before_compile(self):
- if self.query.dispatch.before_compile_delete:
- for fn in self.query.dispatch.before_compile_delete:
- new_query = fn(self.query, self)
- if new_query is not None:
- self.query = new_query
+ self._resolved_values = cls._get_resolved_values(mapper, statement)
- def _do_exec(self):
- delete_stmt = sql.delete(self.primary_table,)
- delete_stmt._where_criteria = self.compile_state._where_criteria
+ if not statement._preserve_parameter_order and statement._values:
+ self._resolved_values = dict(self._resolved_values)
- self._execute_stmt(delete_stmt)
+ new_stmt = sql.Update.__new__(sql.Update)
+ new_stmt.__dict__.update(statement.__dict__)
+ new_stmt.table = mapper.local_table
- def _do_post(self):
- session = self.query.session
- session.dispatch.after_bulk_delete(self)
+ # note if the statement has _multi_values, these
+ # are passed through to the new statement, which will then raise
+ # InvalidRequestError because UPDATE doesn't support multi_values
+ # right now.
+ if statement._ordered_values:
+ new_stmt._ordered_values = self._resolved_values
+ elif statement._values:
+ new_stmt._values = self._resolved_values
+ UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
-class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
- """BulkUD which handles UPDATEs using the "evaluate"
- method of session resolution."""
+ return self
- def _additional_evaluators(self, evaluator_compiler):
- self.value_evaluators = {}
- values = self._resolved_values_keys_as_propnames
- for key, value in values:
- self.value_evaluators[key] = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
+ @classmethod
+ def _do_post_synchronize_evaluate(cls, session, update_options):
- def _do_post_synchronize(self):
- session = self.query.session
states = set()
- evaluated_keys = list(self.value_evaluators.keys())
- for obj in self.matched_objects:
+ evaluated_keys = list(update_options._value_evaluators.keys())
+ for obj in update_options._matched_objects:
+
state, dict_ = (
attributes.instance_state(obj),
attributes.instance_dict(obj),
)
+ assert (
+ state.identity_token == update_options._refresh_identity_token
+ )
+
# only evaluate unmodified attributes
to_evaluate = state.unmodified.intersection(evaluated_keys)
for key in to_evaluate:
- dict_[key] = self.value_evaluators[key](obj)
+ dict_[key] = update_options._value_evaluators[key](obj)
state.manager.dispatch.refresh(state, None, to_evaluate)
states.add(state)
session._register_altered(states)
-
-class BulkDeleteEvaluate(BulkEvaluate, BulkDelete):
- """BulkUD which handles DELETEs using the "evaluate"
- method of session resolution."""
-
- def _do_post_synchronize(self):
- self.query.session._remove_newly_deleted(
- [attributes.instance_state(obj) for obj in self.matched_objects]
- )
-
-
-class BulkUpdateFetch(BulkFetch, BulkUpdate):
- """BulkUD which handles UPDATEs using the "fetch"
- method of session resolution."""
-
- def _do_post_synchronize(self):
- session = self.query.session
- target_mapper = self.compile_state._mapper_zero()
+ @classmethod
+ def _do_post_synchronize_fetch(cls, session, update_options):
+ target_mapper = update_options._subject_mapper
states = set(
[
attributes.instance_state(session.identity_map[identity_key])
for identity_key in [
target_mapper.identity_key_from_primary_key(
- list(primary_key)
+ list(primary_key),
+ identity_token=update_options._refresh_identity_token,
)
- for primary_key in self.matched_rows
+ for primary_key in update_options._matched_rows
]
if identity_key in session.identity_map
]
)
- values = self._resolved_values_keys_as_propnames
+ values = update_options._resolved_keys_as_propnames
attrib = set(k for k, v in values)
for state in states:
to_expire = attrib.intersection(state.dict)
session._register_altered(states)
-class BulkDeleteFetch(BulkFetch, BulkDelete):
- """BulkUD which handles DELETEs using the "fetch"
- method of session resolution."""
+@CompileState.plugin_for("orm", "delete")
+class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
+ @classmethod
+ def create_for_statement(cls, statement, compiler, **kw):
+ self = cls.__new__(cls)
+
+ self.mapper = statement.table._annotations.get("parentmapper", None)
+
+ DeleteDMLState.__init__(self, statement, compiler, **kw)
+
+ return self
+
+ @classmethod
+ def _do_post_synchronize_evaluate(cls, session, update_options):
+
+ session._remove_newly_deleted(
+ [
+ attributes.instance_state(obj)
+ for obj in update_options._matched_objects
+ ]
+ )
+
+ @classmethod
+ def _do_post_synchronize_fetch(cls, session, update_options):
+ target_mapper = update_options._subject_mapper
- def _do_post_synchronize(self):
- session = self.query.session
- target_mapper = self.compile_state._mapper_zero()
- for primary_key in self.matched_rows:
+ for primary_key in update_options._matched_rows:
# TODO: inline this and call remove_newly_deleted
# once
identity_key = target_mapper.identity_key_from_primary_key(
- list(primary_key)
+ list(primary_key),
+ identity_token=update_options._refresh_identity_token,
)
if identity_key in session.identity_map:
session._remove_newly_deleted(
"""
import itertools
+import operator
from . import attributes
from . import exc as orm_exc
from . import interfaces
from . import loading
-from . import persistence
from .base import _assertions
from .context import _column_descriptions
from .context import _legacy_determine_last_joined_entity
return result
- def _execute_crud(self, stmt, mapper):
- conn = self.session.connection(
- mapper=mapper, clause=stmt, close_with_result=True
- )
-
- return conn._execute_20(
- stmt, self.load_options._params, self._execution_options
- )
-
def __str__(self):
statement = self._statement_20()
"""
- delete_op = persistence.BulkDelete.factory(self, synchronize_session)
- delete_op.exec_()
- return delete_op.rowcount
+ bulk_del = BulkDelete(self,)
+ if self.dispatch.before_compile_delete:
+ for fn in self.dispatch.before_compile_delete:
+ new_query = fn(bulk_del.query, bulk_del)
+ if new_query is not None:
+ bulk_del.query = new_query
+
+ self = bulk_del.query
+
+ delete_ = sql.delete(*self._raw_columns)
+ delete_._where_criteria = self._where_criteria
+ result = self.session.execute(
+ delete_,
+ self.load_options._params,
+ execution_options={"synchronize_session": synchronize_session},
+ )
+ bulk_del.result = result
+ self.session.dispatch.after_bulk_delete(bulk_del)
+ result.close()
+
+ return result.rowcount
def update(self, values, synchronize_session="evaluate", update_args=None):
r"""Perform a bulk update query.
"""
update_args = update_args or {}
- update_op = persistence.BulkUpdate.factory(
- self, synchronize_session, values, update_args
+
+ bulk_ud = BulkUpdate(self, values, update_args)
+
+ if self.dispatch.before_compile_update:
+ for fn in self.dispatch.before_compile_update:
+ new_query = fn(bulk_ud.query, bulk_ud)
+ if new_query is not None:
+ bulk_ud.query = new_query
+ self = bulk_ud.query
+
+ upd = sql.update(*self._raw_columns, **update_args).values(values)
+ upd._where_criteria = self._where_criteria
+ result = self.session.execute(
+ upd,
+ self.load_options._params,
+ execution_options={"synchronize_session": synchronize_session},
)
- update_op.exec_()
- return update_op.rowcount
+ bulk_ud.result = result
+ self.session.dispatch.after_bulk_update(bulk_ud)
+ result.close()
+ return result.rowcount
def _compile_state(self, for_statement=False, **kw):
"""Create an out-of-compiler ORMCompileState object.
def process_compile_state(self, compile_state):
pass
+
+
+class BulkUD(object):
+ """State used for the orm.Query version of update() / delete().
+
+ This object is now specific to Query only.
+
+ """
+
+ def __init__(self, query):
+ self.query = query.enable_eagerloads(False)
+ self._validate_query_state()
+ self.mapper = self.query._entity_from_pre_ent_zero()
+
+ def _validate_query_state(self):
+ for attr, methname, notset, op in (
+ ("_limit_clause", "limit()", None, operator.is_),
+ ("_offset_clause", "offset()", None, operator.is_),
+ ("_order_by_clauses", "order_by()", (), operator.eq),
+ ("_group_by_clauses", "group_by()", (), operator.eq),
+ ("_distinct", "distinct()", False, operator.is_),
+ (
+ "_from_obj",
+ "join(), outerjoin(), select_from(), or from_self()",
+ (),
+ operator.eq,
+ ),
+ (
+ "_legacy_setup_joins",
+ "join(), outerjoin(), select_from(), or from_self()",
+ (),
+ operator.eq,
+ ),
+ ):
+ if not op(getattr(self.query, attr), notset):
+ raise sa_exc.InvalidRequestError(
+ "Can't call Query.update() or Query.delete() "
+ "when %s has been called" % (methname,)
+ )
+
+ @property
+ def session(self):
+ return self.query.session
+
+
+class BulkUpdate(BulkUD):
+ """BulkUD which handles UPDATEs."""
+
+ def __init__(self, query, values, update_kwargs):
+ super(BulkUpdate, self).__init__(query)
+ self.values = values
+ self.update_kwargs = update_kwargs
+
+
+class BulkDelete(BulkUD):
+ """BulkUD which handles DELETEs."""
from .. import util
from ..inspection import inspect
from ..sql import coercions
+from ..sql import dml
from ..sql import roles
+from ..sql import selectable
from ..sql import visitors
from ..sql.base import CompileState
"_execution_options",
"_merged_execution_options",
"bind_arguments",
+ "_compile_state_cls",
)
def __init__(
- self, session, statement, parameters, execution_options, bind_arguments
+ self,
+ session,
+ statement,
+ parameters,
+ execution_options,
+ bind_arguments,
+ compile_state_cls,
):
self.session = session
self.statement = statement
self.parameters = parameters
self._execution_options = execution_options
self.bind_arguments = bind_arguments
+ self._compile_state_cls = compile_state_cls
def invoke_statement(
self,
statement, _params, _execution_options, _bind_arguments
)
+ @property
+ def is_orm_statement(self):
+ """return True if the operation is an ORM statement.
+
+ This indictes that the select(), update(), or delete() being
+ invoked contains ORM entities as subjects. For a statement
+ that does not have ORM entities and instead refers only to
+ :class:`.Table` metadata, it is invoked as a Core SQL statement
+ and no ORM-level automation takes place.
+
+ """
+ return self._compile_state_cls is not None
+
+ @property
+ def is_select(self):
+ """return True if this is a SELECT operation."""
+ return isinstance(self.statement, selectable.Select)
+
+ @property
+ def is_update(self):
+ """return True if this is an UPDATE operation."""
+ return isinstance(self.statement, dml.Update)
+
+ @property
+ def is_delete(self):
+ """return True if this is a DELETE operation."""
+ return isinstance(self.statement, dml.Delete)
+
+ @property
+ def _is_crud(self):
+ return isinstance(self.statement, (dml.Update, dml.Delete))
+
@property
def execution_options(self):
"""Placeholder for execution options.
def load_options(self):
"""Return the load_options that will be used for this execution."""
+ if not self.is_select:
+ raise sa_exc.InvalidRequestError(
+ "This ORM execution is not against a SELECT statement "
+ "so there are no load options."
+ )
return self._execution_options.get(
"_sa_orm_load_options", context.QueryContext.default_load_options
)
+ @property
+ def update_delete_options(self):
+ """Return the update_delete_options that will be used for this
+ execution."""
+
+ if not self._is_crud:
+ raise sa_exc.InvalidRequestError(
+ "This ORM execution is not against an UPDATE or DELETE "
+ "statement so there are no update options."
+ )
+ return self._execution_options.get(
+ "_sa_orm_update_options",
+ persistence.BulkUDCompileState.default_update_options,
+ )
+
@property
def user_defined_options(self):
"""The sequence of :class:`.UserDefinedOptions` that have been
compile_state_cls = CompileState._get_plugin_class_for_plugin(
statement, "orm"
)
+ else:
+ compile_state_cls = None
- compile_state_cls.orm_pre_session_exec(
- self, statement, execution_options, bind_arguments
+ if compile_state_cls is not None:
+ execution_options = compile_state_cls.orm_pre_session_exec(
+ self, statement, params, execution_options, bind_arguments
)
-
- if self.dispatch.do_orm_execute:
- skip_events = bind_arguments.pop("_sa_skip_events", False)
-
- if not skip_events:
- orm_exec_state = ORMExecuteState(
- self,
- statement,
- params,
- execution_options,
- bind_arguments,
- )
- for fn in self.dispatch.do_orm_execute:
- result = fn(orm_exec_state)
- if result:
- return result
-
else:
- compile_state_cls = None
bind_arguments.setdefault("clause", statement)
if statement._is_future:
execution_options = util.immutabledict().merge_with(
execution_options, {"future_result": True}
)
+ if self.dispatch.do_orm_execute:
+ # run this event whether or not we are in ORM mode
+ skip_events = bind_arguments.get("_sa_skip_events", False)
+ if not skip_events:
+ orm_exec_state = ORMExecuteState(
+ self,
+ statement,
+ params,
+ execution_options,
+ bind_arguments,
+ compile_state_cls,
+ )
+ for fn in self.dispatch.do_orm_execute:
+ result = fn(orm_exec_state)
+ if result:
+ return result
+
bind = self.get_bind(**bind_arguments)
conn = self._connection_for_bind(bind, close_with_result=True)
self.__binds[insp] = bind
elif insp.is_mapper:
self.__binds[insp.class_] = bind
- for selectable in insp._all_tables:
- self.__binds[selectable] = bind
+ for _selectable in insp._all_tables:
+ self.__binds[_selectable] = bind
else:
raise sa_exc.ArgumentError(
"Not an acceptable bind target: %s" % key
"""
self._add_bind(table, bind)
- def get_bind(self, mapper=None, clause=None, bind=None):
+ def get_bind(
+ self, mapper=None, clause=None, bind=None, _sa_skip_events=None
+ ):
"""Return a "bind" to which this :class:`.Session` is bound.
The "bind" is usually an instance of :class:`_engine.Engine`,
plugin_name = statement._propagate_attrs.get(
"compile_state_plugin", "default"
)
- else:
- plugin_name = "default"
+ klass = cls.plugins.get(
+ (plugin_name, statement.__visit_name__), None
+ )
+ if klass is None:
+ klass = cls.plugins[("default", statement.__visit_name__)]
- klass = cls.plugins[(plugin_name, statement.__visit_name__)]
+ else:
+ klass = cls.plugins[("default", statement.__visit_name__)]
if klass is cls:
return cls(statement, compiler, **kw)
return element.alias(name=name, flat=flat)
+class DMLTableImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
+ __slots__ = ()
+
+ def _post_coercion(self, element, **kw):
+ if "dml_table" in element._annotations:
+ return element._annotations["dml_table"]
+ else:
+ return element
+
+
class DMLSelectImpl(_NoTextCoercion, RoleImpl):
__slots__ = ()
toplevel = not self.stack
if toplevel:
self.isupdate = True
+ if not self.compile_state:
+ self.compile_state = compile_state
extra_froms = compile_state._extra_froms
is_multitable = bool(extra_froms)
toplevel = not self.stack
if toplevel:
self.isdelete = True
+ if not self.compile_state:
+ self.compile_state = compile_state
extra_froms = compile_state._extra_froms
from .base import DialectKWArgs
from .base import Executable
from .base import HasCompileState
+from .elements import BooleanClauseList
from .elements import ClauseElement
from .elements import Null
from .selectable import HasCTE
def __init__(self, statement, compiler, **kw):
self.statement = statement
-
self.isupdate = True
self._preserve_parameter_order = statement._preserve_parameter_order
if statement._ordered_values is not None:
_returning = ()
def __init__(self, table, values, prefixes):
- self.table = coercions.expect(roles.FromClauseRole, table)
+ self.table = coercions.expect(
+ roles.DMLTableRole, table, apply_propagate_attrs=self
+ )
if values is not None:
self.values.non_generative(self, values)
if prefixes:
coercions.expect(roles.WhereHavingRole, whereclause),
)
+ def filter(self, *criteria):
+ """A synonym for the :meth:`_dml.DMLWhereBase.where` method."""
+
+ return self.where(*criteria)
+
+ @property
+ def whereclause(self):
+ """Return the completed WHERE clause for this :class:`.DMLWhereBase`
+ statement.
+
+ This assembles the current collection of WHERE criteria
+ into a single :class:`_expression.BooleanClauseList` construct.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ return BooleanClauseList._construct_for_whereclause(
+ self._where_criteria
+ )
+
class Update(DMLWhereBase, ValuesBase):
"""Represent an Update construct.
"""
self._bind = bind
- self.table = coercions.expect(roles.FromClauseRole, table)
+ self.table = coercions.expect(
+ roles.DMLTableRole, table, apply_propagate_attrs=self
+ )
self._returning = returning
if prefixes:
)
+# TODO: are we using this?
class DMLRole(StatementRole):
pass
+class DMLTableRole(FromClauseRole):
+ _role_name = "subject table for an INSERT, UPDATE or DELETE"
+
+
class DMLColumnRole(SQLRole):
_role_name = "SET/VALUES column expression or string key"
self._reset_column_collection()
-class Join(FromClause):
+class Join(roles.DMLTableRole, FromClause):
"""represent a ``JOIN`` construct between two
:class:`_expression.FromClause`
elements.
return self.element.bind
-class Alias(AliasedReturnsRows):
+class Alias(roles.DMLTableRole, AliasedReturnsRows):
"""Represents an table or selectable alias (AS).
Represents an alias, as typically applied to any table or
self.element = state["element"]
-class TableClause(Immutable, FromClause):
+class TableClause(roles.DMLTableRole, Immutable, FromClause):
"""Represents a minimal "table" construct.
This is a lightweight table object that has only a name, a
from ..inspection import inspect
from ..util import collections_abc
from ..util import HasMemoized
+from ..util import py37
SKIP_TRAVERSE = util.symbol("skip_traverse")
COMPARE_FAILED = False
)
def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+ if py37:
+ # in py37 we can assume two dictionaries created in the same
+ # insert ordering will retain that sorting
+ return (
+ attrname,
+ tuple(
+ (
+ k._gen_cache_key(anon_map, bindparams)
+ if hasattr(k, "__clause_element__")
+ else k,
+ obj[k]._gen_cache_key(anon_map, bindparams),
+ )
+ for k in obj
+ ),
+ )
+ else:
+ expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
+ if expr_values:
+ # expr values can't be sorted deterministically right now,
+ # so no cache
+ anon_map[NO_CACHE] = True
+ return ()
- expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
- if expr_values:
- # expr values can't be sorted deterministically right now,
- # so no cache
- anon_map[NO_CACHE] = True
- return ()
-
- str_values = expr_values.symmetric_difference(obj)
+ str_values = expr_values.symmetric_difference(obj)
- return (
- attrname,
- tuple(
- (k, obj[k]._gen_cache_key(anon_map, bindparams))
- for k in sorted(str_values)
- ),
- )
+ return (
+ attrname,
+ tuple(
+ (k, obj[k]._gen_cache_key(anon_map, bindparams))
+ for k in sorted(str_values)
+ ),
+ )
def visit_dml_multi_values(
self, attrname, obj, parent, anon_map, bindparams
for lv, rv in zip(left, right):
if not self._compare_dml_values_or_ce(lv, rv, **kw):
return COMPARE_FAILED
+ elif isinstance(right, collections_abc.Sequence):
+ return COMPARE_FAILED
+ elif py37:
+ # dictionaries guaranteed to support insert ordering in
+ # py37 so that we can compare the keys in order. without
+ # this, we can't compare SQL expression keys because we don't
+ # know which key is which
+ for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
+ if not self._compare_dml_values_or_ce(lk, rk, **kw):
+ return COMPARE_FAILED
+ if not self._compare_dml_values_or_ce(lv, rv, **kw):
+ return COMPARE_FAILED
else:
for lk in left:
lv = left[lk]
LABEL_STYLE_TABLENAME_PLUS_COL
)
clause = compile_state.statement
- elif isinstance(clause, orm.persistence.BulkUD):
- with mock.patch.object(clause, "_execute_stmt") as stmt_mock:
- clause.exec_()
- clause = stmt_mock.mock_calls[0][1][0]
if compile_kwargs:
kw["compile_kwargs"] = compile_kwargs
from .compat import print_ # noqa
from .compat import py2k # noqa
from .compat import py36 # noqa
+from .compat import py37 # noqa
from .compat import py3k # noqa
from .compat import quote_plus # noqa
from .compat import raise_ # noqa
import sys
+py37 = sys.version_info >= (3, 7)
py36 = sys.version_info >= (3, 6)
py3k = sys.version_info >= (3, 0)
py2k = sys.version_info < (3, 0)
ORMCompileState.orm_pre_session_exec(
sess,
compile_state.select_statement,
+ {},
exec_opts,
bind_arguments,
)
def test_scalar_one(self):
result = self._fixture(num_rows=1)
+ row = result.scalar_one()
+ eq_(row, 1)
+
+ def test_scalars_plus_one(self):
+ result = self._fixture(num_rows=1)
+
row = result.scalars().one()
eq_(row, 1)
- def test_scalar_one_none(self):
+ def test_scalars_plus_one_none(self):
result = self._fixture(num_rows=0)
result = result.scalars()
result.one,
)
+ def test_one_or_none(self):
+ result = self._fixture(num_rows=1)
+
+ eq_(result.one_or_none(), (1, 1, 1))
+
+ def test_scalar_one_or_none(self):
+ result = self._fixture(num_rows=1)
+
+ eq_(result.scalar_one_or_none(), 1)
+
+ def test_scalar_one_or_none_none(self):
+ result = self._fixture(num_rows=0)
+
+ eq_(result.scalar_one_or_none(), None)
+
def test_one_or_none_none(self):
result = self._fixture(num_rows=0)
from sqlalchemy import Column
from sqlalchemy import DateTime
+from sqlalchemy import delete
from sqlalchemy import event
from sqlalchemy import Float
from sqlalchemy import ForeignKey
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
+from sqlalchemy import update
from sqlalchemy import util
from sqlalchemy.ext.horizontal_shard import ShardedSession
from sqlalchemy.future import select as future_select
t = get_tokyo(sess2)
eq_(t.city, tokyo.city)
- def test_bulk_update(self):
+ def test_bulk_update_synchronize_evaluate(self):
sess = self._fixture_data()
eq_(
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
sess.query(Report).filter(Report.temperature >= 80).update(
- {"temperature": Report.temperature + 6}
+ {"temperature": Report.temperature + 6},
+ synchronize_session="evaluate",
)
eq_(
# test synchronize session as well
eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
- def test_bulk_delete(self):
+ def test_bulk_update_synchronize_fetch(self):
+ sess = self._fixture_data()
+
+ eq_(
+ set(row.temperature for row in sess.query(Report.temperature)),
+ {80.0, 75.0, 85.0},
+ )
+
+ temps = sess.query(Report).all()
+ eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+ sess.query(Report).filter(Report.temperature >= 80).update(
+ {"temperature": Report.temperature + 6},
+ synchronize_session="fetch",
+ )
+
+ eq_(
+ set(row.temperature for row in sess.query(Report.temperature)),
+ {86.0, 75.0, 91.0},
+ )
+
+ # test synchronize session as well
+ eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
+
+ def test_bulk_delete_synchronize_evaluate(self):
+ sess = self._fixture_data()
+
+ temps = sess.query(Report).all()
+ eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+ sess.query(Report).filter(Report.temperature >= 80).delete(
+ synchronize_session="evaluate"
+ )
+
+ eq_(
+ set(row.temperature for row in sess.query(Report.temperature)),
+ {75.0},
+ )
+
+ # test synchronize session as well
+ for t in temps:
+ assert inspect(t).deleted is (t.temperature >= 80)
+
+ def test_bulk_delete_synchronize_fetch(self):
sess = self._fixture_data()
temps = sess.query(Report).all()
eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
- sess.query(Report).filter(Report.temperature >= 80).delete()
+ sess.query(Report).filter(Report.temperature >= 80).delete(
+ synchronize_session="fetch"
+ )
eq_(
set(row.temperature for row in sess.query(Report.temperature)),
for t in temps:
assert inspect(t).deleted is (t.temperature >= 80)
+ def test_bulk_update_future_synchronize_evaluate(self):
+ sess = self._fixture_data()
+
+ eq_(
+ set(
+ row.temperature
+ for row in sess.execute(future_select(Report.temperature))
+ ),
+ {80.0, 75.0, 85.0},
+ )
+
+ temps = sess.execute(future_select(Report)).scalars().all()
+ eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+ sess.execute(
+ update(Report)
+ .filter(Report.temperature >= 80)
+ .values({"temperature": Report.temperature + 6},)
+ .execution_options(synchronize_session="evaluate")
+ )
+
+ eq_(
+ set(
+ row.temperature
+ for row in sess.execute(future_select(Report.temperature))
+ ),
+ {86.0, 75.0, 91.0},
+ )
+
+ # test synchronize session as well
+ eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
+
+ def test_bulk_update_future_synchronize_fetch(self):
+ sess = self._fixture_data()
+
+ eq_(
+ set(
+ row.temperature
+ for row in sess.execute(future_select(Report.temperature))
+ ),
+ {80.0, 75.0, 85.0},
+ )
+
+ temps = sess.execute(future_select(Report)).scalars().all()
+ eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+ sess.execute(
+ update(Report)
+ .filter(Report.temperature >= 80)
+ .values({"temperature": Report.temperature + 6},)
+ .execution_options(synchronize_session="fetch")
+ )
+
+ eq_(
+ set(
+ row.temperature
+ for row in sess.execute(future_select(Report.temperature))
+ ),
+ {86.0, 75.0, 91.0},
+ )
+
+ # test synchronize session as well
+ eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
+
+ def test_bulk_delete_future_synchronize_evaluate(self):
+ sess = self._fixture_data()
+
+ temps = sess.execute(future_select(Report)).scalars().all()
+ eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+ sess.execute(
+ delete(Report)
+ .filter(Report.temperature >= 80)
+ .execution_options(synchronize_session="evaluate")
+ )
+
+ eq_(
+ set(
+ row.temperature
+ for row in sess.execute(future_select(Report.temperature))
+ ),
+ {75.0},
+ )
+
+ # test synchronize session as well
+ for t in temps:
+ assert inspect(t).deleted is (t.temperature >= 80)
+
+ def test_bulk_delete_future_synchronize_fetch(self):
+ sess = self._fixture_data()
+
+ temps = sess.execute(future_select(Report)).scalars().all()
+ eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+ sess.execute(
+ delete(Report)
+ .filter(Report.temperature >= 80)
+ .execution_options(synchronize_session="fetch")
+ )
+
+ eq_(
+ set(
+ row.temperature
+ for row in sess.execute(future_select(Report.temperature))
+ ),
+ {75.0},
+ )
+
+ # test synchronize session as well
+ for t in temps:
+ assert inspect(t).deleted is (t.temperature >= 80)
+
class DistinctEngineShardTest(ShardTest, fixtures.TestBase):
def _init_dbs(self):
from sqlalchemy.ext import hybrid
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import aliased
-from sqlalchemy.orm import persistence
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
+from sqlalchemy.sql import update
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
def test_update_plain(self):
Person = self.classes.Person
- s = Session()
- q = s.query(Person)
-
- bulk_ud = persistence.BulkUpdate.factory(
- q, False, {Person.fname: "Dr."}, {}
- )
+ statement = update(Person).values({Person.fname: "Dr."})
self.assert_compile(
- bulk_ud,
+ statement,
"UPDATE person SET first_name=:first_name",
params={"first_name": "Dr."},
)
def test_update_expr(self):
Person = self.classes.Person
- s = Session()
- q = s.query(Person)
-
- bulk_ud = persistence.BulkUpdate.factory(
- q, False, {Person.name: "Dr. No"}, {}
- )
+ statement = update(Person).values({Person.name: "Dr. No"})
self.assert_compile(
- bulk_ud,
+ statement,
"UPDATE person SET first_name=:first_name, last_name=:last_name",
params={"first_name": "Dr.", "last_name": "No"},
)
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import testing
+from sqlalchemy import update
+from sqlalchemy.future import select as future_select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import composite
from sqlalchemy.orm import CompositeProperty
from sqlalchemy.orm import configure_mappers
from sqlalchemy.orm import mapper
-from sqlalchemy.orm import persistence
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.testing import assert_raises_message
sess = self._fixture()
- e1 = sess.query(Edge).filter(Edge.start == Point(14, 5)).one()
+ e1 = sess.execute(
+ future_select(Edge).filter(Edge.start == Point(14, 5))
+ ).scalar_one()
eq_(e1.end, Point(2, 7))
- q = sess.query(Edge).filter(Edge.start == Point(14, 5))
- bulk_ud = persistence.BulkUpdate.factory(
- q, False, {Edge.end: Point(16, 10)}, {}
+ stmt = (
+ update(Edge)
+ .filter(Edge.start == Point(14, 5))
+ .values({Edge.end: Point(16, 10)})
)
self.assert_compile(
- bulk_ud,
+ stmt,
"UPDATE edges SET x2=:x2, y2=:y2 WHERE edges.x1 = :x1_1 "
"AND edges.y1 = :y1_1",
params={"x2": 16, "x1_1": 14, "y2": 10, "y1_1": 5},
sess = self._fixture()
- e1 = sess.query(Edge).filter(Edge.start == Point(14, 5)).one()
+ e1 = sess.execute(
+ future_select(Edge).filter(Edge.start == Point(14, 5))
+ ).scalar_one()
eq_(e1.end, Point(2, 7))
- q = sess.query(Edge).filter(Edge.start == Point(14, 5))
- q.update({Edge.end: Point(16, 10)})
+ stmt = (
+ update(Edge)
+ .filter(Edge.start == Point(14, 5))
+ .values({Edge.end: Point(16, 10)})
+ )
+ sess.execute(stmt)
eq_(e1.end, Point(16, 10))
checkparams={"id_1": 5, "name": "ed"},
)
+ self.assert_compile(
+ update(User).values({User.name: "ed"}).where(User.id == 5),
+ "UPDATE users SET name=:name WHERE users.id = :id_1",
+ checkparams={"id_1": 5, "name": "ed"},
+ )
+
def test_delete_from_entity(self):
from sqlalchemy.sql import delete
eq_(upd.session, sess)
eq_(
canary.after_bulk_update_legacy.mock_calls,
- [call(sess, upd.query, upd.context, upd.result)],
+ [call(sess, upd.query, None, upd.result)],
)
def test_on_bulk_delete_hook(self):
eq_(upd.session, sess)
eq_(
canary.after_bulk_delete_legacy.mock_calls,
- [call(sess, upd.query, upd.context, upd.result)],
+ [call(sess, upd.query, None, upd.result)],
)
s.refresh(a1)
# joined eager load didn't continue
eq_(len(a1.bs), 1)
+ s.close()
@_combinations
def test_flag_resolves_existing(self, target, event_name, fn):
s.expire(a1)
event.listen(target, event_name, fn, restore_load_context=True)
s.query(A).all()
+ s.close()
@_combinations
def test_flag_resolves(self, target, event_name, fn):
s.refresh(a1)
# joined eager load continued
eq_(len(a1.bs), 3)
+ s.close()
class DeclarativeEventListenTest(
upd = canary.after_bulk_update.mock_calls[0][1][0]
eq_(upd.session, sess)
+ eq_(upd.result.rowcount, 0)
def test_on_bulk_delete_hook(self):
User, users = self.classes.User, self.tables.users
upd = canary.after_bulk_delete.mock_calls[0][1][0]
eq_(upd.session, sess)
+ eq_(upd.result.rowcount, 0)
def test_connection_emits_after_begin(self):
sess, canary = self._listener_fixture(bind=testing.db)
from sqlalchemy import Boolean
from sqlalchemy import case
from sqlalchemy import column
+from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy import text
+from sqlalchemy import update
+from sqlalchemy.future import select as future_select
from sqlalchemy.orm import backref
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import mapper
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
-from sqlalchemy.testing import mock
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
-from sqlalchemy.util import collections_abc
class UpdateDeleteTest(fixtures.MappedTest):
list(zip([15, 27, 19, 27])),
)
+ def test_update_future(self):
+ User, users = self.classes.User, self.tables.users
+
+ sess = Session()
+
+ john, jack, jill, jane = (
+ sess.execute(future_select(User).order_by(User.id)).scalars().all()
+ )
+
+ sess.execute(
+ update(User)
+ .where(User.age > 29)
+ .values({"age": User.age - 10})
+ .execution_options(synchronize_session="evaluate"),
+ )
+
+ eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
+ eq_(
+ sess.execute(future_select(User.age).order_by(User.id)).all(),
+ list(zip([25, 37, 29, 27])),
+ )
+
+ sess.execute(
+ update(User)
+ .where(User.age > 29)
+ .values({User.age: User.age - 10})
+ .execution_options(synchronize_session="evaluate")
+ )
+ eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 29, 27])
+ eq_(
+ sess.query(User.age).order_by(User.id).all(),
+ list(zip([25, 27, 29, 27])),
+ )
+
+ sess.query(User).filter(User.age > 27).update(
+ {users.c.age_int: User.age - 10}, synchronize_session="evaluate"
+ )
+ eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 19, 27])
+ eq_(
+ sess.query(User.age).order_by(User.id).all(),
+ list(zip([25, 27, 19, 27])),
+ )
+
+ sess.query(User).filter(User.age == 25).update(
+ {User.age: User.age - 10}, synchronize_session="fetch"
+ )
+ eq_([john.age, jack.age, jill.age, jane.age], [15, 27, 19, 27])
+ eq_(
+ sess.query(User.age).order_by(User.id).all(),
+ list(zip([15, 27, 19, 27])),
+ )
+
def test_update_against_table_col(self):
User, users = self.classes.User, self.tables.users
# Do an update using unordered dict and check that the parameters used
# are ordered in table order
+
+ m1 = testing.mock.Mock()
+
+ @event.listens_for(session, "after_bulk_update")
+ def do_orm_execute(bulk_ud):
+ m1(bulk_ud.result.context.compiled.compile_state.statement)
+
q = session.query(User)
- with mock.patch.object(q, "_execute_crud") as exec_:
- q.filter(User.id == 15).update({"name": "foob", "id": 123})
- # Confirm that parameters are a dict instead of tuple or list
- params = exec_.mock_calls[0][1][0]._values
- assert isinstance(params, collections_abc.Mapping)
+ q.filter(User.id == 15).update({"name": "foob", "age": 123})
+ assert m1.mock_calls[0][1][0]._values
- def test_update_preserve_parameter_order(self):
+ def test_update_preserve_parameter_order_query(self):
User = self.classes.User
session = Session()
# Do update using a tuple and check that order is preserved
- q = session.query(User)
- with mock.patch.object(q, "_execute_crud") as exec_:
- q.filter(User.id == 15).update(
- (("id", 123), ("name", "foob")),
- update_args={"preserve_parameter_order": True},
- )
+
+ m1 = testing.mock.Mock()
+
+ @event.listens_for(session, "after_bulk_update")
+ def do_orm_execute(bulk_ud):
+
cols = [
- c.key for c, v in exec_.mock_calls[0][1][0]._ordered_values
+ c.key
+ for c, v in (
+ (
+ bulk_ud.result.context
+ ).compiled.compile_state.statement._ordered_values
+ )
]
- eq_(["id", "name"], cols)
+ m1(cols)
- # Now invert the order and use a list instead, and check that order is
- # also preserved
q = session.query(User)
- with mock.patch.object(q, "_execute_crud") as exec_:
- q.filter(User.id == 15).update(
- [("name", "foob"), ("id", 123)],
- update_args={"preserve_parameter_order": True},
+ q.filter(User.id == 15).update(
+ (("age", 123), ("name", "foob")),
+ update_args={"preserve_parameter_order": True},
+ )
+
+ eq_(m1.mock_calls[0][1][0], ["age_int", "name"])
+
+ m1.mock_calls = []
+
+ q = session.query(User)
+ q.filter(User.id == 15).update(
+ [("name", "foob"), ("age", 123)],
+ update_args={"preserve_parameter_order": True},
+ )
+ eq_(m1.mock_calls[0][1][0], ["name", "age_int"])
+
+ def test_update_multi_values_error_future(self):
+ User = self.classes.User
+ session = Session()
+
+ # Do update using a tuple and check that order is preserved
+
+ stmt = (
+ update(User)
+ .filter(User.id == 15)
+ .values([("id", 123), ("name", "foob")])
+ )
+
+ assert_raises_message(
+ exc.InvalidRequestError,
+ "UPDATE construct does not support multiple parameter sets.",
+ session.execute,
+ stmt,
+ )
+
+ def test_update_preserve_parameter_order_future(self):
+ User = self.classes.User
+ session = Session()
+
+ # Do update using a tuple and check that order is preserved
+
+ stmt = (
+ update(User)
+ .filter(User.id == 15)
+ .ordered_values(("age", 123), ("name", "foob"))
+ )
+ result = session.execute(stmt)
+ cols = [
+ c.key
+ for c, v in (
+ (
+ result.context
+ ).compiled.compile_state.statement._ordered_values
)
- cols = [
- c.key for c, v in exec_.mock_calls[0][1][0]._ordered_values
- ]
- eq_(["name", "id"], cols)
+ ]
+ eq_(["age_int", "name"], cols)
+
+ # Now invert the order and use a list instead, and check that order is
+ # also preserved
+ stmt = (
+ update(User)
+ .filter(User.id == 15)
+ .ordered_values(("name", "foob"), ("age", 123),)
+ )
+ result = session.execute(stmt)
+ cols = [
+ c.key
+ for c, v in (
+ result.context
+ ).compiled.compile_state.statement._ordered_values
+ ]
+ eq_(["name", "age_int"], cols)
class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest):
def test_update_args(self):
Data = self.classes.Data
- session = testing.mock.Mock(wraps=Session())
+ session = Session()
update_args = {"mysql_limit": 1}
+ m1 = testing.mock.Mock()
+
+ @event.listens_for(session, "after_bulk_update")
+ def do_orm_execute(bulk_ud):
+ update_stmt = (
+ bulk_ud.result.context.compiled.compile_state.statement
+ )
+ m1(update_stmt)
+
q = session.query(Data)
- with testing.mock.patch.object(q, "_execute_crud") as exec_:
- q.update({Data.cnt: Data.cnt + 1}, update_args=update_args)
- eq_(exec_.call_count, 1)
- args, kwargs = exec_.mock_calls[0][1:3]
- eq_(len(args), 2)
- update_stmt = args[0]
+ q.update({Data.cnt: Data.cnt + 1}, update_args=update_args)
+
+ update_stmt = m1.mock_calls[0][1][0]
+
eq_(update_stmt.dialect_kwargs, update_args)
)
s.commit()
- def test_illegal_metadata(self):
+ @testing.only_on("mysql", "Multi table update")
+ def test_update_from_join_no_problem(self):
person = self.classes.Person.__table__
engineer = self.classes.Engineer.__table__
sess = Session()
- assert_raises_message(
- exc.InvalidRequestError,
- "This operation requires only one Table or entity be "
- "specified as the target.",
- sess.query(person.join(engineer)).update,
- {},
+ sess.query(person.join(engineer)).filter(person.c.name == "e2").update(
+ {person.c.name: "updated", engineer.c.engineer_name: "e2a"},
)
+ obj = sess.execute(
+ future_select(self.classes.Engineer).filter(
+ self.classes.Engineer.name == "updated"
+ )
+ ).scalar()
+ eq_(obj.name, "updated")
+ eq_(obj.engineer_name, "e2a")
def test_update_subtable_only(self):
Engineer = self.classes.Engineer
table_b = Table("b", meta, Column("a", Integer), Column("b", Integer))
+table_b_b = Table(
+ "b_b",
+ meta,
+ Column("a", Integer),
+ Column("b", Integer),
+ Column("c", Integer),
+ Column("d", Integer),
+ Column("e", Integer),
+)
+
table_c = Table("c", meta, Column("x", Integer), Column("y", Integer))
table_d = Table("d", meta, Column("y", Integer), Column("z", Integer))
fixtures.append(_statements_w_anonymous_col_names)
+ def _update_dml_w_dicts():
+ return (
+ table_b_b.update().values(
+ {
+ table_b_b.c.a: 5,
+ table_b_b.c.b: 5,
+ table_b_b.c.c: 5,
+ table_b_b.c.d: 5,
+ }
+ ),
+ # equivalent, but testing dictionary insert ordering as cache key
+ # / compare
+ table_b_b.update().values(
+ {
+ table_b_b.c.a: 5,
+ table_b_b.c.c: 5,
+ table_b_b.c.b: 5,
+ table_b_b.c.d: 5,
+ }
+ ),
+ table_b_b.update().values(
+ {table_b_b.c.a: 5, table_b_b.c.b: 5, "c": 5, table_b_b.c.d: 5}
+ ),
+ table_b_b.update().values(
+ {
+ table_b_b.c.a: 5,
+ table_b_b.c.b: 5,
+ table_b_b.c.c: 5,
+ table_b_b.c.d: 5,
+ table_b_b.c.e: 10,
+ }
+ ),
+ table_b_b.update()
+ .values(
+ {
+ table_b_b.c.a: 5,
+ table_b_b.c.b: 5,
+ table_b_b.c.c: 5,
+ table_b_b.c.d: 5,
+ table_b_b.c.e: 10,
+ }
+ )
+ .where(table_b_b.c.c > 10),
+ )
+
+ if util.py37:
+ fixtures.append(_update_dml_w_dicts)
+
class CacheKeyFixture(object):
def _run_cache_key_fixture(self, fixture, compare_values):