self.cursor = self.create_cursor()
if self.isinsert or self.isupdate:
self.__process_defaults()
-
+ self.postfetch_cols = self.compiled.postfetch
+ self.prefetch_cols = self.compiled.prefetch
+
processors = dict(
(key, value) for key, value in
( (compiled.bind_names[bindparam],
"""
conn = self._connection
- if isinstance(stmt, unicode) and not self.dialect.supports_unicode_statements:
+ if isinstance(stmt, unicode) and \
+ not self.dialect.supports_unicode_statements:
stmt = stmt.encode(self.dialect.encoding)
if self.dialect.positional:
def post_insert(self):
if self.dialect.postfetch_lastrowid and \
- (not len(self._inserted_primary_key) or \
- None in self._inserted_primary_key):
+ (not len(self.inserted_primary_key) or \
+ None in self.inserted_primary_key):
table = self.compiled.statement.table
lastrowid = self.get_lastrowid()
- self._inserted_primary_key = [c is table._autoincrement_column and lastrowid or v
- for c, v in zip(table.primary_key, self._inserted_primary_key)
+ self.inserted_primary_key = [
+ c is table._autoincrement_column and lastrowid or v
+ for c, v in zip(table.primary_key, self.inserted_primary_key)
]
def _fetch_implicit_returning(self, resultproxy):
row = resultproxy.fetchone()
ipk = []
- for c, v in zip(table.primary_key, self._inserted_primary_key):
+ for c, v in zip(table.primary_key, self.inserted_primary_key):
if v is not None:
ipk.append(v)
else:
ipk.append(row[c])
- self._inserted_primary_key = ipk
+ self.inserted_primary_key = ipk
def lastrow_has_defaults(self):
- return hasattr(self, 'postfetch_cols') and len(self.postfetch_cols)
+ return (self.isinsert or self.isupdate) and \
+ bool(self.postfetch_cols)
def set_input_sizes(self, translate=None, exclude_types=None):
"""Given a cursor and ClauseParameters, call the appropriate
else:
return self._exec_default(column.onupdate)
- @util.memoized_property
- def _inserted_primary_key(self):
-
- if not self.isinsert:
- raise exc.InvalidRequestError(
- "Statement is not an insert() expression construct.")
- elif self._is_explicit_returning:
- raise exc.InvalidRequestError(
- "Can't call inserted_primary_key when returning() "
- "is used.")
-
-
- # lazyily evaluate inserted_primary_key for executemany.
- # for execute(), its already in __dict__.
- if self.executemany:
- return [
- [compiled_parameters.get(c.key, None)
- for c in self.compiled.\
- statement.table.primary_key
- ] for compiled_parameters in self.compiled_parameters
- ]
- else:
- # _inserted_primary_key should be calced here
- assert False
-
def __process_defaults(self):
"""Generate default values for compiled insert/update statements,
and generate inserted_primary_key collection.
if val is not None:
param[c.key] = val
del self.current_parameters
-
- if self.isinsert:
- self.last_inserted_params = self.compiled_parameters
- else:
- self.last_updated_params = self.compiled_parameters
-
else:
self.current_parameters = compiled_parameters = \
self.compiled_parameters[0]
compiled_parameters[c.key] = val
del self.current_parameters
- if self.isinsert and not self._is_explicit_returning:
- self._inserted_primary_key = [
+ if self.isinsert:
+ self.inserted_primary_key = [
self.compiled_parameters[0].get(c.key, None)
for c in self.compiled.\
statement.table.primary_key
]
- if self.isinsert:
- self.last_inserted_params = compiled_parameters
- else:
- self.last_updated_params = compiled_parameters
-
- self.postfetch_cols = self.compiled.postfetch
- self.prefetch_cols = self.compiled.prefetch
DefaultDialect.execution_ctx_cls = DefaultExecutionContext
# if session has a connection callable,
# organize individual states with the connection
# to use for update
- if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ if uowtransaction.session.connection_callable:
connection_callable = \
- uowtransaction.mapper_flush_opts['connection_callable']
+ uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
of objects.
This is called within the context of a UOWTransaction during a
- flush operation.
+ flush operation, given a list of states to be flushed. The
+ base mapper in an inheritance hierarchy handles the inserts/
+ updates for all descendant mappers.
- `_save_obj` issues SQL statements not just for instances mapped
- directly by this mapper, but for instances mapped by all
- inheriting mappers as well. This is to maintain proper insert
- ordering among a polymorphic chain of instances. Therefore
- _save_obj is typically called only on a *base mapper*, or a
- mapper which does not inherit from any other mapper.
-
"""
# if batch=false, call _save_obj separately for each object
# if session has a connection callable,
# organize individual states with the connection
# to use for insert/update
- if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ if uowtransaction.session.connection_callable:
connection_callable = \
- uowtransaction.mapper_flush_opts['connection_callable']
+ uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
instance_key = state.key or mapper._identity_key_from_state(state)
row_switch = None
+
# call before_XXX extensions
if not has_identity:
mapper.dispatch.on_before_insert(mapper, conn, state)
params = {}
value_params = {}
- hasdata = False
- has_all_pks = True
if isinsert:
+ has_all_pks = True
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col.key] = \
value = prop.get_col_value(col, value)
if value is None:
- if col.default is None and \
- col.server_default is None and \
- col not in pks:
-
- params[col.key] = value
- elif col in pks:
+ if col in pks:
has_all_pks = False
+ elif col.default is None and \
+ col.server_default is None:
+ params[col.key] = value
+
elif isinstance(value, sql.ClauseElement):
value_params[col] = value
else:
insert.append((state, state_dict, params, mapper,
connection, value_params, has_all_pks))
else:
+ hasdata = False
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col._label] = \
else:
hasdata = True
elif col in pks:
- value = state.manager[prop.key].impl.get(state, state_dict)
+ value = state.manager[prop.key].\
+ impl.get(state, state_dict)
if prop.get_col_value:
value = prop.get_col_value(col, value)
params[col._label] = value
statement = self._memo(('update', table), update_stmt)
rows = 0
- postfetch = []
for state, state_dict, params, mapper, \
connection, value_params in update:
c = cached_connections[connection].\
execute(statement, params)
- postfetch.append((mapper, state, state_dict,
- c.prefetch_cols(), c.postfetch_cols(),
- c.last_updated_params(), value_params))
+ mapper._postfetch(
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c.context.prefetch_cols,
+ c.context.postfetch_cols,
+ c.context.compiled_parameters[0],
+ value_params)
rows += c.rowcount
- for mapper, pf in groupby(
- postfetch, lambda rec: rec[0]
- ):
- mapper._postfetch(uowtransaction, table, pf)
-
-
if connection.dialect.supports_sane_rowcount:
if rows != len(update):
raise orm_exc.StaleDataError(
if insert:
statement = self._memo(('insert', table), table.insert)
- postfetch = []
- for (connection, pkeys, hasvalue, has_all_pks), records in groupby(
- insert, lambda rec: (rec[4], rec[2].keys(), bool(rec[5]), rec[6])
+ for (connection, pkeys, hasvalue, has_all_pks), \
+ records in groupby(insert,
+ lambda rec: (rec[4],
+ rec[2].keys(),
+ bool(rec[5]),
+ rec[6])
):
if has_all_pks and not hasvalue:
records = list(records)
- multiparams = [params for state, state_dict,
- params, mapper, conn, value_params,
- has_all_pks in records]
+ multiparams = [rec[2] for rec in records]
c = cached_connections[connection].\
execute(statement, multiparams)
- for (state, state_dict, params, mapper, conn, value_params, has_all_pks), \
- last_inserted_params in zip(records, c.context.compiled_parameters):
- postfetch.append((mapper, state, state_dict,
- c.prefetch_cols(), c.postfetch_cols(),
- last_inserted_params, {}))
+ for (state, state_dict, params, mapper,
+ conn, value_params, has_all_pks), \
+ last_inserted_params in \
+ zip(records, c.context.compiled_parameters):
+ mapper._postfetch(
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ c.context.prefetch_cols,
+ c.context.postfetch_cols,
+ last_inserted_params,
+ value_params)
else:
for state, state_dict, params, mapper, \
- connection, value_params, has_all_pks in records:
+ connection, value_params, \
+ has_all_pks in records:
if value_params:
- c = connection.execute(
- statement.values(value_params),
- params)
+ result = connection.execute(
+ statement.values(value_params),
+ params)
else:
- c = cached_connections[connection].\
+ result = cached_connections[connection].\
execute(statement, params)
- primary_key = c.inserted_primary_key
+ primary_key = result.context.inserted_primary_key
if primary_key is not None:
# set primary key attributes
for pk, col in zip(primary_key,
mapper._pks_by_table[table]):
- # TODO: make sure this inlined code is OK
- # with composites
prop = mapper._columntoproperty[col]
if state_dict.get(prop.key) is None:
# TODO: would rather say:
#state_dict[prop.key] = pk
- mapper._set_state_attr_by_column(state,
- state_dict,
- col, pk)
+ mapper._set_state_attr_by_column(
+ state,
+ state_dict,
+ col, pk)
+
+ mapper._postfetch(
+ uowtransaction,
+ table,
+ state,
+ state_dict,
+ result.context.prefetch_cols,
+ result.context.postfetch_cols,
+ result.context.compiled_parameters[0],
+ value_params)
- postfetch.append((mapper, state, state_dict,
- c.prefetch_cols(), c.postfetch_cols(),
- c.last_inserted_params(), value_params))
-
- for mapper, pf in groupby(
- postfetch, lambda rec: rec[0]
- ):
- mapper._postfetch(uowtransaction, table, pf)
for state, state_dict, mapper, connection, has_identity, \
instance_key, row_switch in tups:
mapper.dispatch.on_after_update(mapper, connection, state)
def _postfetch(self, uowtransaction, table,
- recs):
+ state, dict_, prefetch_cols, postfetch_cols,
+ params, value_params):
"""During a flush, expire attributes in need of newly
persisted database state."""
- for m, state, dict_, prefetch_cols, postfetch_cols, \
- params, value_params in recs:
- postfetch_cols = postfetch_cols
- generated_cols = list(prefetch_cols)
-
- if self.version_id_col is not None:
- generated_cols.append(self.version_id_col)
-
- for c in generated_cols:
- if c.key in params and c in self._columntoproperty:
- self._set_state_attr_by_column(state, dict_, c, params[c.key])
-
- if postfetch_cols:
- sessionlib._expire_state(state, state.dict,
- [self._columntoproperty[c].key
- for c in postfetch_cols]
- )
-
- # synchronize newly inserted ids from one table to the next
- # TODO: this still goes a little too often. would be nice to
- # have definitive list of "columns that changed" here
- for m, equated_pairs in self._table_to_equated[table]:
- sync.populate(state, m, state, m,
- equated_pairs,
- uowtransaction,
- self.passive_updates)
+ if self.version_id_col is not None:
+ prefetch_cols = list(prefetch_cols) + [self.version_id_col]
+
+ for c in prefetch_cols:
+ if c.key in params and c in self._columntoproperty:
+ self._set_state_attr_by_column(state, dict_, c, params[c.key])
+
+ if postfetch_cols:
+ sessionlib._expire_state(state, state.dict,
+ [self._columntoproperty[c].key
+ for c in postfetch_cols]
+ )
+
+ # synchronize newly inserted ids from one table to the next
+ # TODO: this still goes a little too often. would be nice to
+ # have definitive list of "columns that changed" here
+ for m, equated_pairs in self._table_to_equated[table]:
+ sync.populate(state, m, state, m,
+ equated_pairs,
+ uowtransaction,
+ self.passive_updates)
@util.memoized_property
def _table_to_equated(self):
flush operation.
"""
- if 'connection_callable' in uowtransaction.mapper_flush_opts:
+ if uowtransaction.session.connection_callable:
connection_callable = \
- uowtransaction.mapper_flush_opts['connection_callable']
+ uowtransaction.session.connection_callable
else:
connection = uowtransaction.transaction.connection(self)
connection_callable = None
from sqlalchemy import Integer, String, ForeignKey, func
from test.orm import _fixtures, _base
from sqlalchemy.orm import mapper, relationship, backref, \
- create_session, unitofwork, attributes
+ create_session, unitofwork, attributes,\
+ Session
from test.lib.assertsql import AllOf, CompiledSQL
from test.orm._fixtures import keywords, addresses, Base, Keyword, \
sess.flush()
+class BatchInsertsTest(_base.MappedTest, testing.AssertsExecutionResults):
+ @classmethod
+ def define_tables(cls, metadata):
+ Table('t', metadata,
+ Column('id', Integer, primary_key=True,
+ test_needs_autoincrement=True),
+ Column('data', String(50)),
+ Column('def_', String(50), server_default='def1')
+ )
+ @testing.resolve_artifact_names
+ def test_batch_interaction(self):
+ """test batching groups same-structured, primary
+ key present statements together.
+
+ """
+ class T(Base):
+ pass
+ mapper(T, t)
+ sess = Session()
+ sess.add_all([
+ T(data='t1'),
+ T(data='t2'),
+ T(id=3, data='t3'),
+ T(id=4, data='t4'),
+ T(id=5, data='t5'),
+ T(id=6, data=func.lower('t6')),
+ T(id=7, data='t7'),
+ T(id=8, data='t8'),
+ T(id=9, data='t9', def_='def2'),
+ T(id=10, data='t10', def_='def3'),
+ T(id=11, data='t11'),
+ ])
+ self.assert_sql_execution(
+ testing.db,
+ sess.flush,
+ CompiledSQL(
+ "INSERT INTO t (data) VALUES (:data)",
+ {'data': 't1'}
+ ),
+ CompiledSQL(
+ "INSERT INTO t (data) VALUES (:data)",
+ {'data': 't2'}
+ ),
+ CompiledSQL(
+ "INSERT INTO t (id, data) VALUES (:id, :data)",
+ [{'data': 't3', 'id': 3},
+ {'data': 't4', 'id': 4},
+ {'data': 't5', 'id': 5}]
+ ),
+ CompiledSQL(
+ "INSERT INTO t (id, data) VALUES (:id, lower(:lower_1))",
+ {'lower_1': 't6', 'id': 6}
+ ),
+ CompiledSQL(
+ "INSERT INTO t (id, data) VALUES (:id, :data)",
+ [{'data': 't7', 'id': 7}, {'data': 't8', 'id': 8}]
+ ),
+ CompiledSQL(
+ "INSERT INTO t (id, data, def_) VALUES (:id, :data, :def_)",
+ [{'data': 't9', 'id': 9, 'def_':'def2'},
+ {'data': 't10', 'id': 10, 'def_':'def3'}]
+ ),
+ CompiledSQL(
+ "INSERT INTO t (id, data) VALUES (:id, :data)",
+ {'data': 't11', 'id': 11}
+ ),
+ )