]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [feature] The "required" flag is set to
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Aug 2012 19:11:53 +0000 (15:11 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Aug 2012 19:11:53 +0000 (15:11 -0400)
True by default, if not passed explicitly,
on bindparam() if the "value" or "callable"
parameters are not passed.
This will cause statement execution to check
for the parameter being present in the final
collection of bound parameters, rather than
implicitly assigning None. [ticket:2556]

CHANGES
lib/sqlalchemy/sql/expression.py
test/engine/test_execute.py
test/orm/test_query.py
test/sql/test_compiler.py
test/sql/test_functions.py
test/sql/test_query.py

diff --git a/CHANGES b/CHANGES
index ee75177563d74bad5d248c90923bf09232975ad3..1430ad0b7b26482f628200b5a004d5f668c26ac5 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -328,6 +328,15 @@ underneath "0.7.xx".
     docs for "Registering New Dialects".
     [ticket:2462]
 
+  - [feature] The "required" flag is set to
+    True by default, if not passed explicitly,
+    on bindparam() if the "value" or "callable"
+    parameters are not passed.
+    This will cause statement execution to check
+    for the parameter being present in the final
+    collection of bound parameters, rather than
+    implicitly assigning None. [ticket:2556]
+
   - [bug] The names of the columns on the
     .c. attribute of a select().apply_labels()
     is now based on <tablename>_<colkey> instead
index b7b965ea9c4046b8d757056a7cbe9ddfebc25699..6b184d1ca5f0521aa8082f439bd0e017996d4ddb 100644 (file)
@@ -54,6 +54,7 @@ __all__ = [
     'tuple_', 'type_coerce', 'union', 'union_all', 'update', ]
 
 PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT')
+NO_ARG = util.symbol('NO_ARG')
 
 def nullsfirst(column):
     """Return a NULLS FIRST ``ORDER BY`` clause element.
@@ -990,7 +991,7 @@ def table(name, *columns):
     """
     return TableClause(name, *columns)
 
-def bindparam(key, value=None, type_=None, unique=False, required=False,
+def bindparam(key, value=NO_ARG, type_=None, unique=False, required=NO_ARG,
                         quote=None, callable_=None):
     """Create a bind parameter clause with the given key.
 
@@ -1007,6 +1008,14 @@ def bindparam(key, value=None, type_=None, unique=False, required=False,
           overridden by the dictionary of parameters sent to statement
           compilation/execution.
 
+          Defaults to ``None``, however if neither ``value`` nor
+          ``callable`` are passed explicitly, the ``required`` flag will be set to
+          ``True`` which has the effect of requiring a value be present
+          when the statement is actually executed.
+
+          .. versionchanged:: 0.8 The ``required`` flag is set to ``True``
+             automatically if ``value`` or ``callable`` is not passed.
+
         :param callable\_:
           A callable function that takes the place of "value".  The function
           will be called at statement execution time to determine the
@@ -1026,7 +1035,14 @@ def bindparam(key, value=None, type_=None, unique=False, required=False,
           :class:`.ClauseElement`.
 
         :param required:
-          a value is required at execution time.
+          If ``True``, a value is required at execution time.  If not passed,
+          is set to ``True`` or ``False`` based on whether or not
+          one of ``value`` or ``callable`` were passed..
+
+          .. versionchanged:: 0.8 If the ``required`` flag is not specified,
+             it will be set automatically to ``True`` or ``False`` depending
+             on whether or not the ``value`` or ``callable`` parameters
+             were specified.
 
         :param quote:
           True if this parameter name requires quoting and is not
@@ -1037,6 +1053,10 @@ def bindparam(key, value=None, type_=None, unique=False, required=False,
     if isinstance(key, ColumnClause):
         type_ = key.type
         key = key.name
+    if required is NO_ARG:
+        required = (value is NO_ARG and callable_ is None)
+    if value is NO_ARG:
+        value = None
     return BindParameter(key, value, type_=type_,
                             callable_=callable_,
                             unique=unique, required=required,
@@ -1703,6 +1723,7 @@ class ClauseElement(Visitable):
         def visit_bindparam(bind):
             if bind.key in kwargs:
                 bind.value = kwargs[bind.key]
+                bind.required = False
             if unique:
                 bind._convert_to_unique()
         return cloned_traverse(self, {}, {'bindparam': visit_bindparam})
index 1067600df42b6501f8e98066d47ed55d0084f607..900a3c8eebc7905a5fbd58b62c792f3dafc1bd84 100644 (file)
@@ -19,7 +19,7 @@ from sqlalchemy.engine.base import Connection, Engine
 from test.lib import fixtures
 import StringIO
 
-users, metadata = None, None
+users, metadata, users_autoinc = None, None, None
 class ExecuteTest(fixtures.TestBase):
     @classmethod
     def setup_class(cls):
@@ -315,11 +315,9 @@ class ExecuteTest(fixtures.TestBase):
     def test_empty_insert(self):
         """test that execute() interprets [] as a list with no params"""
 
-        result = \
-            testing.db.execute(users_autoinc.insert().
-                        values(user_name=bindparam('name')), [])
-        eq_(testing.db.execute(users_autoinc.select()).fetchall(), [(1,
-            None)])
+        testing.db.execute(users_autoinc.insert().
+                    values(user_name=bindparam('name', None)), [])
+        eq_(testing.db.execute(users_autoinc.select()).fetchall(), [(1, None)])
 
     @testing.requires.ad_hoc_engines
     def test_engine_level_options(self):
index a2f1ff1dc301f232f2f5d496258828b34aa117d0..04b62f8c9cacde90ff0ee7eeca56cbf4a2d7e9d1 100644 (file)
@@ -947,7 +947,8 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL):
 
         session = create_session()
 
-        q = session.query(User.id).filter(User.id==bindparam('foo')).params(foo=7).subquery()
+        q = session.query(User.id).filter(User.id == bindparam('foo')).\
+                            params(foo=7).subquery()
 
         q = session.query(User).filter(User.id.in_(q))
 
@@ -957,7 +958,8 @@ class ExpressionTest(QueryTest, AssertsCompiledSQL):
         User, Address = self.classes.User, self.classes.Address
 
         session = create_session()
-        s = session.query(User.id).join(User.addresses).group_by(User.id).having(func.count(Address.id) > 2)
+        s = session.query(User.id).join(User.addresses).group_by(User.id).\
+                                    having(func.count(Address.id) > 2)
         eq_(
             session.query(User).filter(User.id.in_(s)).all(),
             [User(id=8)]
index 55b583071f569e1dfe800da7d0955fd79baab4fe..40d29f2220799cf018ced3cc02f17341bbead03d 100644 (file)
@@ -1,5 +1,15 @@
 #! coding:utf-8
 
+"""
+compiler tests.
+
+These tests are among the very first that were written when SQLAlchemy
+began in 2005.  As a result the testing style here is very dense;
+it's an ongoing job to break these into much smaller tests with correct pep8
+styling and coherent test organization.
+
+"""
+
 from test.lib.testing import eq_, is_, assert_raises, assert_raises_message
 import datetime, re, operator, decimal
 from sqlalchemy import *
@@ -1446,21 +1456,24 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         # test Text embedded within select_from(), using binds
         generate_series = text(
                             "generate_series(:x, :y, :z) as s(a)",
-                            bindparams=[bindparam('x'), bindparam('y'), bindparam('z')]
+                            bindparams=[bindparam('x', None),
+                                bindparam('y', None), bindparam('z', None)]
                         )
 
-        s =select([
+        s = select([
                     (func.current_date() + literal_column("s.a")).label("dates")
                 ]).select_from(generate_series)
         self.assert_compile(
                     s,
-                    "SELECT CURRENT_DATE + s.a AS dates FROM generate_series(:x, :y, :z) as s(a)",
+                    "SELECT CURRENT_DATE + s.a AS dates FROM "
+                                        "generate_series(:x, :y, :z) as s(a)",
                     checkparams={'y': None, 'x': None, 'z': None}
                 )
 
         self.assert_compile(
                     s.params(x=5, y=6, z=7),
-                    "SELECT CURRENT_DATE + s.a AS dates FROM generate_series(:x, :y, :z) as s(a)",
+                    "SELECT CURRENT_DATE + s.a AS dates FROM "
+                                        "generate_series(:x, :y, :z) as s(a)",
                     checkparams={'y': 6, 'x': 5, 'z': 7}
                 )
 
@@ -1879,7 +1892,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "UNION (SELECT foo, bar FROM bat INTERSECT SELECT foo, bar FROM bat)"
         )
 
-    @testing.uses_deprecated()
     def test_binds(self):
         for (
              stmt,
@@ -1947,13 +1959,15 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
                  {'myid_1':5, 'myid_2': 6}, {'myid_1':5, 'myid_2':6}, [5,6]
              ),
              (
-                bindparam('test', type_=String) + text("'hi'"),
+                bindparam('test', type_=String, required=False) + text("'hi'"),
                 ":test || 'hi'",
                 "? || 'hi'",
                 {'test':None}, [None],
                 {}, {'test':None}, [None]
              ),
              (
+                # testing select.params() here - bindparam() objects
+                # must get required flag set to False
                  select([table1], or_(table1.c.myid==bindparam('myid'),
                                     table2.c.otherid==bindparam('myotherid'))).\
                                         params({'myid':8, 'myotherid':7}),
index f0fcd4b7277d272da585462270a5a371af510789..8e5c6bc5813f68b82332a96754b8ec989a21aeb7 100644 (file)
@@ -268,7 +268,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_functions_with_cols(self):
         users = table('users', column('id'), column('name'), column('fullname'))
         calculate = select([column('q'), column('z'), column('r')],
-            from_obj=[func.calculate(bindparam('x'), bindparam('y'))])
+            from_obj=[func.calculate(bindparam('x', None), bindparam('y', None))])
 
         self.assert_compile(select([users], users.c.id > calculate.c.z),
             "SELECT users.id, users.name, users.fullname "
index e79bf32e3f6697cee8661d32d9b6d1138291f07c..670fb2c64f62de8289e5564b386d417dfa8e05d5 100644 (file)
@@ -1,4 +1,4 @@
-from test.lib.testing import eq_, assert_raises_message, assert_raises
+from test.lib.testing import eq_, assert_raises_message, assert_raises, is_
 import datetime
 from sqlalchemy import *
 from sqlalchemy import exc, sql, util
@@ -1216,6 +1216,67 @@ class QueryTest(fixtures.TestBase):
         r = s.execute().fetchall()
         assert len(r) == 1
 
+class RequiredBindTest(fixtures.TablesTest):
+    run_create_tables = None
+    run_deletes = None
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('foo', metadata,
+                Column('id', Integer, primary_key=True),
+                Column('data', String(50)),
+                Column('x', Integer)
+            )
+
+    def _assert_raises(self, stmt, params):
+        assert_raises_message(
+            exc.StatementError,
+            "A value is required for bind parameter 'x'",
+            testing.db.execute, stmt, **params)
+
+        assert_raises_message(
+            exc.StatementError,
+            "A value is required for bind parameter 'x'",
+            testing.db.execute, stmt, params)
+
+    def test_insert(self):
+        stmt = self.tables.foo.insert().values(x=bindparam('x'),
+                                    data=bindparam('data'))
+        self._assert_raises(
+            stmt, {'data': 'data'}
+        )
+
+    def test_select_where(self):
+        stmt = select([self.tables.foo]).\
+                    where(self.tables.foo.c.data == bindparam('data')).\
+                    where(self.tables.foo.c.x == bindparam('x'))
+        self._assert_raises(
+            stmt, {'data': 'data'}
+        )
+
+    def test_select_columns(self):
+        stmt = select([bindparam('data'), bindparam('x')])
+        self._assert_raises(
+            stmt, {'data': 'data'}
+        )
+
+    def test_text(self):
+        stmt = text("select * from foo where x=:x and data=:data1")
+        self._assert_raises(
+            stmt, {'data1': 'data'}
+        )
+
+    def test_required_flag(self):
+        is_(bindparam('foo').required, True)
+        is_(bindparam('foo', required=False).required, False)
+        is_(bindparam('foo', 'bar').required, False)
+        is_(bindparam('foo', 'bar', required=True).required, True)
+
+        c = lambda: None
+        is_(bindparam('foo', callable_=c, required=True).required, True)
+        is_(bindparam('foo', callable_=c).required, False)
+        is_(bindparam('foo', callable_=c, required=False).required, False)
+
 class TableInsertTest(fixtures.TablesTest):
     """test for consistent insert behavior across dialects
     regarding the inline=True flag, lower-case 't' tables.