]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added label truncation for bind param names which was lost in the previous related...
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 28 Mar 2007 01:39:58 +0000 (01:39 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 28 Mar 2007 01:39:58 +0000 (01:39 +0000)
added more tests plus test for column targeting with text() clause.

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/sql.py
test/sql/labels.py
test/sql/query.py

index 37b6366a9f24ad3fa400b930c89543e62043025a..050e605ebf6b78a223b6f6ed71fa8ceccb8f34d4 100644 (file)
@@ -99,6 +99,10 @@ class ANSICompiler(sql.Compiled):
 
         # a dictionary of bind parameter keys to _BindParamClause instances.
         self.binds = {}
+        
+        # a dictionary of _BindParamClause instances to "compiled" names that are
+        # actually present in the generated SQL
+        self.bind_names = {}
 
         # a dictionary which stores the string representation for every ClauseElement
         # processed by this compiler.
@@ -216,14 +220,16 @@ class ANSICompiler(sql.Compiled):
         bindparams.update(params)
         d = sql.ClauseParameters(self.dialect, self.positiontup)
         for b in self.binds.values():
-            d.set_parameter(b, b.value)
+            name = self.bind_names.get(b, b.key)
+            d.set_parameter(b, b.value, name)
 
         for key, value in bindparams.iteritems():
             try:
                 b = self.binds[key]
             except KeyError:
                 continue
-            d.set_parameter(b, value)
+            name = self.bind_names.get(b, b.key)
+            d.set_parameter(b, value, name)
 
         return d
 
@@ -358,8 +364,11 @@ class ANSICompiler(sql.Compiled):
         return binary.operator
 
     def visit_bindparam(self, bindparam):
+        # apply truncation to the ultimate generated name
+
         if bindparam.shortname != bindparam.key:
             self.binds.setdefault(bindparam.shortname, bindparam)
+
         if bindparam.unique:
             count = 1
             key = bindparam.key
@@ -367,20 +376,29 @@ class ANSICompiler(sql.Compiled):
             # redefine the generated name of the bind param in the case
             # that we have multiple conflicting bind parameters.
             while self.binds.setdefault(key, bindparam) is not bindparam:
-                # ensure the name doesn't expand the length of the string
-                # in case we're at the edge of max identifier length
                 tag = "_%d" % count
-                key = bindparam.key[0 : len(bindparam.key) - len(tag)] + tag
+                key = bindparam.key + tag
                 count += 1
             bindparam.key = key
-            self.strings[bindparam] = self.bindparam_string(key)
+            self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
         else:
             existing = self.binds.get(bindparam.key)
             if existing is not None and existing.unique:
                 raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
-            self.strings[bindparam] = self.bindparam_string(bindparam.key)
+            self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
             self.binds[bindparam.key] = bindparam
+    
+    def _truncate_bindparam(self, bindparam):
+        if bindparam in self.bind_names:
+            return self.bind_names[bindparam]
             
+        bind_name = bindparam.key
+        if len(bind_name) >= self.dialect.max_identifier_length():
+            bind_name = bind_name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(random.randint(0, 65535))[2:]
+            # add to bind_names for translation
+            self.bind_names[bindparam] = bind_name
+        return bind_name
+        
     def bindparam_string(self, name):
         return self.bindtemplate % name
 
@@ -614,7 +632,7 @@ class ANSICompiler(sql.Compiled):
                 self.binds[p.key] = p
                 if p.shortname is not None:
                     self.binds[p.shortname] = p
-                return self.bindparam_string(p.key)
+                return self.bindparam_string(self._truncate_bindparam(p))
             else:
                 self.inline_params.add(col)
                 self.traverse(p)
@@ -648,7 +666,7 @@ class ANSICompiler(sql.Compiled):
             if isinstance(p, sql._BindParamClause):
                 self.binds[p.key] = p
                 self.binds[p.shortname] = p
-                return self.bindparam_string(p.key)
+                return self.bindparam_string(self._truncate_bindparam(p))
             else:
                 self.traverse(p)
                 self.inline_params.add(col)
index 849dfe1d128773e6bcd9e87bf1c794ebd8d34ead..be43bb21b5f44015706ab08137dda2b0ba75bb90 100644 (file)
@@ -457,9 +457,9 @@ class ClauseParameters(dict):
         self.binds = {}
         self.positional = positional or []
 
-    def set_parameter(self, bindparam, value):
-        self[bindparam.key] = value
-        self.binds[bindparam.key] = bindparam
+    def set_parameter(self, bindparam, value, name):
+        self[name] = value
+        self.binds[name] = bindparam
 
     def get_original(self, key):
         """Return the given parameter as it was originally placed in
index 0b3957619815c970de2a845627310483fc7b6a90..0302fee7845ac9f3949d82308c4bf158edbcb4c4 100644 (file)
@@ -30,5 +30,15 @@ class LongLabelsTest(testbase.PersistTest):
             (4, "data4"),
         ]
     
+    def test_colbinds(self):
+        r = table1.select(table1.c.this_is_the_primary_key_column == 4).execute()
+        assert r.fetchall() == [(4, "data4")]
+
+        r = table1.select(or_(
+            table1.c.this_is_the_primary_key_column == 4,
+            table1.c.this_is_the_primary_key_column == 2
+        )).execute()
+        assert r.fetchall() == [(2, "data2"), (4, "data4")]
+        
 if __name__ == '__main__':
     testbase.main()
\ No newline at end of file
index 22474186757428a33a2fb26e86937474ae622343..3c3e2334c085e157e07d18674ed8e471851ccdf6 100644 (file)
@@ -216,6 +216,10 @@ class QueryTest(PersistTest):
         self.assert_(r.user_id == r['user_id'] == r[self.users.c.user_id] == 2)
         self.assert_(r.user_name == r['user_name'] == r[self.users.c.user_name] == 'jack')
 
+        r = text("select * from query_users where user_id=2", engine=testbase.db).execute().fetchone()
+        self.assert_(r.user_id == r['user_id'] == r[self.users.c.user_id] == 2)
+        self.assert_(r.user_name == r['user_name'] == r[self.users.c.user_name] == 'jack')
+        
     def test_keys(self):
         self.users.insert().execute(user_id=1, user_name='foo')
         r = self.users.select().execute().fetchone()