import types
import weakref
import operator
-from itertools import chain
+from itertools import chain, groupby
deque = __import__('collections').deque
from sqlalchemy import sql, util, log, exc as sa_exc
self._memoized_values[key] = value = callable_()
return value
- def _save_obj(self, states, uowtransaction, postupdate=False,
- post_update_cols=None, single=False):
+ def _post_update(self, states, uowtransaction, post_update_cols):
+ """Issue UPDATE statements on behalf of a relationship() which
+ specifies post_update.
+
+ """
+ cached_connections = util.PopulateDict(
+ lambda conn:conn.execution_options(
+ compiled_cache=self._compiled_cache
+ ))
+
+ # if session has a connection callable,
+ # organize individual states with the connection
+ # to use for update
+ if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ connection_callable = \
+ uowtransaction.mapper_flush_opts['connection_callable']
+ else:
+ connection = uowtransaction.transaction.connection(self)
+ connection_callable = None
+
+ tups = []
+ for state in _sort_states(states):
+ if connection_callable:
+ conn = connection_callable(self, state.obj())
+ else:
+ conn = connection
+
+ mapper = _state_mapper(state)
+
+ tups.append((state, state.dict, mapper, conn))
+
+ table_to_mapper = self._sorted_tables
+
+ for table in table_to_mapper:
+ update = []
+
+ for state, state_dict, mapper, connection in tups:
+ if table not in mapper._pks_by_table:
+ continue
+
+ pks = mapper._pks_by_table[table]
+ params = {}
+ hasdata = False
+
+ for col in mapper._cols_by_table[table]:
+ if col in pks:
+ params[col._label] = \
+ mapper._get_state_attr_by_column(
+ state,
+ state_dict, col)
+ elif col in post_update_cols:
+ prop = mapper._columntoproperty[col]
+ history = attributes.get_state_history(
+ state, prop.key, passive=True)
+ if history.added:
+ params[col.key] = \
+ prop.get_col_value(col,
+ history.added[0])
+ hasdata = True
+ if hasdata:
+ update.append((state, state_dict, params, mapper,
+ connection))
+
+ if update:
+ mapper = table_to_mapper[table]
+
+ def update_stmt():
+ clause = sql.and_()
+
+ for col in mapper._pks_by_table[table]:
+ clause.clauses.append(col == sql.bindparam(col._label,
+ type_=col.type))
+
+ return table.update(clause)
+
+ statement = self._memo(('post_update', table), update_stmt)
+
+ # execute each UPDATE in the order according to the original
+ # list of states to guarantee row access order, but
+ # also group them into common (connection, cols) sets
+ # to support executemany().
+ for key, grouper in groupby(
+ update, lambda rec: (rec[4], rec[2].keys())
+ ):
+ multiparams = [params for state, state_dict,
+ params, mapper, conn in grouper]
+ cached_connections[connection].\
+ execute(statement, multiparams)
+
+ def _save_obj(self, states, uowtransaction, single=False):
"""Issue ``INSERT`` and/or ``UPDATE`` statements for a list
of objects.
mapper which does not inherit from any other mapper.
"""
+
# if batch=false, call _save_obj separately for each object
if not single and not self.batch:
for state in _sort_states(states):
self._save_obj([state],
uowtransaction,
- postupdate=postupdate,
- post_update_cols=post_update_cols,
single=True)
return
- cached_connections = util.PopulateDict(
- lambda conn:conn.execution_options(
- compiled_cache=self._compiled_cache
- ))
-
# if session has a connection callable,
# organize individual states with the connection
# to use for insert/update
instance_key = state.key or mapper._identity_key_from_state(state)
row_switch = None
- if not postupdate:
- # call before_XXX extensions
- if not has_identity:
- if 'before_insert' in mapper.extension:
- mapper.extension.before_insert(
- mapper, conn, state.obj())
- else:
- if 'before_update' in mapper.extension:
- mapper.extension.before_update(
- mapper, conn, state.obj())
-
- # detect if we have a "pending" instance (i.e. has
- # no instance_key attached to it), and another instance
- # with the same identity key already exists as persistent.
- # convert to an UPDATE if so.
- if not has_identity and \
- instance_key in uowtransaction.session.identity_map:
- instance = \
- uowtransaction.session.identity_map[instance_key]
- existing = attributes.instance_state(instance)
- if not uowtransaction.is_deleted(existing):
- raise orm_exc.FlushError(
- "New instance %s with identity key %s conflicts "
- "with persistent instance %s" %
- (state_str(state), instance_key,
- state_str(existing)))
-
- self._log_debug(
- "detected row switch for identity %s. "
- "will update %s, remove %s from "
- "transaction", instance_key,
- state_str(state), state_str(existing))
-
- # remove the "delete" flag from the existing element
- uowtransaction.remove_state_actions(existing)
- row_switch = existing
+ # call before_XXX extensions
+ if not has_identity:
+ if 'before_insert' in mapper.extension:
+ mapper.extension.before_insert(
+ mapper, conn, state.obj())
+ else:
+ if 'before_update' in mapper.extension:
+ mapper.extension.before_update(
+ mapper, conn, state.obj())
+
+ # detect if we have a "pending" instance (i.e. has
+ # no instance_key attached to it), and another instance
+ # with the same identity key already exists as persistent.
+ # convert to an UPDATE if so.
+ if not has_identity and \
+ instance_key in uowtransaction.session.identity_map:
+ instance = \
+ uowtransaction.session.identity_map[instance_key]
+ existing = attributes.instance_state(instance)
+ if not uowtransaction.is_deleted(existing):
+ raise orm_exc.FlushError(
+ "New instance %s with identity key %s conflicts "
+ "with persistent instance %s" %
+ (state_str(state), instance_key,
+ state_str(existing)))
+
+ self._log_debug(
+ "detected row switch for identity %s. "
+ "will update %s, remove %s from "
+ "transaction", instance_key,
+ state_str(state), state_str(existing))
+
+ # remove the "delete" flag from the existing element
+ uowtransaction.remove_state_actions(existing)
+ row_switch = existing
tups.append(
- (state,
- state.dict,
- mapper,
- conn,
- has_identity,
- instance_key,
- row_switch)
+ (state, state.dict, mapper, conn,
+ has_identity, instance_key, row_switch)
)
+ # dictionary of connection->connection_with_cache_options.
+ cached_connections = util.PopulateDict(
+ lambda conn:conn.execution_options(
+ compiled_cache=self._compiled_cache
+ ))
+
table_to_mapper = self._sorted_tables
for table in table_to_mapper:
pks = mapper._pks_by_table[table]
isinsert = not has_identity and \
- not postupdate and \
not row_switch
params = {}
col not in pks:
pass
else:
- if post_update_cols is not None and \
- col not in post_update_cols:
- if col in pks:
- params[col._label] = mapper.\
- _get_state_attr_by_column(
- state, state_dict, col)
- continue
-
prop = mapper._columntoproperty[col]
history = attributes.get_state_history(
state, prop.key, passive=True)
connection, value_params))
if update:
-
mapper = table_to_mapper[table]
needs_version_id = mapper.version_id_col is not None and \
clause = sql.and_()
for col in mapper._pks_by_table[table]:
- clause.clauses.append(
- col == sql.bindparam(col._label,
- type_=col.type)
- )
+ clause.clauses.append(col == sql.bindparam(col._label,
+ type_=col.type))
if needs_version_id:
clause.clauses.append(mapper.version_id_col ==\
c.last_inserted_params(),
value_params)
- if not postupdate:
- for state, state_dict, mapper, connection, has_identity, \
- instance_key, row_switch in tups:
+ for state, state_dict, mapper, connection, has_identity, \
+ instance_key, row_switch in tups:
- # expire readonly attributes
- readonly = state.unmodified.intersection(
- p.key for p in mapper._readonly_props
- )
-
- if readonly:
- _expire_state(state, state.dict, readonly)
-
- # if eager_defaults option is enabled,
- # refresh whatever has been expired.
- if self.eager_defaults and state.unloaded:
- state.key = self._identity_key_from_state(state)
- uowtransaction.session.query(self)._get(
- state.key, refresh_state=state,
- only_load_props=state.unloaded)
-
- # call after_XXX extensions
- if not has_identity:
- if 'after_insert' in mapper.extension:
- mapper.extension.after_insert(
- mapper, connection, state.obj())
- else:
- if 'after_update' in mapper.extension:
- mapper.extension.after_update(
- mapper, connection, state.obj())
+ # expire readonly attributes
+ readonly = state.unmodified.intersection(
+ p.key for p in mapper._readonly_props
+ )
+
+ if readonly:
+ _expire_state(state, state.dict, readonly)
+
+ # if eager_defaults option is enabled,
+ # refresh whatever has been expired.
+ if self.eager_defaults and state.unloaded:
+ state.key = self._identity_key_from_state(state)
+ uowtransaction.session.query(self)._get(
+ state.key, refresh_state=state,
+ only_load_props=state.unloaded)
+
+ # call after_XXX extensions
+ if not has_identity:
+ if 'after_insert' in mapper.extension:
+ mapper.extension.after_insert(
+ mapper, connection, state.obj())
+ else:
+ if 'after_update' in mapper.extension:
+ mapper.extension.after_update(
+ mapper, connection, state.obj())
def _postfetch(self, uowtransaction, table,
state, dict_, resultproxy,
params, value_params):
- """Expire attributes in need of newly persisted database state."""
+ """During a flush, expire attributes in need of newly
+ persisted database state."""
postfetch_cols = resultproxy.postfetch_cols()
generated_cols = list(resultproxy.prefetch_cols())
"VALUES (:favorite_ball_id, :data)",
lambda ctx:{'favorite_ball_id':b.id, 'data':'some data'}),
- AllOf(
- CompiledSQL("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id':p.id,'ball_id':b.id}),
-
- CompiledSQL("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id':p.id,'ball_id':b2.id}),
-
- CompiledSQL("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id':p.id,'ball_id':b3.id}),
-
- CompiledSQL("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id':p.id,'ball_id':b4.id})
+ CompiledSQL("UPDATE ball SET person_id=:person_id "
+ "WHERE ball.id = :ball_id",
+ lambda ctx:[
+ {'person_id':p.id,'ball_id':b.id},
+ {'person_id':p.id,'ball_id':b2.id},
+ {'person_id':p.id,'ball_id':b3.id},
+ {'person_id':p.id,'ball_id':b4.id}
+ ]
),
+
)
sess.delete(p)
self.assert_sql_execution(testing.db, sess.flush,
- AllOf(CompiledSQL("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id': None, 'ball_id': b.id}),
-
CompiledSQL("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id': None, 'ball_id': b2.id}),
-
- CompiledSQL("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id': None, 'ball_id': b3.id}),
-
- CompiledSQL("UPDATE ball SET person_id=:person_id "
- "WHERE ball.id = :ball_id",
- lambda ctx:{'person_id': None, 'ball_id': b4.id})),
-
+ "WHERE ball.id = :ball_id",
+ lambda ctx:[
+ {'person_id': None, 'ball_id': b.id},
+ {'person_id': None, 'ball_id': b2.id},
+ {'person_id': None, 'ball_id': b3.id},
+ {'person_id': None, 'ball_id': b4.id}
+ ]
+ ),
CompiledSQL("DELETE FROM person WHERE person.id = :id",
lambda ctx:[{'id':p.id}]),
)
session.delete(root)
+
self.assert_sql_execution(
testing.db,
session.flush,
- AllOf(
- CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
- "WHERE node.id = :node_id",
- lambda ctx:{'next_sibling_id':None, 'node_id':about.id}),
- CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
- "WHERE node.id = :node_id",
- lambda ctx:{'node_id':stories.id, 'next_sibling_id':None})
+ CompiledSQL("UPDATE node SET next_sibling_id=:next_sibling_id "
+ "WHERE node.id = :node_id",
+ lambda ctx: [
+ {'node_id': about.id, 'next_sibling_id': None},
+ {'node_id': stories.id, 'next_sibling_id': None}
+ ]
),
AllOf(
CompiledSQL("DELETE FROM node WHERE node.id = :id",