From fc17f7e65933cc5b3329436a9dd5f28a094dcc7a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 18 Aug 2006 20:12:39 +0000 Subject: [PATCH] [ticket:280] statement execution supports using the same BindParam object more than once in an expression; simplified handling of positional parameters. nice job by Bill Noon figuring out the basic idea. --- CHANGES | 3 +++ lib/sqlalchemy/ansisql.py | 29 +++++------------------------ lib/sqlalchemy/engine/default.py | 4 ++-- lib/sqlalchemy/sql.py | 17 ++++++++++++----- test/sql/query.py | 10 ++++++++++ 5 files changed, 32 insertions(+), 31 deletions(-) diff --git a/CHANGES b/CHANGES index 6ba3a90d5d..5fb83098e5 100644 --- a/CHANGES +++ b/CHANGES @@ -16,6 +16,9 @@ parent isnt available to cascade from. to save) - improved the check for objects being part of a session when the unit of work seeks to flush() them as part of a relationship.. +- [ticket:280] statement execution supports using the same BindParam +object more than once in an expression; simplified handling of positional +parameters. nice job by Bill Noon figuring out the basic idea. 0.2.7 - quoting facilities set up so that database-specific quoting can be diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 36ae93bc65..031c633284 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -74,6 +74,7 @@ class ANSICompiler(sql.Compiled): self.bindtemplate = ":%s" self.paramstyle = dialect.paramstyle self.positional = dialect.positional + self.positiontup = [] self.preparer = dialect.preparer() def after_compile(self): @@ -84,7 +85,6 @@ class ANSICompiler(sql.Compiled): if self.paramstyle=='pyformat': self.strings[self.statement] = re.sub(match, lambda m:'%(' + m.group(1) +')s', self.strings[self.statement]) elif self.positional: - self.positiontup = [] params = re.finditer(match, self.strings[self.statement]) for p in params: self.positiontup.append(p.group(1)) @@ -128,15 +128,10 @@ class ANSICompiler(sql.Compiled): bindparams = {} bindparams.update(params) - d = sql.ClauseParameters(self.dialect) - if self.positional: - for k in self.positiontup: - b = self.binds[k] - d.set_parameter(k, b.value, b) - else: - for b in self.binds.values(): - d.set_parameter(b.key, b.value, b) - + d = sql.ClauseParameters(self.dialect, self.positiontup) + for b in self.binds.values(): + d.set_parameter(b.key, b.value, b) + for key, value in bindparams.iteritems(): try: b = self.binds[key] @@ -146,20 +141,6 @@ class ANSICompiler(sql.Compiled): return d - def get_named_params(self, parameters): - """given the results of the get_params method, returns the parameters - in dictionary format. For a named paramstyle, this just returns the - same dictionary. For a positional paramstyle, the given parameters are - assumed to be in list format and are converted back to a dictionary. - """ - if self.positional: - p = {} - for i in range(0, len(self.positiontup)): - p[self.positiontup[i]] = parameters[i] - return p - else: - return parameters - def default_from(self): """called when a SELECT statement has no froms, and no FROM clause is to be appended. gives Oracle a chance to tack on a "FROM DUAL" to the string output. """ diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index f0ffd7797e..6bef1fabd1 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -98,9 +98,9 @@ class DefaultDialect(base.Dialect): if parameters is not None: if self.positional: if executemany: - parameters = [p.values() for p in parameters] + parameters = [p.get_raw_list() for p in parameters] else: - parameters = parameters.values() + parameters = parameters.get_raw_list() else: if executemany: parameters = [p.get_raw_dict() for p in parameters] diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 18591c24cf..2aa1342ca0 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -273,12 +273,19 @@ class AbstractDialect(object): """represents the behavior of a particular database. Used by Compiled objects.""" pass -class ClauseParameters(util.OrderedDict): - """represents a dictionary/iterator of bind parameter key names/values. Includes parameters compiled with a Compiled object as well as additional arguments passed to the Compiled object's get_params() method. Parameter values will be converted as per the TypeEngine objects present in the bind parameter objects. The non-converted value can be retrieved via the get_original method. For Compiled objects that compile positional parameters, the values() iteration of the object will return the parameter values in the correct order.""" - def __init__(self, dialect): +class ClauseParameters(dict): + """represents a dictionary/iterator of bind parameter key names/values. + + Tracks the original BindParam objects as well as the keys/position of each + parameter, and can return parameters as a dictionary or a list. + Will process parameter values according to the TypeEngine objects present in + the BindParams. + """ + def __init__(self, dialect, positional=None): super(ClauseParameters, self).__init__(self) self.dialect=dialect self.binds = {} + self.positional = positional or [] def set_parameter(self, key, value, bindparam): self[key] = value self.binds[key] = bindparam @@ -290,10 +297,10 @@ class ClauseParameters(util.OrderedDict): if self.binds.has_key(key): v = self.binds[key].typeprocess(v, self.dialect) return v - def values(self): - return [self[key] for key in self] def get_original_dict(self): return self.copy() + def get_raw_list(self): + return [self[key] for key in self.positional] def get_raw_dict(self): d = {} for k in self: diff --git a/test/sql/query.py b/test/sql/query.py index 2148aae672..ccb998e9e3 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -115,6 +115,16 @@ class QueryTest(PersistTest): finally: test_table.drop() + def test_repeated_bindparams(self): + """test that a BindParam can be used more than once. + this should be run for dbs with both positional and named paramstyles.""" + self.users.insert().execute(user_id = 7, user_name = 'jack') + self.users.insert().execute(user_id = 8, user_name = 'fred') + + u = bindparam('uid') + s = self.users.select(or_(self.users.c.user_name==u, self.users.c.user_name==u)) + r = s.execute(uid='fred').fetchall() + assert len(r) == 1 def testdelete(self): self.users.insert().execute(user_id = 7, user_name = 'jack') -- 2.47.2