]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [bug] Fixed issue where the "required" exception
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 28 Jan 2012 19:20:25 +0000 (14:20 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 28 Jan 2012 19:20:25 +0000 (14:20 -0500)
would not be raised for bindparam() with required=True,
if the statement were given no parameters at all.
[ticket:2381]

CHANGES
lib/sqlalchemy/sql/compiler.py
test/sql/test_compiler.py

diff --git a/CHANGES b/CHANGES
index e1c9969df73dd173ec45c675e444fc6169f2553b..cf25308cae3e7e9fce68af92657ffa558f89314c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -78,6 +78,11 @@ CHANGES
     constructs to sqlalchemy.sql namespace, though
     not part of __all__ as of yet.
 
+  - [bug] Fixed issue where the "required" exception
+    would not be raised for bindparam() with required=True,
+    if the statement were given no parameters at all.
+    [ticket:2381]
+
 - engine
   - [bug] Added __reduce__ to StatementError, 
     DBAPIError, column errors so that exceptions 
index c63ae1aea29d5144aab0bc494864132d1523147b..93e2473d94bedf8abced57b1f019ca1777ad4759 100644 (file)
@@ -289,7 +289,7 @@ class SQLCompiler(engine.Compiled):
     def sql_compiler(self):
         return self
 
-    def construct_params(self, params=None, _group_number=None):
+    def construct_params(self, params=None, _group_number=None, _check=True):
         """return a dictionary of bind parameter keys and values"""
 
         if params:
@@ -299,29 +299,40 @@ class SQLCompiler(engine.Compiled):
                     pd[name] = params[bindparam.key]
                 elif name in params:
                     pd[name] = params[name]
-                elif bindparam.required:
+                elif _check and bindparam.required:
                     if _group_number:
                         raise exc.InvalidRequestError(
-                                "A value is required for bind parameter %r, "
-                                "in parameter group %d" % 
-                                (bindparam.key, _group_number))
+                            "A value is required for bind parameter %r, "
+                            "in parameter group %d" % 
+                            (bindparam.key, _group_number))
                     else:
                         raise exc.InvalidRequestError(
-                                "A value is required for bind parameter %r" 
-                                % bindparam.key)
+                            "A value is required for bind parameter %r" 
+                            % bindparam.key)
                 else:
                     pd[name] = bindparam.effective_value
             return pd
         else:
             pd = {}
             for bindparam in self.bind_names:
+                if _check and bindparam.required:
+                    if _group_number:
+                        raise exc.InvalidRequestError(
+                            "A value is required for bind parameter %r, "
+                            "in parameter group %d" % 
+                            (bindparam.key, _group_number))
+                    else:
+                        raise exc.InvalidRequestError(
+                            "A value is required for bind parameter %r" 
+                            % bindparam.key)
                 pd[self.bind_names[bindparam]] = bindparam.effective_value
             return pd
 
-    params = property(construct_params, doc="""
-        Return the bind params for this compiled object.
-
-    """)
+    @property
+    def params(self):
+        """Return the bind param dictionary embedded into this 
+        compiled object, for those values that are present."""
+        return self.construct_params(_check=False)
 
     def default_from(self):
         """Called when a SELECT statement has no froms, and no FROM clause is
index 9a53dd89ccaab1c6009524acbf65ff5c217145a0..b84a566d56a8da415adcd88978d8c272cf6deb50 100644 (file)
@@ -1918,6 +1918,48 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             {'x':12}
         )
 
+    def test_bind_params_missing(self):
+        assert_raises_message(exc.InvalidRequestError, 
+            r"A value is required for bind parameter 'x'",
+            select([table1]).where(
+                    and_(
+                        table1.c.myid==bindparam("x", required=True), 
+                        table1.c.name==bindparam("y", required=True)
+                    )
+                ).compile().construct_params,
+            params=dict(y=5)
+        )
+
+        assert_raises_message(exc.InvalidRequestError, 
+            r"A value is required for bind parameter 'x'",
+            select([table1]).where(
+                    table1.c.myid==bindparam("x", required=True)
+                ).compile().construct_params
+        )
+
+        assert_raises_message(exc.InvalidRequestError, 
+            r"A value is required for bind parameter 'x', "
+                "in parameter group 2",
+            select([table1]).where(
+                    and_(
+                        table1.c.myid==bindparam("x", required=True), 
+                        table1.c.name==bindparam("y", required=True)
+                    )
+                ).compile().construct_params,
+            params=dict(y=5),
+            _group_number=2
+        )
+
+        assert_raises_message(exc.InvalidRequestError, 
+            r"A value is required for bind parameter 'x', "
+                "in parameter group 2",
+            select([table1]).where(
+                    table1.c.myid==bindparam("x", required=True)
+                ).compile().construct_params,
+            _group_number=2
+        )
+
+
 
     @testing.emits_warning('.*empty sequence.*')
     def test_in(self):