]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
[ticket:280] statement execution supports using the same BindParam
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 Aug 2006 20:12:39 +0000 (20:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 18 Aug 2006 20:12:39 +0000 (20:12 +0000)
object more than once in an expression; simplified handling of positional
parameters.  nice job by Bill Noon figuring out the basic idea.

CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql.py
test/sql/query.py

diff --git a/CHANGES b/CHANGES
index 6ba3a90d5d2196aba81575796a938b0692aa1a1f..5fb83098e5961ba2974a8515698f9c6f298a9166 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -16,6 +16,9 @@ parent isnt available to cascade from.
 to save)
 - improved the check for objects being part of a session when the
 unit of work seeks to flush() them as part of a relationship..
+- [ticket:280] statement execution supports using the same BindParam
+object more than once in an expression; simplified handling of positional
+parameters.  nice job by Bill Noon figuring out the basic idea.
 
 0.2.7
 - quoting facilities set up so that database-specific quoting can be
index 36ae93bc654272922773d368126b89ea08da1926..031c63328409d8d41957e7967248fd2e6c3483c4 100644 (file)
@@ -74,6 +74,7 @@ class ANSICompiler(sql.Compiled):
         self.bindtemplate = ":%s"
         self.paramstyle = dialect.paramstyle
         self.positional = dialect.positional
+        self.positiontup = []
         self.preparer = dialect.preparer()
         
     def after_compile(self):
@@ -84,7 +85,6 @@ class ANSICompiler(sql.Compiled):
         if self.paramstyle=='pyformat':
             self.strings[self.statement] = re.sub(match, lambda m:'%(' + m.group(1) +')s', self.strings[self.statement])
         elif self.positional:
-            self.positiontup = []
             params = re.finditer(match, self.strings[self.statement])
             for p in params:
                 self.positiontup.append(p.group(1))
@@ -128,15 +128,10 @@ class ANSICompiler(sql.Compiled):
             bindparams = {}
         bindparams.update(params)
 
-        d = sql.ClauseParameters(self.dialect)
-        if self.positional:
-            for k in self.positiontup:
-                b = self.binds[k]
-                d.set_parameter(k, b.value, b)
-        else:
-            for b in self.binds.values():
-                d.set_parameter(b.key, b.value, b)
-            
+        d = sql.ClauseParameters(self.dialect, self.positiontup)
+        for b in self.binds.values():
+            d.set_parameter(b.key, b.value, b)
+
         for key, value in bindparams.iteritems():
             try:
                 b = self.binds[key]
@@ -146,20 +141,6 @@ class ANSICompiler(sql.Compiled):
 
         return d
 
-    def get_named_params(self, parameters):
-        """given the results of the get_params method, returns the parameters
-        in dictionary format.  For a named paramstyle, this just returns the
-        same dictionary.  For a positional paramstyle, the given parameters are
-        assumed to be in list format and are converted back to a dictionary.
-        """
-        if self.positional:
-            p = {}
-            for i in range(0, len(self.positiontup)):
-                p[self.positiontup[i]] = parameters[i]
-            return p
-        else:
-            return parameters
-    
     def default_from(self):
         """called when a SELECT statement has no froms, and no FROM clause is to be appended.  
         gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
index f0ffd7797e2f10d4b6712d874d5d39fb631e7cf1..6bef1fabd167cde9e6822a5d42e2254f047fb481 100644 (file)
@@ -98,9 +98,9 @@ class DefaultDialect(base.Dialect):
         if parameters is not None:
            if self.positional:
                 if executemany:
-                    parameters = [p.values() for p in parameters]
+                    parameters = [p.get_raw_list() for p in parameters]
                 else:
-                    parameters = parameters.values()
+                    parameters = parameters.get_raw_list()
            else:
                 if executemany:
                     parameters = [p.get_raw_dict() for p in parameters]
index 18591c24cff314137af3dc71527ae8e847dc9330..2aa1342ca07a3ec142994d747a70d6d864983f02 100644 (file)
@@ -273,12 +273,19 @@ class AbstractDialect(object):
     """represents the behavior of a particular database.  Used by Compiled objects."""
     pass
     
-class ClauseParameters(util.OrderedDict):
-    """represents a dictionary/iterator of bind parameter key names/values.  Includes parameters compiled with a Compiled object as well as additional arguments passed to the Compiled object's get_params() method.  Parameter values will be converted as per the TypeEngine objects present in the bind parameter objects.  The non-converted value can be retrieved via the get_original method.  For Compiled objects that compile positional parameters, the values() iteration of the object will return the parameter values in the correct order."""
-    def __init__(self, dialect):
+class ClauseParameters(dict):
+    """represents a dictionary/iterator of bind parameter key names/values.  
+    
+    Tracks the original BindParam objects as well as the keys/position of each
+    parameter, and can return parameters as a dictionary or a list.
+    Will process parameter values according to the TypeEngine objects present in
+    the BindParams.
+    """
+    def __init__(self, dialect, positional=None):
         super(ClauseParameters, self).__init__(self)
         self.dialect=dialect
         self.binds = {}
+        self.positional = positional or []
     def set_parameter(self, key, value, bindparam):
         self[key] = value
         self.binds[key] = bindparam
@@ -290,10 +297,10 @@ class ClauseParameters(util.OrderedDict):
         if self.binds.has_key(key):
             v = self.binds[key].typeprocess(v, self.dialect)
         return v
-    def values(self):
-        return [self[key] for key in self]
     def get_original_dict(self):
         return self.copy()
+    def get_raw_list(self):
+        return [self[key] for key in self.positional]
     def get_raw_dict(self):
         d = {}
         for k in self:
index 2148aae672b67ad6314c2de6187cd1a491034d02..ccb998e9e3f73e006930cae1a66713bb9b5a35f3 100644 (file)
@@ -115,6 +115,16 @@ class QueryTest(PersistTest):
         finally:
             test_table.drop()
 
+    def test_repeated_bindparams(self):
+        """test that a BindParam can be used more than once.  
+        this should be run for dbs with both positional and named paramstyles."""
+        self.users.insert().execute(user_id = 7, user_name = 'jack')
+        self.users.insert().execute(user_id = 8, user_name = 'fred')
+
+        u = bindparam('uid')
+        s = self.users.select(or_(self.users.c.user_name==u, self.users.c.user_name==u))
+        r = s.execute(uid='fred').fetchall()
+        assert len(r) == 1
         
     def testdelete(self):
         self.users.insert().execute(user_id = 7, user_name = 'jack')