From: Mike Bayer Date: Thu, 15 Oct 2009 18:41:02 +0000 (+0000) Subject: - an executemany() now requires that all bound parameter X-Git-Tag: rel_0_6beta1~254 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c5571ab19a155f0c11381d65edc07c16902f3fad;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - an executemany() now requires that all bound parameter sets require that all keys are present which are present in the first bound parameter set. The structure and behavior of an insert/update statement is very much determined by the first parameter set, including which defaults are going to fire off, and a minimum of guesswork is performed with all the rest so that performance is not impacted. For this reason defaults would otherwise silently "fail" for missing parameters, so this is now guarded against. [ticket:1566] --- diff --git a/CHANGES b/CHANGES index 6b3c6ba195..e02f56954d 100644 --- a/CHANGES +++ b/CHANGES @@ -86,6 +86,17 @@ CHANGES - the autoincrement flag on column now indicates the column which should be linked to cursor.lastrowid, if that method is used. See the API docs for details. + + - an executemany() now requires that all bound parameter + sets require that all keys are present which are + present in the first bound parameter set. The structure + and behavior of an insert/update statement is very much + determined by the first parameter set, including which + defaults are going to fire off, and a minimum of + guesswork is performed with all the rest so that performance + is not impacted. For this reason defaults would otherwise + silently "fail" for missing parameters, so this is now guarded + against. [ticket:1566] - returning() support is native to insert(), update(), delete(). Implementations of varying levels of diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 12ab605e46..ad728da9c6 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -232,7 +232,7 @@ class DefaultExecutionContext(base.ExecutionContext): self.compiled_parameters = [compiled.construct_params()] self.executemany = False else: - self.compiled_parameters = [compiled.construct_params(m) for m in parameters] + self.compiled_parameters = [compiled.construct_params(m, _group_number=grp) for grp,m in enumerate(parameters)] self.executemany = len(parameters) > 1 self.cursor = self.create_cursor() @@ -508,11 +508,22 @@ class DefaultExecutionContext(base.ExecutionContext): if self.executemany: if len(self.compiled.prefetch): - params = self.compiled_parameters - for param in params: + scalar_defaults = {} + + # pre-determine scalar Python-side defaults + # to avoid many calls of get_insert_default()/get_update_default() + for c in self.compiled.prefetch: + if self.isinsert and c.default and c.default.is_scalar: + scalar_defaults[c] = c.default.arg + elif self.isupdate and c.onupdate and c.onupdate.is_scalar: + scalar_defaults[c] = c.onupdate.arg + + for param in self.compiled_parameters: self.current_parameters = param for c in self.compiled.prefetch: - if self.isinsert: + if c in scalar_defaults: + val = scalar_defaults[c] + elif self.isinsert: val = self.get_insert_default(c) else: val = self.get_update_default(c) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 7965070d1e..9798fc23a6 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -1094,6 +1094,10 @@ class ColumnDefault(DefaultGenerator): @util.memoized_property def is_clause_element(self): return isinstance(self.arg, expression.ClauseElement) + + @util.memoized_property + def is_scalar(self): + return not self.is_callable and not self.is_clause_element and not self.is_sequence def _maybe_wrap_callable(self, fn): """Backward compat: Wrap callables that don't accept a context.""" diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 61c6c214f1..b204f42b14 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -231,7 +231,7 @@ class SQLCompiler(engine.Compiled): def is_subquery(self): return len(self.stack) > 1 - def construct_params(self, params=None): + def construct_params(self, params=None, _group_number=None): """return a dictionary of bind parameter keys and values""" if params: @@ -242,7 +242,12 @@ class SQLCompiler(engine.Compiled): pd[name] = params[paramname] break else: - if util.callable(bindparam.value): + if bindparam.required: + if _group_number: + raise exc.InvalidRequestError("A value is required for bind parameter %r, in parameter group %d" % (bindparam.key, _group_number)) + else: + raise exc.InvalidRequestError("A value is required for bind parameter %r" % bindparam.key) + elif util.callable(bindparam.value): pd[name] = bindparam.value() else: pd[name] = bindparam.value @@ -751,8 +756,8 @@ class SQLCompiler(engine.Compiled): return text - def _create_crud_bind_param(self, col, value): - bindparam = sql.bindparam(col.key, value, type_=col.type) + def _create_crud_bind_param(self, col, value, required=False): + bindparam = sql.bindparam(col.key, value, type_=col.type, required=required) self.binds[col.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) @@ -770,21 +775,23 @@ class SQLCompiler(engine.Compiled): self.postfetch = [] self.prefetch = [] self.returning = [] - + # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: return [ - (c, self._create_crud_bind_param(c, None)) + (c, self._create_crud_bind_param(c, None, required=True)) for c in stmt.table.columns ] + required = object() + # if we have statement parameters - set defaults in the # compiled params if self.column_keys is None: parameters = {} else: - parameters = dict((sql._column_as_key(key), None) + parameters = dict((sql._column_as_key(key), required) for key in self.column_keys) if stmt.parameters is not None: @@ -808,7 +815,7 @@ class SQLCompiler(engine.Compiled): if c.key in parameters: value = parameters[c.key] if sql._is_literal(value): - value = self._create_crud_bind_param(c, value) + value = self._create_crud_bind_param(c, value, required=value is required) else: self.postfetch.append(c) value = self.process(value.self_group()) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 0ece67e20f..0a703ad36a 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -743,7 +743,7 @@ def table(name, *columns): """ return TableClause(name, *columns) -def bindparam(key, value=None, shortname=None, type_=None, unique=False): +def bindparam(key, value=None, shortname=None, type_=None, unique=False, required=False): """Create a bind parameter clause with the given key. value @@ -762,11 +762,14 @@ def bindparam(key, value=None, shortname=None, type_=None, unique=False): underlying ``key`` modified to a uniquely generated name. mostly useful with value-based bind params. + required + A value is required at execution time. + """ if isinstance(key, ColumnClause): - return _BindParamClause(key.name, value, type_=key.type, unique=unique, shortname=shortname) + return _BindParamClause(key.name, value, type_=key.type, unique=unique, shortname=shortname, required=required) else: - return _BindParamClause(key, value, type_=type_, unique=unique, shortname=shortname) + return _BindParamClause(key, value, type_=type_, unique=unique, shortname=shortname, required=required) def outparam(key, type_=None): """Create an 'OUT' parameter for usage in functions (stored procedures), for @@ -2071,7 +2074,7 @@ class _BindParamClause(ColumnElement): __visit_name__ = 'bindparam' quote = None - def __init__(self, key, value, type_=None, unique=False, isoutparam=False, shortname=None): + def __init__(self, key, value, type_=None, unique=False, isoutparam=False, shortname=None, required=False): """Construct a _BindParamClause. key @@ -2100,7 +2103,10 @@ class _BindParamClause(ColumnElement): modified if another ``_BindParamClause`` of the same name already has been located within the containing ``ClauseElement``. - + + required + a value is required at execution time. + isoutparam if True, the parameter should be treated like a stored procedure "OUT" parameter. @@ -2115,7 +2121,8 @@ class _BindParamClause(ColumnElement): self.value = value self.isoutparam = isoutparam self.shortname = shortname - + self.required = required + if type_ is None: self.type = sqltypes.type_map.get(type(value), sqltypes.NullType)() elif isinstance(type_, type): diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index baed19f885..04809b48ae 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -4,7 +4,7 @@ from sqlalchemy import Sequence, Column, func from sqlalchemy.sql import select, text import sqlalchemy as sa from sqlalchemy.test import testing, engines -from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean +from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean, exc from sqlalchemy.test.schema import Table from sqlalchemy.test.testing import eq_ from test.sql import _base @@ -300,7 +300,16 @@ class DefaultTest(testing.TestBase): 12, today, 'py'), (53, 'imthedefault', f, ts, ts, ctexec, True, False, 12, today, 'py')]) - + + def test_missing_many_param(self): + assert_raises_message(exc.InvalidRequestError, + "A value is required for bind parameter 'col7', in parameter group 1", + t.insert().execute, + {'col4':7, 'col7':12, 'col8':19}, + {'col4':7, 'col8':19}, + {'col4':7, 'col7':12, 'col8':19}, + ) + def test_insert_values(self): t.insert(values={'col3':50}).execute() l = t.select().execute() @@ -356,7 +365,7 @@ class DefaultTest(testing.TestBase): l = l.first() eq_(55, l['col3']) - + class PKDefaultTest(_base.TablesTest): __requires__ = ('subqueries',) diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 470a694fb9..fe11c62bfc 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -43,12 +43,23 @@ class QueryTest(TestBase): assert users.count().scalar() == 1 def test_insert_heterogeneous_params(self): - users.insert().execute( + """test that executemany parameters are asserted to match the parameter set of the first.""" + + assert_raises_message(exc.InvalidRequestError, + "A value is required for bind parameter 'user_name', in parameter group 2", + users.insert().execute, {'user_id':7, 'user_name':'jack'}, {'user_id':8, 'user_name':'ed'}, {'user_id':9} ) - assert users.select().execute().fetchall() == [(7, 'jack'), (8, 'ed'), (9, None)] + + # this succeeds however. We aren't yet doing + # a length check on all subsequent parameters. + users.insert().execute( + {'user_id':7}, + {'user_id':8, 'user_name':'ed'}, + {'user_id':9} + ) def test_update(self): users.insert().execute(user_id = 7, user_name = 'jack')