]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use execute_20 to preserve compiled cache
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Oct 2020 22:39:42 +0000 (18:39 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Oct 2020 12:41:05 +0000 (08:41 -0400)
Fixes: #5420
Change-Id: I3e5a255207da752b7b7cc9b8f41ad5e2ccd0b447

lib/sqlalchemy/orm/persistence.py
test/orm/test_unitofworkv2.py

index fa126a279b75198aeb88830fa8f171b85ddbeabc..022f6611f7b165abf05db6f1ebc8aa86f4a4b2d5 100644 (file)
@@ -52,8 +52,6 @@ def _bulk_insert(
 ):
     base_mapper = mapper.base_mapper
 
-    cached_connections = _cached_connection_dict(base_mapper)
-
     if session_transaction.session.connection_callable:
         raise NotImplementedError(
             "connection_callable / per-instance sharding "
@@ -105,7 +103,6 @@ def _bulk_insert(
         _emit_insert_statements(
             base_mapper,
             None,
-            cached_connections,
             super_mapper,
             table,
             records,
@@ -127,8 +124,6 @@ def _bulk_update(
 ):
     base_mapper = mapper.base_mapper
 
-    cached_connections = _cached_connection_dict(base_mapper)
-
     search_keys = mapper._primary_key_propkeys
     if mapper._version_id_prop:
         search_keys = {mapper._version_id_prop.key}.union(search_keys)
@@ -183,7 +178,6 @@ def _bulk_update(
         _emit_update_statements(
             base_mapper,
             None,
-            cached_connections,
             super_mapper,
             table,
             records,
@@ -210,7 +204,6 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
 
     states_to_update = []
     states_to_insert = []
-    cached_connections = _cached_connection_dict(base_mapper)
 
     for (
         state,
@@ -240,7 +233,6 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
         _emit_update_statements(
             base_mapper,
             uowtransaction,
-            cached_connections,
             mapper,
             table,
             update,
@@ -249,7 +241,6 @@ def save_obj(base_mapper, states, uowtransaction, single=False):
         _emit_insert_statements(
             base_mapper,
             uowtransaction,
-            cached_connections,
             mapper,
             table,
             insert,
@@ -282,7 +273,6 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
     specifies post_update.
 
     """
-    cached_connections = _cached_connection_dict(base_mapper)
 
     states_to_update = list(
         _organize_states_for_post_update(base_mapper, states, uowtransaction)
@@ -315,7 +305,6 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols):
         _emit_post_update_statements(
             base_mapper,
             uowtransaction,
-            cached_connections,
             mapper,
             table,
             update,
@@ -330,8 +319,6 @@ def delete_obj(base_mapper, states, uowtransaction):
 
     """
 
-    cached_connections = _cached_connection_dict(base_mapper)
-
     states_to_delete = list(
         _organize_states_for_delete(base_mapper, states, uowtransaction)
     )
@@ -352,7 +339,6 @@ def delete_obj(base_mapper, states, uowtransaction):
         _emit_delete_statements(
             base_mapper,
             uowtransaction,
-            cached_connections,
             mapper,
             table,
             delete,
@@ -856,7 +842,6 @@ def _collect_delete_commands(
 def _emit_update_statements(
     base_mapper,
     uowtransaction,
-    cached_connections,
     mapper,
     table,
     update,
@@ -870,6 +855,8 @@ def _emit_update_statements(
         and mapper.version_id_col in mapper._cols_by_table[table]
     )
 
+    execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
     def update_stmt():
         clauses = BooleanClauseList._construct_raw(operators.and_)
 
@@ -948,7 +935,11 @@ def _emit_update_statements(
                 has_all_defaults,
                 has_all_pks,
             ) in records:
-                c = connection.execute(statement.values(value_params), params)
+                c = connection._execute_20(
+                    statement.values(value_params),
+                    params,
+                    execution_options=execution_options,
+                )
                 if bookkeeping:
                     _postfetch(
                         mapper,
@@ -977,8 +968,8 @@ def _emit_update_statements(
                     has_all_defaults,
                     has_all_pks,
                 ) in records:
-                    c = cached_connections[connection].execute(
-                        statement, params
+                    c = connection._execute_20(
+                        statement, params, execution_options=execution_options
                     )
 
                     # TODO: why with bookkeeping=False?
@@ -1003,8 +994,8 @@ def _emit_update_statements(
                     assert_singlerow and len(multiparams) == 1
                 )
 
-                c = cached_connections[connection].execute(
-                    statement, multiparams
+                c = connection._execute_20(
+                    statement, multiparams, execution_options=execution_options
                 )
 
                 rows += c.rowcount
@@ -1054,7 +1045,6 @@ def _emit_update_statements(
 def _emit_insert_statements(
     base_mapper,
     uowtransaction,
-    cached_connections,
     mapper,
     table,
     insert,
@@ -1065,6 +1055,8 @@ def _emit_insert_statements(
 
     cached_stmt = base_mapper._memo(("insert", table), table.insert)
 
+    execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
     for (
         (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
         records,
@@ -1098,7 +1090,10 @@ def _emit_insert_statements(
             records = list(records)
             multiparams = [rec[2] for rec in records]
 
-            c = cached_connections[connection].execute(statement, multiparams)
+            c = connection._execute_20(
+                statement, multiparams, execution_options=execution_options
+            )
+
             if bookkeeping:
                 for (
                     (
@@ -1154,9 +1149,10 @@ def _emit_insert_statements(
             if do_executemany:
                 multiparams = [rec[2] for rec in records]
 
-                c = cached_connections[connection].execute(
-                    statement, multiparams
+                c = connection._execute_20(
+                    statement, multiparams, execution_options=execution_options
                 )
+
                 if bookkeeping:
                     for (
                         (
@@ -1213,12 +1209,16 @@ def _emit_insert_statements(
                     has_all_defaults,
                 ) in records:
                     if value_params:
-                        result = connection.execute(
-                            statement.values(value_params), params
+                        result = connection._execute_20(
+                            statement.values(value_params),
+                            params,
+                            execution_options=execution_options,
                         )
                     else:
-                        result = cached_connections[connection].execute(
-                            statement, params
+                        result = connection._execute_20(
+                            statement,
+                            params,
+                            execution_options=execution_options,
                         )
 
                     primary_key = result.inserted_primary_key
@@ -1253,11 +1253,13 @@ def _emit_insert_statements(
 
 
 def _emit_post_update_statements(
-    base_mapper, uowtransaction, cached_connections, mapper, table, update
+    base_mapper, uowtransaction, mapper, table, update
 ):
     """Emit UPDATE statements corresponding to value lists collected
     by _collect_post_update_commands()."""
 
+    execution_options = {"compiled_cache": base_mapper._compiled_cache}
+
     needs_version_id = (
         mapper.version_id_col is not None
         and mapper.version_id_col in mapper._cols_by_table[table]
@@ -1316,7 +1318,11 @@ def _emit_post_update_statements(
         if not allow_multirow:
             check_rowcount = assert_singlerow
             for state, state_dict, mapper_rec, connection, params in records:
-                c = cached_connections[connection].execute(statement, params)
+
+                c = connection._execute_20(
+                    statement, params, execution_options=execution_options
+                )
+
                 _postfetch_post_update(
                     mapper_rec,
                     uowtransaction,
@@ -1337,7 +1343,9 @@ def _emit_post_update_statements(
                 assert_singlerow and len(multiparams) == 1
             )
 
-            c = cached_connections[connection].execute(statement, multiparams)
+            c = connection._execute_20(
+                statement, multiparams, execution_options=execution_options
+            )
 
             rows += c.rowcount
             for state, state_dict, mapper_rec, connection, params in records:
@@ -1368,7 +1376,7 @@ def _emit_post_update_statements(
 
 
 def _emit_delete_statements(
-    base_mapper, uowtransaction, cached_connections, mapper, table, delete
+    base_mapper, uowtransaction, mapper, table, delete
 ):
     """Emit DELETE statements corresponding to value lists collected
     by _collect_delete_commands()."""
@@ -1400,8 +1408,7 @@ def _emit_delete_statements(
     for connection, recs in groupby(delete, lambda rec: rec[1]):  # connection
         del_objects = [params for params, connection in recs]
 
-        connection = cached_connections[connection]
-
+        execution_options = {"compiled_cache": base_mapper._compiled_cache}
         expected = len(del_objects)
         rows_matched = -1
         only_warn = False
@@ -1415,7 +1422,10 @@ def _emit_delete_statements(
                 # execute deletes individually so that versioned
                 # rows can be verified
                 for params in del_objects:
-                    c = connection.execute(statement, params)
+
+                    c = connection._execute_20(
+                        statement, params, execution_options=execution_options
+                    )
                     rows_matched += c.rowcount
             else:
                 util.warn(
@@ -1423,9 +1433,13 @@ def _emit_delete_statements(
                     "- versioning cannot be verified."
                     % connection.dialect.dialect_description
                 )
-                connection.execute(statement, del_objects)
+                connection._execute_20(
+                    statement, del_objects, execution_options=execution_options
+                )
         else:
-            c = connection.execute(statement, del_objects)
+            c = connection._execute_20(
+                statement, del_objects, execution_options=execution_options
+            )
 
             if not need_version_id:
                 only_warn = True
@@ -1702,15 +1716,6 @@ def _connections_for_states(base_mapper, uowtransaction, states):
         yield state, state.dict, mapper, connection
 
 
-def _cached_connection_dict(base_mapper):
-    # dictionary of connection->connection_with_cache_options.
-    return util.PopulateDict(
-        lambda conn: conn.execution_options(
-            compiled_cache=base_mapper._compiled_cache
-        )
-    )
-
-
 def _sort_states(mapper, states):
     pending = set(states)
     persistent = set(s for s in pending if s.key is not None)
index e5d9a2f7a52518b5cdbc55ece281d20f6173f8b5..ed320db10426b89baed301b11a575981616633c6 100644 (file)
@@ -4,6 +4,7 @@ from sqlalchemy import exc
 from sqlalchemy import FetchedValue
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
+from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import JSON
 from sqlalchemy import literal
@@ -25,6 +26,7 @@ from sqlalchemy.testing import config
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
 from sqlalchemy.testing.assertsql import AllOf
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.assertsql import Conditional
@@ -3066,3 +3068,36 @@ class NullEvaluatingTest(fixtures.MappedTest, testing.AssertsExecutionResults):
         s.commit()
         eq_(s.query(cast(JSONThing.data, String)).scalar(), "null")
         eq_(s.query(cast(JSONThing.data_null, String)).scalar(), None)
+
+
+class EnsureCacheTest(fixtures.FutureEngineMixin, UOWTest):
+    def test_ensure_cache(self):
+        users, User = self.tables.users, self.classes.User
+
+        mapper(User, users)
+
+        cache = {}
+        eq_(len(inspect(User)._compiled_cache), 0)
+
+        with testing.db.connect().execution_options(
+            compiled_cache=cache
+        ) as conn:
+            s = Session(conn)
+            u1 = User(name="adf")
+            s.add(u1)
+            s.flush()
+
+            is_(conn._execution_options["compiled_cache"], cache)
+            eq_(len(inspect(User)._compiled_cache), 1)
+
+            u1.name = "newname"
+            s.flush()
+
+            is_(conn._execution_options["compiled_cache"], cache)
+            eq_(len(inspect(User)._compiled_cache), 2)
+
+            s.delete(u1)
+            s.flush()
+
+            is_(conn._execution_options["compiled_cache"], cache)
+            eq_(len(inspect(User)._compiled_cache), 3)