try:
if item_type == 'property':
prop = iterator.next()
- visitables.append((prop.cascade_iterator(type_, parent_state, visited_instances, halt_on), 'mapper', None))
+ visitables.append((prop.cascade_iterator(type_, parent_state,
+ visited_instances, halt_on), 'mapper', None))
elif item_type == 'mapper':
instance, instance_mapper, corresponding_state = iterator.next()
yield (instance, instance_mapper)
- visitables.append((instance_mapper._props.itervalues(), 'property', corresponding_state))
+ visitables.append((instance_mapper._props.itervalues(),
+ 'property', corresponding_state))
except StopIteration:
visitables.pop()
# if batch=false, call _save_obj separately for each object
if not single and not self.batch:
for state in _sort_states(states):
- self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
+ self._save_obj([state],
+ uowtransaction,
+ postupdate=postupdate,
+ post_update_cols=post_update_cols,
+ single=True)
return
-
+
# if session has a connection callable,
- # organize individual states with the connection to use for insert/update
- tups = []
+ # organize individual states with the connection
+ # to use for insert/update
if 'connection_callable' in uowtransaction.mapper_flush_opts:
- connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
- for state in _sort_states(states):
- m = _state_mapper(state)
- tups.append(
- (
- state,
- m,
- connection_callable(self, state.obj()),
- _state_has_identity(state),
- state.key or m._identity_key_from_state(state)
- )
- )
+ connection_callable = \
+ uowtransaction.mapper_flush_opts['connection_callable']
else:
connection = uowtransaction.transaction.connection(self)
- for state in _sort_states(states):
- m = _state_mapper(state)
- tups.append(
- (
- state,
- m,
- connection,
- _state_has_identity(state),
- state.key or m._identity_key_from_state(state)
- )
- )
+ connection_callable = None
- if not postupdate:
- # call before_XXX extensions
- for state, mapper, connection, has_identity, instance_key in tups:
+ tups = []
+ for state in _sort_states(states):
+ conn = connection_callable and \
+ connection_callable(self, state.obj()) or \
+ connection
+
+ has_identity = _state_has_identity(state)
+ mapper = _state_mapper(state)
+ instance_key = state.key or mapper._identity_key_from_state(state)
+
+ row_switch = None
+ if not postupdate:
+ # call before_XXX extensions
if not has_identity:
if 'before_insert' in mapper.extension:
- mapper.extension.before_insert(mapper, connection, state.obj())
+ mapper.extension.before_insert(mapper, conn, state.obj())
else:
if 'before_update' in mapper.extension:
- mapper.extension.before_update(mapper, connection, state.obj())
+ mapper.extension.before_update(mapper, conn, state.obj())
- row_switches = {}
- if not postupdate:
- for state, mapper, connection, has_identity, instance_key in tups:
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
- # and another instance with the same identity key already exists as persistent. convert to an
- # UPDATE if so.
+ # and another instance with the same identity key already exists as persistent.
+ # convert to an UPDATE if so.
if not has_identity and instance_key in uowtransaction.session.identity_map:
instance = uowtransaction.session.identity_map[instance_key]
existing = attributes.instance_state(instance)
"New instance %s with identity key %s conflicts "
"with persistent instance %s" %
(state_str(state), instance_key, state_str(existing)))
-
+
self._log_debug(
- "detected row switch for identity %s. will update %s, remove %s from "
- "transaction", instance_key, state_str(state), state_str(existing))
-
+ "detected row switch for identity %s. "
+ "will update %s, remove %s from "
+ "transaction", instance_key,
+ state_str(state), state_str(existing))
+
# remove the "delete" flag from the existing element
uowtransaction.set_row_switch(existing)
- row_switches[state] = existing
-
- table_to_mapper = self._sorted_tables
+ row_switch = existing
+
+ tups.append(
+ (state,
+ mapper,
+ conn,
+ has_identity,
+ instance_key,
+ row_switch)
+ )
- for table in table_to_mapper.iterkeys():
+ table_to_mapper = self._sorted_tables
+
+ for table in table_to_mapper:
insert = []
update = []
- for state, mapper, connection, has_identity, instance_key in tups:
+ for state, mapper, connection, has_identity, \
+ instance_key, row_switch in tups:
if table not in mapper._pks_by_table:
continue
pks = mapper._pks_by_table[table]
- isinsert = not has_identity and not postupdate and state not in row_switches
+ isinsert = not has_identity and \
+ not postupdate and \
+ not row_switch
params = {}
value_params = {}
value_params[col] = value
else:
params[col.key] = value
- insert.append((state, params, mapper, connection, value_params))
+ insert.append((state, params, mapper,
+ connection, value_params))
else:
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
- params[col._label] = mapper._get_state_attr_by_column(row_switches.get(state, state), col)
- params[col.key] = mapper.version_id_generator(params[col._label])
+ params[col._label] = \
+ mapper._get_state_attr_by_column(
+ row_switch or state,
+ col)
+ params[col.key] = \
+ mapper.version_id_generator(params[col._label])
+
+ # HACK: check for history, in case the history is only
+ # in a different table than the one where the version_id_col
+ # is.
for prop in mapper._columntoproperty.itervalues():
- history = attributes.get_state_history(state, prop.key, passive=True)
+ history = attributes.get_state_history(
+ state, prop.key, passive=True)
if history.added:
hasdata = True
elif mapper.polymorphic_on is not None and \
- mapper.polymorphic_on.shares_lineage(col) and col not in pks:
+ mapper.polymorphic_on.shares_lineage(col) and \
+ col not in pks:
pass
else:
- if post_update_cols is not None and col not in post_update_cols:
+ if post_update_cols is not None and \
+ col not in post_update_cols:
if col in pks:
- params[col._label] = mapper._get_state_attr_by_column(state, col)
+ params[col._label] = \
+ mapper._get_state_attr_by_column(state, col)
continue
prop = mapper._columntoproperty[col]
elif col in pks:
params[col._label] = mapper._get_state_attr_by_column(state, col)
if hasdata:
- update.append((state, params, mapper, connection, value_params))
+ update.append((state, params, mapper,
+ connection, value_params))
if update:
mapper = table_to_mapper[table]
clause = sql.and_()
for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label, type_=col.type))
+ clause.clauses.append(
+ col ==
+ sql.bindparam(col._label, type_=col.type)
+ )
- if mapper.version_id_col is not None and \
- table.c.contains_column(mapper.version_id_col):
-
+ needs_version_id = mapper.version_id_col is not None and \
+ table.c.contains_column(mapper.version_id_col)
+
+ if needs_version_id:
clause.clauses.append(mapper.version_id_col ==\
sql.bindparam(mapper.version_id_col._label, type_=col.type))
statement = table.update(clause)
-
+
rows = 0
for state, params, mapper, connection, value_params in update:
c = connection.execute(statement.values(value_params), params)
- mapper._postfetch(uowtransaction, connection, table,
+ mapper._postfetch(uowtransaction, table,
state, c, c.last_updated_params(), value_params)
rows += c.rowcount
if connection.dialect.supports_sane_rowcount:
if rows != len(update):
raise orm_exc.ConcurrentModificationError(
- "Updated rowcount %d does not match number of objects updated %d" %
+ "Updated rowcount %d does not match number "
+ "of objects updated %d" %
(rows, len(update)))
-
- elif mapper.version_id_col is not None:
+
+ elif needs_version_id:
util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." % c.dialect.dialect_description,
- stacklevel=12)
+ "- versioning cannot be verified." %
+ c.dialect.dialect_description,
+ stacklevel=12)
if insert:
statement = table.insert()
len(primary_key) > i:
mapper._set_state_attr_by_column(state, col, primary_key[i])
- mapper._postfetch(uowtransaction, connection, table,
+ mapper._postfetch(uowtransaction, table,
state, c, c.last_inserted_params(), value_params)
-
if not postupdate:
- for state, mapper, connection, has_identity, instance_key in tups:
+ for state, mapper, connection, has_identity, \
+ instance_key, row_switch in tups:
# expire readonly attributes
readonly = state.unmodified.intersection(
if readonly:
_expire_state(state, state.dict, readonly)
- # if specified, eagerly refresh whatever has
- # been expired.
+ # if eager_defaults option is enabled,
+ # refresh whatever has been expired.
if self.eager_defaults and state.unloaded:
state.key = self._identity_key_from_state(state)
uowtransaction.session.query(self)._get(
if 'after_update' in mapper.extension:
mapper.extension.after_update(mapper, connection, state.obj())
- def _postfetch(self, uowtransaction, connection, table,
+ def _postfetch(self, uowtransaction, table,
state, resultproxy, params, value_params):
"""Expire attributes in need of newly persisted database state."""
if c.key in params and c in self._columntoproperty:
self._set_state_attr_by_column(state, c, params[c.key])
- deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]]
-
- if deferred_props:
- _expire_state(state, state.dict, deferred_props)
+ if postfetch_cols:
+ _expire_state(state, state.dict,
+ [self._columntoproperty[c].key
+ for c in postfetch_cols]
+ )
# synchronize newly inserted ids from one table to the next
# TODO: this still goes a little too often. would be nice to
# have definitive list of "columns that changed" here
- cols = set(table.c)
- for m in self.iterate_to_root():
- if m._inherits_equated_pairs and \
- cols.intersection([l for l, r in m._inherits_equated_pairs]):
- sync.populate(state, m, state, m,
- m._inherits_equated_pairs,
- uowtransaction,
- self.passive_updates)
-
+ for m, equated_pairs in self._table_to_equated[table]:
+ sync.populate(state, m, state, m,
+ equated_pairs,
+ uowtransaction,
+ self.passive_updates)
+
+ @util.memoized_property
+ def _table_to_equated(self):
+ """memoized map of tables to collections of columns to be
+ synchronized upwards to the base mapper."""
+
+ result = util.defaultdict(list)
+
+ for table in self._sorted_tables:
+ cols = set(table.c)
+ for m in self.iterate_to_root():
+ if m._inherits_equated_pairs and \
+ cols.intersection([l for l, r in m._inherits_equated_pairs]):
+ result[table].append((m, m._inherits_equated_pairs))
+
+ return result
+
def _delete_obj(self, states, uowtransaction):
"""Issue ``DELETE`` statements for a list of objects.
"""
if 'connection_callable' in uowtransaction.mapper_flush_opts:
- connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
- tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in _sort_states(states)]
+ connection_callable = \
+ uowtransaction.mapper_flush_opts['connection_callable']
else:
connection = uowtransaction.transaction.connection(self)
- tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)]
-
- for state, mapper, connection in tups:
+ connection_callable = None
+
+ tups = []
+ for state in _sort_states(states):
+ mapper = _state_mapper(state)
+
+ conn = connection_callable and \
+ connection_callable(self, state.obj()) or \
+ connection
+
if 'before_delete' in mapper.extension:
- mapper.extension.before_delete(mapper, connection, state.obj())
+ mapper.extension.before_delete(mapper, conn, state.obj())
+
+ tups.append((state,
+ _state_mapper(state),
+ _state_has_identity(state),
+ conn))
table_to_mapper = self._sorted_tables
for table in reversed(table_to_mapper.keys()):
- delete = {}
- for state, mapper, connection in tups:
- if table not in mapper._pks_by_table:
+ delete = util.defaultdict(list)
+ for state, mapper, has_identity, connection in tups:
+ if not has_identity or table not in mapper._pks_by_table:
continue
params = {}
- if not _state_has_identity(state):
- continue
- else:
- delete.setdefault(connection, []).append(params)
+ delete[connection].append(params)
for col in mapper._pks_by_table[table]:
params[col.key] = mapper._get_state_attr_by_column(state, col)
- if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
- params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col)
+ if mapper.version_id_col is not None and \
+ table.c.contains_column(mapper.version_id_col):
+ params[mapper.version_id_col.key] = \
+ mapper._get_state_attr_by_column(state, mapper.version_id_col)
for connection, del_objects in delete.iteritems():
mapper = table_to_mapper[table]
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col.key, type_=col.type))
- if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
+
+ need_version_id = mapper.version_id_col is not None and \
+ table.c.contains_column(mapper.version_id_col)
+
+ if need_version_id:
clause.clauses.append(
mapper.version_id_col ==
- sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type))
+ sql.bindparam(
+ mapper.version_id_col.key,
+ type_=mapper.version_id_col.type
+ )
+ )
+
statement = table.delete(clause)
- c = connection.execute(statement, del_objects)
- if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects):
- raise orm_exc.ConcurrentModificationError("Deleted rowcount %d does not match "
- "number of objects deleted %d" % (c.rowcount, len(del_objects)))
+ rows = -1
+
+ if need_version_id and \
+ not connection.dialect.supports_sane_multi_rowcount:
+ # TODO: need test coverage for this [ticket:1761]
+ if connection.dialect.supports_sane_rowcount:
+ rows = 0
+ # execute deletes individually so that versioned
+ # rows can be verified
+ for params in del_objects:
+ c = connection.execute(statement, params)
+ rows += c.rowcount
+ else:
+ util.warn("Dialect %s does not support deleted rowcount "
+ "- versioning cannot be verified." %
+ c.dialect.dialect_description,
+ stacklevel=12)
+ connection.execute(statement, del_objects)
+ else:
+ c = connection.execute(statement, del_objects)
+ if connection.dialect.supports_sane_multi_rowcount:
+ rows = c.rowcount
+
+ if rows != -1 and rows != len(del_objects):
+ raise orm_exc.ConcurrentModificationError(
+ "Deleted rowcount %d does not match "
+ "number of objects deleted %d" %
+ (c.rowcount, len(del_objects))
+ )
- for state, mapper, connection in tups:
+ for state, mapper, has_identity, connection in tups:
if 'after_delete' in mapper.extension:
mapper.extension.after_delete(mapper, connection, state.obj())