From ccbcbda43e74a1d09d50aa2f8212b3cb9adafd23 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 28 Mar 2007 01:39:58 +0000 Subject: [PATCH] added label truncation for bind param names which was lost in the previous related commit. added more tests plus test for column targeting with text() clause. --- lib/sqlalchemy/ansisql.py | 36 +++++++++++++++++++++++++++--------- lib/sqlalchemy/sql.py | 6 +++--- test/sql/labels.py | 10 ++++++++++ test/sql/query.py | 4 ++++ 4 files changed, 44 insertions(+), 12 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 37b6366a9f..050e605ebf 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -99,6 +99,10 @@ class ANSICompiler(sql.Compiled): # a dictionary of bind parameter keys to _BindParamClause instances. self.binds = {} + + # a dictionary of _BindParamClause instances to "compiled" names that are + # actually present in the generated SQL + self.bind_names = {} # a dictionary which stores the string representation for every ClauseElement # processed by this compiler. @@ -216,14 +220,16 @@ class ANSICompiler(sql.Compiled): bindparams.update(params) d = sql.ClauseParameters(self.dialect, self.positiontup) for b in self.binds.values(): - d.set_parameter(b, b.value) + name = self.bind_names.get(b, b.key) + d.set_parameter(b, b.value, name) for key, value in bindparams.iteritems(): try: b = self.binds[key] except KeyError: continue - d.set_parameter(b, value) + name = self.bind_names.get(b, b.key) + d.set_parameter(b, value, name) return d @@ -358,8 +364,11 @@ class ANSICompiler(sql.Compiled): return binary.operator def visit_bindparam(self, bindparam): + # apply truncation to the ultimate generated name + if bindparam.shortname != bindparam.key: self.binds.setdefault(bindparam.shortname, bindparam) + if bindparam.unique: count = 1 key = bindparam.key @@ -367,20 +376,29 @@ class ANSICompiler(sql.Compiled): # redefine the generated name of the bind param in the case # that we have multiple conflicting bind parameters. while self.binds.setdefault(key, bindparam) is not bindparam: - # ensure the name doesn't expand the length of the string - # in case we're at the edge of max identifier length tag = "_%d" % count - key = bindparam.key[0 : len(bindparam.key) - len(tag)] + tag + key = bindparam.key + tag count += 1 bindparam.key = key - self.strings[bindparam] = self.bindparam_string(key) + self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam)) else: existing = self.binds.get(bindparam.key) if existing is not None and existing.unique: raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) - self.strings[bindparam] = self.bindparam_string(bindparam.key) + self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam)) self.binds[bindparam.key] = bindparam + + def _truncate_bindparam(self, bindparam): + if bindparam in self.bind_names: + return self.bind_names[bindparam] + bind_name = bindparam.key + if len(bind_name) >= self.dialect.max_identifier_length(): + bind_name = bind_name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(random.randint(0, 65535))[2:] + # add to bind_names for translation + self.bind_names[bindparam] = bind_name + return bind_name + def bindparam_string(self, name): return self.bindtemplate % name @@ -614,7 +632,7 @@ class ANSICompiler(sql.Compiled): self.binds[p.key] = p if p.shortname is not None: self.binds[p.shortname] = p - return self.bindparam_string(p.key) + return self.bindparam_string(self._truncate_bindparam(p)) else: self.inline_params.add(col) self.traverse(p) @@ -648,7 +666,7 @@ class ANSICompiler(sql.Compiled): if isinstance(p, sql._BindParamClause): self.binds[p.key] = p self.binds[p.shortname] = p - return self.bindparam_string(p.key) + return self.bindparam_string(self._truncate_bindparam(p)) else: self.traverse(p) self.inline_params.add(col) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 849dfe1d12..be43bb21b5 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -457,9 +457,9 @@ class ClauseParameters(dict): self.binds = {} self.positional = positional or [] - def set_parameter(self, bindparam, value): - self[bindparam.key] = value - self.binds[bindparam.key] = bindparam + def set_parameter(self, bindparam, value, name): + self[name] = value + self.binds[name] = bindparam def get_original(self, key): """Return the given parameter as it was originally placed in diff --git a/test/sql/labels.py b/test/sql/labels.py index 0b39576198..0302fee784 100644 --- a/test/sql/labels.py +++ b/test/sql/labels.py @@ -30,5 +30,15 @@ class LongLabelsTest(testbase.PersistTest): (4, "data4"), ] + def test_colbinds(self): + r = table1.select(table1.c.this_is_the_primary_key_column == 4).execute() + assert r.fetchall() == [(4, "data4")] + + r = table1.select(or_( + table1.c.this_is_the_primary_key_column == 4, + table1.c.this_is_the_primary_key_column == 2 + )).execute() + assert r.fetchall() == [(2, "data2"), (4, "data4")] + if __name__ == '__main__': testbase.main() \ No newline at end of file diff --git a/test/sql/query.py b/test/sql/query.py index 2247418675..3c3e2334c0 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -216,6 +216,10 @@ class QueryTest(PersistTest): self.assert_(r.user_id == r['user_id'] == r[self.users.c.user_id] == 2) self.assert_(r.user_name == r['user_name'] == r[self.users.c.user_name] == 'jack') + r = text("select * from query_users where user_id=2", engine=testbase.db).execute().fetchone() + self.assert_(r.user_id == r['user_id'] == r[self.users.c.user_id] == 2) + self.assert_(r.user_name == r['user_name'] == r[self.users.c.user_name] == 'jack') + def test_keys(self): self.users.insert().execute(user_id=1, user_name='foo') r = self.users.select().execute().fetchone() -- 2.47.2