]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- an executemany() now requires that all bound parameter
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Oct 2009 18:41:02 +0000 (18:41 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Oct 2009 18:41:02 +0000 (18:41 +0000)
sets require that all keys are present which are
present in the first bound parameter set.  The structure
and behavior of an insert/update statement is very much
determined by the first parameter set, including which
defaults are going to fire off, and a minimum of
guesswork is performed with all the rest so that performance
is not impacted.  For this reason defaults would otherwise
silently "fail" for missing parameters, so this is now guarded
against. [ticket:1566]

CHANGES
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/sql/test_defaults.py
test/sql/test_query.py

diff --git a/CHANGES b/CHANGES
index 6b3c6ba195cb72eeed5e87fb27efdf3ecc98ef17..e02f56954d710f415f23652af53e33e8a4f0eabf 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -86,6 +86,17 @@ CHANGES
     - the autoincrement flag on column now indicates the column
       which should be linked to cursor.lastrowid, if that method
       is used.  See the API docs for details.
+    
+    - an executemany() now requires that all bound parameter
+      sets require that all keys are present which are 
+      present in the first bound parameter set.  The structure
+      and behavior of an insert/update statement is very much
+      determined by the first parameter set, including which
+      defaults are going to fire off, and a minimum of 
+      guesswork is performed with all the rest so that performance
+      is not impacted.  For this reason defaults would otherwise 
+      silently "fail" for missing parameters, so this is now guarded 
+      against. [ticket:1566]
       
     - returning() support is native to insert(), update(),
       delete(). Implementations of varying levels of
index 12ab605e4637b0874b4579a8ee9d692ecafcb634..ad728da9c6645b422b081185b7472d05395c6890 100644 (file)
@@ -232,7 +232,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                 self.compiled_parameters = [compiled.construct_params()]
                 self.executemany = False
             else:
-                self.compiled_parameters = [compiled.construct_params(m) for m in parameters]
+                self.compiled_parameters = [compiled.construct_params(m, _group_number=grp) for grp,m in enumerate(parameters)]
                 self.executemany = len(parameters) > 1
 
             self.cursor = self.create_cursor()
@@ -508,11 +508,22 @@ class DefaultExecutionContext(base.ExecutionContext):
 
         if self.executemany:
             if len(self.compiled.prefetch):
-                params = self.compiled_parameters
-                for param in params:
+                scalar_defaults = {}
+                
+                # pre-determine scalar Python-side defaults
+                # to avoid many calls of get_insert_default()/get_update_default()
+                for c in self.compiled.prefetch:
+                    if self.isinsert and c.default and c.default.is_scalar:
+                        scalar_defaults[c] = c.default.arg
+                    elif self.isupdate and c.onupdate and c.onupdate.is_scalar:
+                        scalar_defaults[c] = c.onupdate.arg
+                        
+                for param in self.compiled_parameters:
                     self.current_parameters = param
                     for c in self.compiled.prefetch:
-                        if self.isinsert:
+                        if c in scalar_defaults:
+                            val = scalar_defaults[c]
+                        elif self.isinsert:
                             val = self.get_insert_default(c)
                         else:
                             val = self.get_update_default(c)
index 7965070d1e332340f7033c2191582eaade3f245c..9798fc23a6d9176f48a3d7d67a5f31dd97b53d70 100644 (file)
@@ -1094,6 +1094,10 @@ class ColumnDefault(DefaultGenerator):
     @util.memoized_property
     def is_clause_element(self):
         return isinstance(self.arg, expression.ClauseElement)
+    
+    @util.memoized_property
+    def is_scalar(self):
+        return not self.is_callable and not self.is_clause_element and not self.is_sequence
         
     def _maybe_wrap_callable(self, fn):
         """Backward compat: Wrap callables that don't accept a context."""
index 61c6c214f1764cf4abcdaa40eb5cd0dcfde5bb6e..b204f42b14580eee24fad6fca7a872217a45e506 100644 (file)
@@ -231,7 +231,7 @@ class SQLCompiler(engine.Compiled):
     def is_subquery(self):
         return len(self.stack) > 1
 
-    def construct_params(self, params=None):
+    def construct_params(self, params=None, _group_number=None):
         """return a dictionary of bind parameter keys and values"""
 
         if params:
@@ -242,7 +242,12 @@ class SQLCompiler(engine.Compiled):
                         pd[name] = params[paramname]
                         break
                 else:
-                    if util.callable(bindparam.value):
+                    if bindparam.required:
+                        if _group_number:
+                            raise exc.InvalidRequestError("A value is required for bind parameter %r, in parameter group %d" % (bindparam.key, _group_number))
+                        else:
+                            raise exc.InvalidRequestError("A value is required for bind parameter %r" % bindparam.key)
+                    elif util.callable(bindparam.value):
                         pd[name] = bindparam.value()
                     else:
                         pd[name] = bindparam.value
@@ -751,8 +756,8 @@ class SQLCompiler(engine.Compiled):
 
         return text
 
-    def _create_crud_bind_param(self, col, value):
-        bindparam = sql.bindparam(col.key, value, type_=col.type)
+    def _create_crud_bind_param(self, col, value, required=False):
+        bindparam = sql.bindparam(col.key, value, type_=col.type, required=required)
         self.binds[col.key] = bindparam
         return self.bindparam_string(self._truncate_bindparam(bindparam))
         
@@ -770,21 +775,23 @@ class SQLCompiler(engine.Compiled):
         self.postfetch = []
         self.prefetch = []
         self.returning = []
-        
+
         # no parameters in the statement, no parameters in the
         # compiled params - return binds for all columns
         if self.column_keys is None and stmt.parameters is None:
             return [
-                        (c, self._create_crud_bind_param(c, None)) 
+                        (c, self._create_crud_bind_param(c, None, required=True)) 
                         for c in stmt.table.columns
                     ]
 
+        required = object()
+        
         # if we have statement parameters - set defaults in the
         # compiled params
         if self.column_keys is None:
             parameters = {}
         else:
-            parameters = dict((sql._column_as_key(key), None)
+            parameters = dict((sql._column_as_key(key), required)
                               for key in self.column_keys)
 
         if stmt.parameters is not None:
@@ -808,7 +815,7 @@ class SQLCompiler(engine.Compiled):
             if c.key in parameters:
                 value = parameters[c.key]
                 if sql._is_literal(value):
-                    value = self._create_crud_bind_param(c, value)
+                    value = self._create_crud_bind_param(c, value, required=value is required)
                 else:
                     self.postfetch.append(c)
                     value = self.process(value.self_group())
index 0ece67e20f7e02ea51d2ffbc258d876c6b136725..0a703ad36a49bec16d5d08aa77b1e9e55455a16d 100644 (file)
@@ -743,7 +743,7 @@ def table(name, *columns):
     """
     return TableClause(name, *columns)
 
-def bindparam(key, value=None, shortname=None, type_=None, unique=False):
+def bindparam(key, value=None, shortname=None, type_=None, unique=False, required=False):
     """Create a bind parameter clause with the given key.
 
     value
@@ -762,11 +762,14 @@ def bindparam(key, value=None, shortname=None, type_=None, unique=False):
       underlying ``key`` modified to a uniquely generated name.
       mostly useful with value-based bind params.
 
+    required
+      A value is required at execution time.
+      
     """
     if isinstance(key, ColumnClause):
-        return _BindParamClause(key.name, value, type_=key.type, unique=unique, shortname=shortname)
+        return _BindParamClause(key.name, value, type_=key.type, unique=unique, shortname=shortname, required=required)
     else:
-        return _BindParamClause(key, value, type_=type_, unique=unique, shortname=shortname)
+        return _BindParamClause(key, value, type_=type_, unique=unique, shortname=shortname, required=required)
 
 def outparam(key, type_=None):
     """Create an 'OUT' parameter for usage in functions (stored procedures), for
@@ -2071,7 +2074,7 @@ class _BindParamClause(ColumnElement):
     __visit_name__ = 'bindparam'
     quote = None
 
-    def __init__(self, key, value, type_=None, unique=False, isoutparam=False, shortname=None):
+    def __init__(self, key, value, type_=None, unique=False, isoutparam=False, shortname=None, required=False):
         """Construct a _BindParamClause.
 
         key
@@ -2100,7 +2103,10 @@ class _BindParamClause(ColumnElement):
           modified if another ``_BindParamClause`` of the same name
           already has been located within the containing
           ``ClauseElement``.
-
+        
+        required
+          a value is required at execution time.
+          
         isoutparam
           if True, the parameter should be treated like a stored procedure "OUT"
           parameter.
@@ -2115,7 +2121,8 @@ class _BindParamClause(ColumnElement):
         self.value = value
         self.isoutparam = isoutparam
         self.shortname = shortname
-
+        self.required = required
+        
         if type_ is None:
             self.type = sqltypes.type_map.get(type(value), sqltypes.NullType)()
         elif isinstance(type_, type):
index baed19f88542f4e866c1d1335801a3d1dc8aa051..04809b48aec304ee7fc896178bffb45ff9383d4c 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy import Sequence, Column, func
 from sqlalchemy.sql import select, text
 import sqlalchemy as sa
 from sqlalchemy.test import testing, engines
-from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean
+from sqlalchemy import MetaData, Integer, String, ForeignKey, Boolean, exc
 from sqlalchemy.test.schema import Table
 from sqlalchemy.test.testing import eq_
 from test.sql import _base
@@ -300,7 +300,16 @@ class DefaultTest(testing.TestBase):
               12, today, 'py'),
              (53, 'imthedefault', f, ts, ts, ctexec, True, False,
               12, today, 'py')])
-
+    
+    def test_missing_many_param(self):
+        assert_raises_message(exc.InvalidRequestError, 
+            "A value is required for bind parameter 'col7', in parameter group 1",
+            t.insert().execute,
+            {'col4':7, 'col7':12, 'col8':19},
+            {'col4':7, 'col8':19},
+            {'col4':7, 'col7':12, 'col8':19},
+        )
+        
     def test_insert_values(self):
         t.insert(values={'col3':50}).execute()
         l = t.select().execute()
@@ -356,7 +365,7 @@ class DefaultTest(testing.TestBase):
         l = l.first()
         eq_(55, l['col3'])
 
-
+    
 class PKDefaultTest(_base.TablesTest):
     __requires__ = ('subqueries',)
 
index 470a694fb947a8a206aff6cc429fd1ef7b42dfcb..fe11c62bfc7185f82f08d040e480bd85f27eef43 100644 (file)
@@ -43,12 +43,23 @@ class QueryTest(TestBase):
         assert users.count().scalar() == 1
 
     def test_insert_heterogeneous_params(self):
-        users.insert().execute(
+        """test that executemany parameters are asserted to match the parameter set of the first."""
+        
+        assert_raises_message(exc.InvalidRequestError, 
+            "A value is required for bind parameter 'user_name', in parameter group 2",
+            users.insert().execute,
             {'user_id':7, 'user_name':'jack'},
             {'user_id':8, 'user_name':'ed'},
             {'user_id':9}
         )
-        assert users.select().execute().fetchall() == [(7, 'jack'), (8, 'ed'), (9, None)]
+
+        # this succeeds however.   We aren't yet doing 
+        # a length check on all subsequent parameters.
+        users.insert().execute(
+            {'user_id':7},
+            {'user_id':8, 'user_name':'ed'},
+            {'user_id':9}
+        )
 
     def test_update(self):
         users.insert().execute(user_id = 7, user_name = 'jack')