from ..sql.compiler import DDLCompiler
from ..sql.compiler import SQLCompiler
from ..sql.elements import quoted_name
-from ..sql.schema import default_is_scalar
from ..util.typing import Final
from ..util.typing import Literal
self.cursor = self.create_cursor()
if self.compiled.insert_prefetch or self.compiled.update_prefetch:
- if self.executemany:
- self._process_executemany_defaults()
- else:
- self._process_executesingle_defaults()
+ self._process_execute_defaults()
processors = compiled._bind_processors
if default.is_sequence:
return self.fire_sequence(default, type_)
elif default.is_callable:
+ # this codepath is not normally used as it's inlined
+ # into _process_execute_defaults
self.current_column = column
return default.arg(self)
elif default.is_clause_element:
return self._exec_default_clause_element(column, default, type_)
else:
+ # this codepath is not normally used as it's inlined
+ # into _process_execute_defaults
return default.arg
def _exec_default_clause_element(self, column, default, type_):
else:
return self._exec_default(column, column.onupdate, column.type)
- def _process_executemany_defaults(self):
+ def _process_execute_defaults(self):
compiled = cast(SQLCompiler, self.compiled)
key_getter = compiled._within_exec_param_key_getter
- scalar_defaults: Dict[Column[Any], Any] = {}
-
- insert_prefetch = compiled.insert_prefetch
- update_prefetch = compiled.update_prefetch
-
# pre-determine scalar Python-side defaults
# to avoid many calls of get_insert_default()/
# get_update_default()
- for c in insert_prefetch:
- if c.default and default_is_scalar(c.default):
- scalar_defaults[c] = c.default.arg
-
- for c in update_prefetch:
- if c.onupdate and default_is_scalar(c.onupdate):
- scalar_defaults[c] = c.onupdate.arg
+ if compiled.insert_prefetch:
+ prefetch_recs = [
+ (
+ c,
+ key_getter(c),
+ (
+ c.default.arg, # type: ignore
+ c.default.is_scalar,
+ c.default.is_callable,
+ )
+ if c.default and c.default.has_arg
+ else (None, None, None),
+ self.get_insert_default,
+ )
+ for c in compiled.insert_prefetch
+ ]
+ elif compiled.update_prefetch:
+ prefetch_recs = [
+ (
+ c,
+ key_getter(c),
+ (
+ c.onupdate.arg, # type: ignore
+ c.onupdate.is_scalar,
+ c.onupdate.is_callable,
+ )
+ if c.onupdate and c.onupdate.has_arg
+ else (None, None, None),
+ self.get_update_default,
+ )
+ for c in compiled.update_prefetch
+ ]
+ else:
+ prefetch_recs = []
for param in self.compiled_parameters:
self.current_parameters = param
- for c in insert_prefetch:
- if c in scalar_defaults:
- val = scalar_defaults[c]
- else:
- val = self.get_insert_default(c)
- if val is not None:
- param[key_getter(c)] = val
- for c in update_prefetch:
- if c in scalar_defaults:
- val = scalar_defaults[c]
- else:
- val = self.get_update_default(c)
- if val is not None:
- param[key_getter(c)] = val
-
- del self.current_parameters
-
- def _process_executesingle_defaults(self):
- compiled = cast(SQLCompiler, self.compiled)
- key_getter = compiled._within_exec_param_key_getter
- self.current_parameters = (
- compiled_parameters
- ) = self.compiled_parameters[0]
-
- for c in compiled.insert_prefetch:
- if c.default and default_is_scalar(c.default):
- val = c.default.arg
- else:
- val = self.get_insert_default(c)
-
- if val is not None:
- compiled_parameters[key_getter(c)] = val
-
- for c in compiled.update_prefetch:
- val = self.get_update_default(c)
+ for (
+ c,
+ param_key,
+ (arg, is_scalar, is_callable),
+ fallback,
+ ) in prefetch_recs:
+ if is_scalar:
+ param[param_key] = arg
+ elif is_callable:
+ self.current_column = c
+ param[param_key] = arg(self) # type: ignore
+ else:
+ val = fallback(c)
+ if val is not None:
+ param[param_key] = val
- if val is not None:
- compiled_parameters[key_getter(c)] = val
del self.current_parameters