]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- port the _collect_insert_commands optimizations from ticket_3100
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Aug 2014 22:33:42 +0000 (18:33 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Aug 2014 22:33:42 +0000 (18:33 -0400)
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py

index 06ec2bf144c64ed50ddc4a117d36808824cefa35..fc15769cd48d8c8fa0159ea798bf67f09cb0dd28 100644 (file)
@@ -1892,6 +1892,41 @@ class Mapper(InspectionAttr):
 
     """
 
+    @_memoized_configured_property
+    def _col_to_propkey(self):
+        return dict(
+            (
+                table,
+                [
+                    (col, self._columntoproperty[col].key)
+                    for col in columns
+                ]
+            )
+            for table, columns in self._cols_by_table.items()
+        )
+
+    @_memoized_configured_property
+    def _pk_keys_by_table(self):
+        return dict(
+            (
+                table,
+                frozenset([col.key for col in pks])
+            )
+            for table, pks in self._pks_by_table.items()
+        )
+
+    @_memoized_configured_property
+    def _server_default_cols(self):
+        return dict(
+            (
+                table,
+                frozenset([
+                    col for col in columns
+                    if col.server_default is not None])
+            )
+            for table, columns in self._cols_by_table.items()
+        )
+
     @property
     def selectable(self):
         """The :func:`.select` construct this :class:`.Mapper` selects from
index 9d39c39b07fba2baea0029d2180f57d6f24d56e4..b7283e48e2abb60ecf8f0fbee492ad2b8c6538f8 100644 (file)
@@ -18,7 +18,7 @@ import operator
 from itertools import groupby
 from .. import sql, util, exc as sa_exc, schema
 from . import attributes, sync, exc as orm_exc, evaluator
-from .base import _state_mapper, state_str, _attr_as_key
+from .base import state_str, _attr_as_key
 from ..sql import expression
 from . import loading
 
@@ -141,6 +141,7 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
             states):
 
         has_identity = bool(state.key)
+
         instance_key = state.key or mapper._identity_key_from_state(state)
 
         row_switch = None
@@ -183,12 +184,12 @@ def _organize_states_for_save(base_mapper, states, uowtransaction):
         if not has_identity and not row_switch:
             states_to_insert.append(
                 (state, dict_, mapper, connection,
-                 has_identity, instance_key, row_switch)
+                 has_identity, row_switch)
             )
         else:
             states_to_update.append(
                 (state, dict_, mapper, connection,
-                 has_identity, instance_key, row_switch)
+                 has_identity, row_switch)
             )
 
     return states_to_insert, states_to_update
@@ -237,43 +238,41 @@ def _collect_insert_commands(base_mapper, uowtransaction, table,
     """
     insert = []
     for state, state_dict, mapper, connection, has_identity, \
-            instance_key, row_switch in states_to_insert:
+            row_switch in states_to_insert:
+
         if table not in mapper._pks_by_table:
             continue
 
-        pks = mapper._pks_by_table[table]
-
         params = {}
         value_params = {}
-
-        has_all_pks = True
-        has_all_defaults = True
-        has_version_id_generator = mapper.version_id_generator is not False \
-            and mapper.version_id_col is not None
-        for col in mapper._cols_by_table[table]:
-            if has_version_id_generator and col is mapper.version_id_col:
-                val = mapper.version_id_generator(None)
-                params[col.key] = val
+        for col, propkey in mapper._col_to_propkey[table]:
+            if propkey in state_dict:
+                value = state_dict[propkey]
+                if isinstance(value, sql.ClauseElement):
+                    value_params[col.key] = value
+                elif value is not None or (
+                        not col.primary_key and
+                        not col.server_default and
+                        not col.default):
+                    params[col.key] = value
             else:
-                # pull straight from the dict for
-                # pending objects
-                prop = mapper._columntoproperty[col]
-                value = state_dict.get(prop.key, None)
+                if not col.server_default \
+                        and not col.default and not col.primary_key:
+                    params[col.key] = None
 
-                if value is None:
-                    if col in pks:
-                        has_all_pks = False
-                    elif col.default is None and \
-                            col.server_default is None:
-                        params[col.key] = value
-                    elif col.server_default is not None and \
-                            mapper.base_mapper.eager_defaults:
-                        has_all_defaults = False
+        has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
 
-                elif isinstance(value, sql.ClauseElement):
-                    value_params[col] = value
-                else:
-                    params[col.key] = value
+        if base_mapper.eager_defaults:
+            has_all_defaults = mapper._server_default_cols[table].\
+                issubset(params)
+        else:
+            has_all_defaults = True
+
+        if mapper.version_id_generator is not False \
+                and mapper.version_id_col is not None and \
+                mapper.version_id_col in mapper._cols_by_table[table]:
+            params[mapper.version_id_col.key] = \
+                mapper.version_id_generator(None)
 
         insert.append((state, state_dict, params, mapper,
                        connection, value_params, has_all_pks,
@@ -296,7 +295,7 @@ def _collect_update_commands(base_mapper, uowtransaction,
 
     update = []
     for state, state_dict, mapper, connection, has_identity, \
-            instance_key, row_switch in states_to_update:
+            row_switch in states_to_update:
         if table not in mapper._pks_by_table:
             continue
 
@@ -571,7 +570,7 @@ def _emit_insert_statements(base_mapper, uowtransaction,
     for (connection, pkeys, hasvalue, has_all_pks, has_all_defaults), \
         records in groupby(insert,
                            lambda rec: (rec[4],
-                                        list(rec[2].keys()),
+                                        tuple(sorted(rec[2].keys())),
                                         bool(rec[5]),
                                         rec[6], rec[7])
                            ):
@@ -762,7 +761,7 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction,
 
     """
     for state, state_dict, mapper, connection, has_identity, \
-            instance_key, row_switch in states_to_insert + \
+            row_switch in states_to_insert + \
             states_to_update:
 
         if mapper._readonly_props:
@@ -864,7 +863,7 @@ def _connections_for_states(base_mapper, uowtransaction, states):
         if connection_callable:
             connection = connection_callable(base_mapper, state.obj())
 
-        mapper = _state_mapper(state)
+        mapper = state.manager.mapper
 
         yield state, state.dict, mapper, connection