From: Mike Bayer Date: Thu, 1 Oct 2020 22:39:42 +0000 (-0400) Subject: use execute_20 to preserve compiled cache X-Git-Tag: rel_1_4_0b1~60^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=1430a30dd2033550699cf5e074cb81058687eb13;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git use execute_20 to preserve compiled cache Fixes: #5420 Change-Id: I3e5a255207da752b7b7cc9b8f41ad5e2ccd0b447 --- diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index fa126a279b..022f6611f7 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -52,8 +52,6 @@ def _bulk_insert( ): base_mapper = mapper.base_mapper - cached_connections = _cached_connection_dict(base_mapper) - if session_transaction.session.connection_callable: raise NotImplementedError( "connection_callable / per-instance sharding " @@ -105,7 +103,6 @@ def _bulk_insert( _emit_insert_statements( base_mapper, None, - cached_connections, super_mapper, table, records, @@ -127,8 +124,6 @@ def _bulk_update( ): base_mapper = mapper.base_mapper - cached_connections = _cached_connection_dict(base_mapper) - search_keys = mapper._primary_key_propkeys if mapper._version_id_prop: search_keys = {mapper._version_id_prop.key}.union(search_keys) @@ -183,7 +178,6 @@ def _bulk_update( _emit_update_statements( base_mapper, None, - cached_connections, super_mapper, table, records, @@ -210,7 +204,6 @@ def save_obj(base_mapper, states, uowtransaction, single=False): states_to_update = [] states_to_insert = [] - cached_connections = _cached_connection_dict(base_mapper) for ( state, @@ -240,7 +233,6 @@ def save_obj(base_mapper, states, uowtransaction, single=False): _emit_update_statements( base_mapper, uowtransaction, - cached_connections, mapper, table, update, @@ -249,7 +241,6 @@ def save_obj(base_mapper, states, uowtransaction, single=False): _emit_insert_statements( base_mapper, uowtransaction, - cached_connections, mapper, table, insert, @@ -282,7 +273,6 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): specifies post_update. """ - cached_connections = _cached_connection_dict(base_mapper) states_to_update = list( _organize_states_for_post_update(base_mapper, states, uowtransaction) @@ -315,7 +305,6 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): _emit_post_update_statements( base_mapper, uowtransaction, - cached_connections, mapper, table, update, @@ -330,8 +319,6 @@ def delete_obj(base_mapper, states, uowtransaction): """ - cached_connections = _cached_connection_dict(base_mapper) - states_to_delete = list( _organize_states_for_delete(base_mapper, states, uowtransaction) ) @@ -352,7 +339,6 @@ def delete_obj(base_mapper, states, uowtransaction): _emit_delete_statements( base_mapper, uowtransaction, - cached_connections, mapper, table, delete, @@ -856,7 +842,6 @@ def _collect_delete_commands( def _emit_update_statements( base_mapper, uowtransaction, - cached_connections, mapper, table, update, @@ -870,6 +855,8 @@ def _emit_update_statements( and mapper.version_id_col in mapper._cols_by_table[table] ) + execution_options = {"compiled_cache": base_mapper._compiled_cache} + def update_stmt(): clauses = BooleanClauseList._construct_raw(operators.and_) @@ -948,7 +935,11 @@ def _emit_update_statements( has_all_defaults, has_all_pks, ) in records: - c = connection.execute(statement.values(value_params), params) + c = connection._execute_20( + statement.values(value_params), + params, + execution_options=execution_options, + ) if bookkeeping: _postfetch( mapper, @@ -977,8 +968,8 @@ def _emit_update_statements( has_all_defaults, has_all_pks, ) in records: - c = cached_connections[connection].execute( - statement, params + c = connection._execute_20( + statement, params, execution_options=execution_options ) # TODO: why with bookkeeping=False? @@ -1003,8 +994,8 @@ def _emit_update_statements( assert_singlerow and len(multiparams) == 1 ) - c = cached_connections[connection].execute( - statement, multiparams + c = connection._execute_20( + statement, multiparams, execution_options=execution_options ) rows += c.rowcount @@ -1054,7 +1045,6 @@ def _emit_update_statements( def _emit_insert_statements( base_mapper, uowtransaction, - cached_connections, mapper, table, insert, @@ -1065,6 +1055,8 @@ def _emit_insert_statements( cached_stmt = base_mapper._memo(("insert", table), table.insert) + execution_options = {"compiled_cache": base_mapper._compiled_cache} + for ( (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), records, @@ -1098,7 +1090,10 @@ def _emit_insert_statements( records = list(records) multiparams = [rec[2] for rec in records] - c = cached_connections[connection].execute(statement, multiparams) + c = connection._execute_20( + statement, multiparams, execution_options=execution_options + ) + if bookkeeping: for ( ( @@ -1154,9 +1149,10 @@ def _emit_insert_statements( if do_executemany: multiparams = [rec[2] for rec in records] - c = cached_connections[connection].execute( - statement, multiparams + c = connection._execute_20( + statement, multiparams, execution_options=execution_options ) + if bookkeeping: for ( ( @@ -1213,12 +1209,16 @@ def _emit_insert_statements( has_all_defaults, ) in records: if value_params: - result = connection.execute( - statement.values(value_params), params + result = connection._execute_20( + statement.values(value_params), + params, + execution_options=execution_options, ) else: - result = cached_connections[connection].execute( - statement, params + result = connection._execute_20( + statement, + params, + execution_options=execution_options, ) primary_key = result.inserted_primary_key @@ -1253,11 +1253,13 @@ def _emit_insert_statements( def _emit_post_update_statements( - base_mapper, uowtransaction, cached_connections, mapper, table, update + base_mapper, uowtransaction, mapper, table, update ): """Emit UPDATE statements corresponding to value lists collected by _collect_post_update_commands().""" + execution_options = {"compiled_cache": base_mapper._compiled_cache} + needs_version_id = ( mapper.version_id_col is not None and mapper.version_id_col in mapper._cols_by_table[table] @@ -1316,7 +1318,11 @@ def _emit_post_update_statements( if not allow_multirow: check_rowcount = assert_singlerow for state, state_dict, mapper_rec, connection, params in records: - c = cached_connections[connection].execute(statement, params) + + c = connection._execute_20( + statement, params, execution_options=execution_options + ) + _postfetch_post_update( mapper_rec, uowtransaction, @@ -1337,7 +1343,9 @@ def _emit_post_update_statements( assert_singlerow and len(multiparams) == 1 ) - c = cached_connections[connection].execute(statement, multiparams) + c = connection._execute_20( + statement, multiparams, execution_options=execution_options + ) rows += c.rowcount for state, state_dict, mapper_rec, connection, params in records: @@ -1368,7 +1376,7 @@ def _emit_post_update_statements( def _emit_delete_statements( - base_mapper, uowtransaction, cached_connections, mapper, table, delete + base_mapper, uowtransaction, mapper, table, delete ): """Emit DELETE statements corresponding to value lists collected by _collect_delete_commands().""" @@ -1400,8 +1408,7 @@ def _emit_delete_statements( for connection, recs in groupby(delete, lambda rec: rec[1]): # connection del_objects = [params for params, connection in recs] - connection = cached_connections[connection] - + execution_options = {"compiled_cache": base_mapper._compiled_cache} expected = len(del_objects) rows_matched = -1 only_warn = False @@ -1415,7 +1422,10 @@ def _emit_delete_statements( # execute deletes individually so that versioned # rows can be verified for params in del_objects: - c = connection.execute(statement, params) + + c = connection._execute_20( + statement, params, execution_options=execution_options + ) rows_matched += c.rowcount else: util.warn( @@ -1423,9 +1433,13 @@ def _emit_delete_statements( "- versioning cannot be verified." % connection.dialect.dialect_description ) - connection.execute(statement, del_objects) + connection._execute_20( + statement, del_objects, execution_options=execution_options + ) else: - c = connection.execute(statement, del_objects) + c = connection._execute_20( + statement, del_objects, execution_options=execution_options + ) if not need_version_id: only_warn = True @@ -1702,15 +1716,6 @@ def _connections_for_states(base_mapper, uowtransaction, states): yield state, state.dict, mapper, connection -def _cached_connection_dict(base_mapper): - # dictionary of connection->connection_with_cache_options. - return util.PopulateDict( - lambda conn: conn.execution_options( - compiled_cache=base_mapper._compiled_cache - ) - ) - - def _sort_states(mapper, states): pending = set(states) persistent = set(s for s in pending if s.key is not None) diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index e5d9a2f7a5..ed320db104 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -4,6 +4,7 @@ from sqlalchemy import exc from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey from sqlalchemy import func +from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import JSON from sqlalchemy import literal @@ -25,6 +26,7 @@ from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional @@ -3066,3 +3068,36 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults): s.commit() eq_(s.query(cast(JSONThing.data, String)).scalar(), "null") eq_(s.query(cast(JSONThing.data_null, String)).scalar(), None) + + +class EnsureCacheTest(fixtures.FutureEngineMixin, UOWTest): + def test_ensure_cache(self): + users, User = self.tables.users, self.classes.User + + mapper(User, users) + + cache = {} + eq_(len(inspect(User)._compiled_cache), 0) + + with testing.db.connect().execution_options( + compiled_cache=cache + ) as conn: + s = Session(conn) + u1 = User(name="adf") + s.add(u1) + s.flush() + + is_(conn._execution_options["compiled_cache"], cache) + eq_(len(inspect(User)._compiled_cache), 1) + + u1.name = "newname" + s.flush() + + is_(conn._execution_options["compiled_cache"], cache) + eq_(len(inspect(User)._compiled_cache), 2) + + s.delete(u1) + s.flush() + + is_(conn._execution_options["compiled_cache"], cache) + eq_(len(inspect(User)._compiled_cache), 3)