]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
optimize exec defaults a bit
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Apr 2023 01:16:58 +0000 (21:16 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Apr 2023 04:39:03 +0000 (00:39 -0400)
since I am probably using this for the new "sentinel" thing,
clean up this code, reduce codepaths and inline a bit

Change-Id: I9cb312828e3bc23636f6db794b169f1acc4ebae3

lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/schema.py

index 3e4e6fb9ae65d6be9222e83e368a413f51fc2f7d..462473de22c3abb460183d4dbd365ae6f58da6b7 100644 (file)
@@ -62,7 +62,6 @@ from ..sql.base import _NoArg
 from ..sql.compiler import DDLCompiler
 from ..sql.compiler import SQLCompiler
 from ..sql.elements import quoted_name
-from ..sql.schema import default_is_scalar
 from ..util.typing import Final
 from ..util.typing import Literal
 
@@ -1203,10 +1202,7 @@ class DefaultExecutionContext(ExecutionContext):
         self.cursor = self.create_cursor()
 
         if self.compiled.insert_prefetch or self.compiled.update_prefetch:
-            if self.executemany:
-                self._process_executemany_defaults()
-            else:
-                self._process_executesingle_defaults()
+            self._process_execute_defaults()
 
         processors = compiled._bind_processors
 
@@ -1907,11 +1903,15 @@ class DefaultExecutionContext(ExecutionContext):
         if default.is_sequence:
             return self.fire_sequence(default, type_)
         elif default.is_callable:
+            # this codepath is not normally used as it's inlined
+            # into _process_execute_defaults
             self.current_column = column
             return default.arg(self)
         elif default.is_clause_element:
             return self._exec_default_clause_element(column, default, type_)
         else:
+            # this codepath is not normally used as it's inlined
+            # into _process_execute_defaults
             return default.arg
 
     def _exec_default_clause_element(self, column, default, type_):
@@ -2054,68 +2054,68 @@ class DefaultExecutionContext(ExecutionContext):
         else:
             return self._exec_default(column, column.onupdate, column.type)
 
-    def _process_executemany_defaults(self):
+    def _process_execute_defaults(self):
         compiled = cast(SQLCompiler, self.compiled)
 
         key_getter = compiled._within_exec_param_key_getter
 
-        scalar_defaults: Dict[Column[Any], Any] = {}
-
-        insert_prefetch = compiled.insert_prefetch
-        update_prefetch = compiled.update_prefetch
-
         # pre-determine scalar Python-side defaults
         # to avoid many calls of get_insert_default()/
         # get_update_default()
-        for c in insert_prefetch:
-            if c.default and default_is_scalar(c.default):
-                scalar_defaults[c] = c.default.arg
-
-        for c in update_prefetch:
-            if c.onupdate and default_is_scalar(c.onupdate):
-                scalar_defaults[c] = c.onupdate.arg
+        if compiled.insert_prefetch:
+            prefetch_recs = [
+                (
+                    c,
+                    key_getter(c),
+                    (
+                        c.default.arg,  # type: ignore
+                        c.default.is_scalar,
+                        c.default.is_callable,
+                    )
+                    if c.default and c.default.has_arg
+                    else (None, None, None),
+                    self.get_insert_default,
+                )
+                for c in compiled.insert_prefetch
+            ]
+        elif compiled.update_prefetch:
+            prefetch_recs = [
+                (
+                    c,
+                    key_getter(c),
+                    (
+                        c.onupdate.arg,  # type: ignore
+                        c.onupdate.is_scalar,
+                        c.onupdate.is_callable,
+                    )
+                    if c.onupdate and c.onupdate.has_arg
+                    else (None, None, None),
+                    self.get_update_default,
+                )
+                for c in compiled.update_prefetch
+            ]
+        else:
+            prefetch_recs = []
 
         for param in self.compiled_parameters:
             self.current_parameters = param
-            for c in insert_prefetch:
-                if c in scalar_defaults:
-                    val = scalar_defaults[c]
-                else:
-                    val = self.get_insert_default(c)
-                if val is not None:
-                    param[key_getter(c)] = val
-            for c in update_prefetch:
-                if c in scalar_defaults:
-                    val = scalar_defaults[c]
-                else:
-                    val = self.get_update_default(c)
-                if val is not None:
-                    param[key_getter(c)] = val
-
-        del self.current_parameters
-
-    def _process_executesingle_defaults(self):
-        compiled = cast(SQLCompiler, self.compiled)
 
-        key_getter = compiled._within_exec_param_key_getter
-        self.current_parameters = (
-            compiled_parameters
-        ) = self.compiled_parameters[0]
-
-        for c in compiled.insert_prefetch:
-            if c.default and default_is_scalar(c.default):
-                val = c.default.arg
-            else:
-                val = self.get_insert_default(c)
-
-            if val is not None:
-                compiled_parameters[key_getter(c)] = val
-
-        for c in compiled.update_prefetch:
-            val = self.get_update_default(c)
+            for (
+                c,
+                param_key,
+                (arg, is_scalar, is_callable),
+                fallback,
+            ) in prefetch_recs:
+                if is_scalar:
+                    param[param_key] = arg
+                elif is_callable:
+                    self.current_column = c
+                    param[param_key] = arg(self)  # type: ignore
+                else:
+                    val = fallback(c)
+                    if val is not None:
+                        param[param_key] = val
 
-            if val is not None:
-                compiled_parameters[key_getter(c)] = val
         del self.current_parameters
 
 
index b4263137b38182c770a47001a38284f645eec08b..ab56d2552ca7c7151c6d3f556773c76b07fc9784 100644 (file)
@@ -3093,6 +3093,7 @@ class DefaultGenerator(Executable, SchemaItem):
     is_clause_element = False
     is_callable = False
     is_scalar = False
+    has_arg = False
     column: Optional[Column[Any]]
 
     def __init__(self, for_update: bool = False) -> None:
@@ -3234,6 +3235,7 @@ class ScalarElementColumnDefault(ColumnDefault):
     """
 
     is_scalar = True
+    has_arg = True
 
     def __init__(self, arg: Any, for_update: bool = False) -> None:
         self.for_update = for_update
@@ -3256,7 +3258,7 @@ class ColumnElementColumnDefault(ColumnDefault):
     """
 
     is_clause_element = True
-
+    has_arg = True
     arg: _SQLExprDefault
 
     def __init__(
@@ -3294,6 +3296,7 @@ class CallableColumnDefault(ColumnDefault):
 
     is_callable = True
     arg: _CallableColumnDefaultProtocol
+    has_arg = True
 
     def __init__(
         self,