):
base_mapper = mapper.base_mapper
- cached_connections = _cached_connection_dict(base_mapper)
-
if session_transaction.session.connection_callable:
raise NotImplementedError(
"connection_callable / per-instance sharding "
_emit_insert_statements(
base_mapper,
None,
- cached_connections,
super_mapper,
table,
records,
):
base_mapper = mapper.base_mapper
- cached_connections = _cached_connection_dict(base_mapper)
-
search_keys = mapper._primary_key_propkeys
if mapper._version_id_prop:
search_keys = {mapper._version_id_prop.key}.union(search_keys)
_emit_update_statements(
base_mapper,
None,
- cached_connections,
super_mapper,
table,
records,
states_to_update = []
states_to_insert = []
- cached_connections = _cached_connection_dict(base_mapper)
for (
state,
_emit_update_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
update,
_emit_insert_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
insert,
specifies post_update.
"""
- cached_connections = _cached_connection_dict(base_mapper)
states_to_update = list(
_organize_states_for_post_update(base_mapper, states, uowtransaction)
_emit_post_update_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
update,
"""
- cached_connections = _cached_connection_dict(base_mapper)
-
states_to_delete = list(
_organize_states_for_delete(base_mapper, states, uowtransaction)
)
_emit_delete_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
delete,
def _emit_update_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
update,
and mapper.version_id_col in mapper._cols_by_table[table]
)
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
def update_stmt():
clauses = BooleanClauseList._construct_raw(operators.and_)
has_all_defaults,
has_all_pks,
) in records:
- c = connection.execute(statement.values(value_params), params)
+ c = connection._execute_20(
+ statement.values(value_params),
+ params,
+ execution_options=execution_options,
+ )
if bookkeeping:
_postfetch(
mapper,
has_all_defaults,
has_all_pks,
) in records:
- c = cached_connections[connection].execute(
- statement, params
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
)
# TODO: why with bookkeeping=False?
assert_singlerow and len(multiparams) == 1
)
- c = cached_connections[connection].execute(
- statement, multiparams
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
)
rows += c.rowcount
def _emit_insert_statements(
base_mapper,
uowtransaction,
- cached_connections,
mapper,
table,
insert,
cached_stmt = base_mapper._memo(("insert", table), table.insert)
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
for (
(connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
records,
records = list(records)
multiparams = [rec[2] for rec in records]
- c = cached_connections[connection].execute(statement, multiparams)
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
+
if bookkeeping:
for (
(
if do_executemany:
multiparams = [rec[2] for rec in records]
- c = cached_connections[connection].execute(
- statement, multiparams
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
)
+
if bookkeeping:
for (
(
has_all_defaults,
) in records:
if value_params:
- result = connection.execute(
- statement.values(value_params), params
+ result = connection._execute_20(
+ statement.values(value_params),
+ params,
+ execution_options=execution_options,
)
else:
- result = cached_connections[connection].execute(
- statement, params
+ result = connection._execute_20(
+ statement,
+ params,
+ execution_options=execution_options,
)
primary_key = result.inserted_primary_key
def _emit_post_update_statements(
- base_mapper, uowtransaction, cached_connections, mapper, table, update
+ base_mapper, uowtransaction, mapper, table, update
):
"""Emit UPDATE statements corresponding to value lists collected
by _collect_post_update_commands()."""
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
needs_version_id = (
mapper.version_id_col is not None
and mapper.version_id_col in mapper._cols_by_table[table]
if not allow_multirow:
check_rowcount = assert_singlerow
for state, state_dict, mapper_rec, connection, params in records:
- c = cached_connections[connection].execute(statement, params)
+
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
+ )
+
_postfetch_post_update(
mapper_rec,
uowtransaction,
assert_singlerow and len(multiparams) == 1
)
- c = cached_connections[connection].execute(statement, multiparams)
+ c = connection._execute_20(
+ statement, multiparams, execution_options=execution_options
+ )
rows += c.rowcount
for state, state_dict, mapper_rec, connection, params in records:
def _emit_delete_statements(
- base_mapper, uowtransaction, cached_connections, mapper, table, delete
+ base_mapper, uowtransaction, mapper, table, delete
):
"""Emit DELETE statements corresponding to value lists collected
by _collect_delete_commands()."""
for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
del_objects = [params for params, connection in recs]
- connection = cached_connections[connection]
-
+ execution_options = {"compiled_cache": base_mapper._compiled_cache}
expected = len(del_objects)
rows_matched = -1
only_warn = False
# execute deletes individually so that versioned
# rows can be verified
for params in del_objects:
- c = connection.execute(statement, params)
+
+ c = connection._execute_20(
+ statement, params, execution_options=execution_options
+ )
rows_matched += c.rowcount
else:
util.warn(
"- versioning cannot be verified."
% connection.dialect.dialect_description
)
- connection.execute(statement, del_objects)
+ connection._execute_20(
+ statement, del_objects, execution_options=execution_options
+ )
else:
- c = connection.execute(statement, del_objects)
+ c = connection._execute_20(
+ statement, del_objects, execution_options=execution_options
+ )
if not need_version_id:
only_warn = True
yield state, state.dict, mapper, connection
-def _cached_connection_dict(base_mapper):
- # dictionary of connection->connection_with_cache_options.
- return util.PopulateDict(
- lambda conn: conn.execution_options(
- compiled_cache=base_mapper._compiled_cache
- )
- )
-
-
def _sort_states(mapper, states):
pending = set(states)
persistent = set(s for s in pending if s.key is not None)
from sqlalchemy import FetchedValue
from sqlalchemy import ForeignKey
from sqlalchemy import func
+from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import JSON
from sqlalchemy import literal
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
from sqlalchemy.testing.assertsql import AllOf
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.testing.assertsql import Conditional
s.commit()
eq_(s.query(cast(JSONThing.data, String)).scalar(), "null")
eq_(s.query(cast(JSONThing.data_null, String)).scalar(), None)
+
+
+class EnsureCacheTest(fixtures.FutureEngineMixin, UOWTest):
+ def test_ensure_cache(self):
+ users, User = self.tables.users, self.classes.User
+
+ mapper(User, users)
+
+ cache = {}
+ eq_(len(inspect(User)._compiled_cache), 0)
+
+ with testing.db.connect().execution_options(
+ compiled_cache=cache
+ ) as conn:
+ s = Session(conn)
+ u1 = User(name="adf")
+ s.add(u1)
+ s.flush()
+
+ is_(conn._execution_options["compiled_cache"], cache)
+ eq_(len(inspect(User)._compiled_cache), 1)
+
+ u1.name = "newname"
+ s.flush()
+
+ is_(conn._execution_options["compiled_cache"], cache)
+ eq_(len(inspect(User)._compiled_cache), 2)
+
+ s.delete(u1)
+ s.flush()
+
+ is_(conn._execution_options["compiled_cache"], cache)
+ eq_(len(inspect(User)._compiled_cache), 3)