"not supported in bulk_insert()")
connection = uowtransaction.transaction.connection(base_mapper)
-
+ value_params = {}
for table, sub_mapper in base_mapper._sorted_tables.items():
if not mapper.isa(sub_mapper):
continue
- to_translate = dict(
- (mapper._columntoproperty[col].key, col.key)
- for col in mapper._cols_by_table[table]
- )
has_version_generator = mapper.version_id_generator is not False and \
mapper.version_id_col is not None
- multiparams = []
+
+ records = []
for mapping in mappings:
params = dict(
- (k, mapping.get(v)) for k, v in to_translate.items()
+ (col.key, mapping[propkey])
+ for col, propkey in mapper._col_to_propkey[table]
+ if propkey in mapping
)
+
if has_version_generator:
params[mapper.version_id_col.key] = \
mapper.version_id_generator(None)
- multiparams.append(params)
- statement = base_mapper._memo(('insert', table), table.insert)
- cached_connections[connection].execute(statement, multiparams)
+ records.append(
+ (None, None, params, sub_mapper,
+ connection, value_params, True, True)
+ )
+
+ _emit_insert_statements(base_mapper, uowtransaction,
+ cached_connections,
+ mapper, table, records,
+ bookkeeping=False)
def bulk_update(mapper, mappings, uowtransaction):
connection = uowtransaction.transaction.connection(base_mapper)
+ value_params = {}
for table, sub_mapper in base_mapper._sorted_tables.items():
if not mapper.isa(sub_mapper):
continue
- needs_version_id = sub_mapper.version_id_col is not None and \
- table.c.contains_column(sub_mapper.version_id_col)
-
- def update_stmt():
- return _update_stmt_for_mapper(sub_mapper, table, needs_version_id)
-
- statement = base_mapper._memo(('update', table), update_stmt)
+ label_pks = mapper._pks_by_table[table]
+ if mapper.version_id_col is not None:
+ label_pks = label_pks.union([mapper.version_id_col])
- pks = mapper._pks_by_table[table]
to_translate = dict(
- (mapper._columntoproperty[col].key, col._label
- if col in pks else col.key)
- for col in mapper._cols_by_table[table]
+ (propkey, col._label if col in label_pks else col.key)
+ for col, propkey in mapper._col_to_propkey[table]
)
- for colnames, sub_mappings in groupby(
- mappings,
- lambda mapping: sorted(tuple(mapping.keys()))):
-
- multiparams = []
- for mapping in sub_mappings:
- params = dict(
- (to_translate[k], v) for k, v in mapping.items()
- )
- multiparams.append(params)
-
- c = cached_connections[connection].execute(statement, multiparams)
+ records = []
+ for mapping in mappings:
+ params = dict(
+ (to_translate[k], v) for k, v in mapping.items()
+ )
- rows = c.rowcount
+ if mapper.version_id_generator is not False and \
+ mapper.version_id_col is not None and \
+ mapper.version_id_col.key not in params:
+ params[mapper.version_id_col.key] = \
+ mapper.version_id_generator(
+ params[mapper.version_id_col._label])
- if connection.dialect.supports_sane_rowcount:
- if rows != len(multiparams):
- raise orm_exc.StaleDataError(
- "UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched." %
- (table.description, len(multiparams), rows))
+ records.append(
+ (None, None, params, sub_mapper, connection, value_params)
+ )
- elif needs_version_id:
- util.warn("Dialect %s does not support updated rowcount "
- "- versioning cannot be verified." %
- c.dialect.dialect_description,
- stacklevel=12)
+ _emit_update_statements(base_mapper, uowtransaction,
+ cached_connections,
+ mapper, table, records,
+ bookkeeping=False)
def save_obj(
if table not in mapper._pks_by_table:
continue
- pks = mapper._pks_by_table[table]
-
params = {}
value_params = {}
-
- has_all_pks = True
- has_all_defaults = True
- has_version_id_generator = mapper.version_id_generator is not False \
- and mapper.version_id_col is not None
- for col in mapper._cols_by_table[table]:
- if has_version_id_generator and col is mapper.version_id_col:
- val = mapper.version_id_generator(None)
- params[col.key] = val
+ for col, propkey in mapper._col_to_propkey[table]:
+ if propkey in state_dict:
+ value = state_dict[propkey]
+ if isinstance(value, sql.ClauseElement):
+ value_params[col.key] = value
+ elif value is not None or (
+ not col.primary_key and
+ not col.server_default and
+ not col.default):
+ params[col.key] = value
else:
- # pull straight from the dict for
- # pending objects
- prop = mapper._columntoproperty[col]
- value = state_dict.get(prop.key, None)
+ if not col.server_default \
+ and not col.default and not col.primary_key:
+ params[col.key] = None
- if value is None:
- if col in pks:
- has_all_pks = False
- elif col.default is None and \
- col.server_default is None:
- params[col.key] = value
- elif col.server_default is not None and \
- mapper.base_mapper.eager_defaults:
- has_all_defaults = False
+ has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
- elif isinstance(value, sql.ClauseElement):
- value_params[col] = value
- else:
- params[col.key] = value
+ if base_mapper.eager_defaults:
+ has_all_defaults = mapper._server_default_cols[table].\
+ issubset(params)
+ else:
+ has_all_defaults = True
+
+ if mapper.version_id_generator is not False \
+ and mapper.version_id_col is not None and \
+ mapper.version_id_col in mapper._cols_by_table[table]:
+ params[mapper.version_id_col.key] = \
+ mapper.version_id_generator(None)
insert.append((state, state_dict, params, mapper,
connection, value_params, has_all_pks,
return delete
-def _update_stmt_for_mapper(mapper, table, needs_version_id):
- clause = sql.and_()
-
- for col in mapper._pks_by_table[table]:
- clause.clauses.append(col == sql.bindparam(col._label,
- type_=col.type))
-
- if needs_version_id:
- clause.clauses.append(
- mapper.version_id_col == sql.bindparam(
- mapper.version_id_col._label,
- type_=mapper.version_id_col.type))
-
- stmt = table.update(clause)
- if mapper.base_mapper.eager_defaults:
- stmt = stmt.return_defaults()
- elif mapper.version_id_col is not None:
- stmt = stmt.return_defaults(mapper.version_id_col)
-
- return stmt
-
-
def _emit_update_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, update):
+ cached_connections, mapper, table, update,
+ bookkeeping=True):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_update_commands()."""
table.c.contains_column(mapper.version_id_col)
def update_stmt():
- return _update_stmt_for_mapper(mapper, table, needs_version_id)
+ clause = sql.and_()
+
+ for col in mapper._pks_by_table[table]:
+ clause.clauses.append(col == sql.bindparam(col._label,
+ type_=col.type))
+
+ if needs_version_id:
+ clause.clauses.append(
+ mapper.version_id_col == sql.bindparam(
+ mapper.version_id_col._label,
+ type_=mapper.version_id_col.type))
+
+ stmt = table.update(clause)
+ if mapper.base_mapper.eager_defaults:
+ stmt = stmt.return_defaults()
+ elif mapper.version_id_col is not None:
+ stmt = stmt.return_defaults(mapper.version_id_col)
+
+ return stmt
statement = base_mapper._memo(('update', table), update_stmt)
c = connection.execute(
statement.values(value_params),
params)
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params)
+ if bookkeeping:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params)
rows += c.rowcount
else:
multiparams = [rec[2] for rec in records]
execute(statement, multiparams)
rows += c.rowcount
- for state, state_dict, params, mapper, \
- connection, value_params in records:
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params)
+ if bookkeeping:
+ for state, state_dict, params, mapper, \
+ connection, value_params in records:
+ _postfetch(
+ mapper,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ c.context.compiled_parameters[0],
+ value_params)
if connection.dialect.supports_sane_rowcount:
if rows != len(records):
def _emit_insert_statements(base_mapper, uowtransaction,
- cached_connections, mapper, table, insert):
+ cached_connections, mapper, table, insert,
+ bookkeeping=True):
"""Emit INSERT statements corresponding to value lists collected
by _collect_insert_commands()."""
for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
records in groupby(insert,
lambda rec: (rec[4],
- list(rec[2].keys()),
+ tuple(sorted(rec[2].keys())),
bool(rec[5]),
rec[6], rec[7])
):
- if \
+ if not bookkeeping or \
(
has_all_defaults
or not base_mapper.eager_defaults
c = cached_connections[connection].\
execute(statement, multiparams)
- for (state, state_dict, params, mapper_rec,
- conn, value_params, has_all_pks, has_all_defaults), \
- last_inserted_params in \
- zip(records, c.context.compiled_parameters):
- _postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- last_inserted_params,
- value_params)
+ if bookkeeping:
+ for (state, state_dict, params, mapper_rec,
+ conn, value_params, has_all_pks, has_all_defaults), \
+ last_inserted_params in \
+ zip(records, c.context.compiled_parameters):
+ _postfetch(
+ mapper_rec,
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c,
+ last_inserted_params,
+ value_params)
else:
if not has_all_defaults and base_mapper.eager_defaults: