]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- converted all anonymous labels and aliases to be generated within the compilation...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 16 Jul 2007 18:55:05 +0000 (18:55 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 16 Jul 2007 18:55:05 +0000 (18:55 +0000)
- also some tweaks to unicode result column names; no longer chopping out characters from the names, since the name might be composed of all non-ascii characters.  mysql needs some work here since its returning, i think, the unicode's internally-encoded bytes directly within a bytestring.
- need to simplify the amount of dictionaries present in ANSICompiler, its pretty hard to follow at this point.

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql.py
test/sql/labels.py
test/sql/unicode.py

index 7322c50bd910da018280d0a1cc513f9003b8c9ce..0072f4686994e2417fa0d6a13c184fc41be5a54f 100644 (file)
@@ -175,7 +175,7 @@ class ANSICompiler(engine.Compiled):
         # and also knows postfetching will be needed to get the values represented by these
         # parameters.
         self.inline_params = None
-
+        
     def after_compile(self):
         # this re will search for params like :param
         # it has a negative lookbehind for an extra ':' so that it doesnt match
@@ -244,7 +244,7 @@ class ANSICompiler(engine.Compiled):
         bindparams.update(params)
         d = sql.ClauseParameters(self.dialect, self.positiontup)
         for b in self.binds.values():
-            name = self.bind_names.get(b, b.key)
+            name = self.bind_names[b]
             d.set_parameter(b, b.value, name)
 
         for key, value in bindparams.iteritems():
@@ -252,7 +252,7 @@ class ANSICompiler(engine.Compiled):
                 b = self.binds[key]
             except KeyError:
                 continue
-            name = self.bind_names.get(b, b.key)
+            name = self.bind_names[b]
             d.set_parameter(b, value, name)
 
         return d
@@ -294,15 +294,15 @@ class ANSICompiler(engine.Compiled):
             if column.table.oid_column is column:
                 n = self.dialect.oid_column_name(column)
                 if n is not None:
-                    self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False), n)
+                    self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n)
                 elif len(column.table.primary_key) != 0:
                     pk = list(column.table.primary_key)[0]
                     pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name))
-                    self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname)
+                    self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name))
                 else:
                     self.strings[column] = None
             else:
-                self.strings[column] = self.preparer.format_column_with_table(column, column_name=name)
+                self.strings[column] = self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name))
 
         if len(self.select_stack):
             # if we are within a visit to a Select, set up the "typemap"
@@ -397,7 +397,6 @@ class ANSICompiler(engine.Compiled):
         if bindparam.unique:
             count = 1
             key = bindparam.key
-
             # redefine the generated name of the bind param in the case
             # that we have multiple conflicting bind parameters.
             while self.binds.setdefault(key, bindparam) is not bindparam:
@@ -418,29 +417,44 @@ class ANSICompiler(engine.Compiled):
             return self.bind_names[bindparam]
             
         bind_name = bindparam.key
-        if len(bind_name) > self.dialect.max_identifier_length():
-            bind_name = self._truncated_identifier("bindparam", bind_name)
-            # add to bind_names for translation
-            self.bind_names[bindparam] = bind_name
+        bind_name = self._truncated_identifier("bindparam", bind_name)
+        # add to bind_names for translation
+        self.bind_names[bindparam] = bind_name
+            
         return bind_name
     
     def _truncated_identifier(self, ident_class, name):
         if (ident_class, name) in self.generated_ids:
             return self.generated_ids[(ident_class, name)]
-        if len(name) > self.dialect.max_identifier_length():
+            
+        anonname = self._anonymize(name)
+        if len(anonname) > self.dialect.max_identifier_length():
             counter = self.generated_ids.get(ident_class, 1)
             truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:]
             self.generated_ids[ident_class] = counter + 1
         else:
-            truncname = name
+            truncname = anonname
         self.generated_ids[(ident_class, name)] = truncname
         return truncname
+    
+    def _anonymize(self, name):
+        def anon(match):
+            (ident, derived) = match.group(1,2)
+            if ('anonymous', ident) in self.generated_ids:
+                return self.generated_ids[('anonymous', ident)]
+            else:
+                anonymous_counter = self.generated_ids.get('anonymous', 1)
+                newname = derived + "_" + str(anonymous_counter)
+                self.generated_ids['anonymous'] = anonymous_counter + 1
+                self.generated_ids[('anonymous', ident)] = newname
+                return newname
+        return re.sub(r'{ANON (\d+) (.*)}', anon, name)
             
     def bindparam_string(self, name):
         return self.bindtemplate % name
 
     def visit_alias(self, alias):
-        self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias)
+        self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name))
         self.strings[alias] = self.get_str(alias.original)
 
     def enter_select(self, select):
@@ -1089,8 +1103,8 @@ class ANSIIdentifierPreparer(object):
     def format_label(self, label, name=None):
         return self.__generic_obj_format(label, name or label.name)
 
-    def format_alias(self, alias):
-        return self.__generic_obj_format(alias, alias.name)
+    def format_alias(self, alias, name=None):
+        return self.__generic_obj_format(alias, name or alias.name)
 
     def format_savepoint(self, savepoint):
         return self.__generic_obj_format(savepoint, savepoint)
@@ -1105,25 +1119,25 @@ class ANSIIdentifierPreparer(object):
             result = self.__generic_obj_format(table, table.schema) + "." + result
         return result
 
-    def format_column(self, column, use_table=False, name=None):
+    def format_column(self, column, use_table=False, name=None, table_name=None):
         """Prepare a quoted column name."""
         if name is None:
             name = column.name
         if not getattr(column, 'is_literal', False):
             if use_table:
-                return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, name)
+                return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.__generic_obj_format(column, name)
             else:
                 return self.__generic_obj_format(column, name)
         else:
             # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
             if use_table:
-                return self.format_table(column.table, use_schema=False) + "." + name
+                return self.format_table(column.table, use_schema=False, name=table_name) + "." + name
             else:
                 return name
 
-    def format_column_with_table(self, column, column_name=None):
+    def format_column_with_table(self, column, column_name=None, table_name=None):
         """Prepare a quoted column name with table name."""
         
-        return self.format_column(column, use_table=True, name=column_name)
+        return self.format_column(column, use_table=True, name=column_name, table_name=table_name)
 
 dialect = ANSIDialect
index 6f4cca2234fee0ed3bd44cd99947466626010b00..008d90b0e7c9587291495e70d4d1729dac4fb097 100644 (file)
@@ -1118,7 +1118,7 @@ class ResultProxy(object):
 
             for i, item in enumerate(metadata):
                 # sqlite possibly prepending table name to colnames so strip
-                colname = item[0].split('.')[-1]
+                colname = self.dialect.decode_result_columnname(item[0].split('.')[-1])
                 if self.context.typemap is not None:
                     type = self.context.typemap.get(colname.lower(), typemap.get(item[1], types.NULLTYPE))
                 else:
@@ -1150,7 +1150,9 @@ class ResultProxy(object):
             elif isinstance(key, basestring) and key.lower() in props:
                 rec = props[key.lower()]
             elif isinstance(key, sql.ColumnElement):
+                print "LABEL ON COLUMN", repr(key.key), "IS", repr(key._label)
                 label = context.column_labels.get(key._label, key.name).lower()
+                print "SO YEAH, NOW WE GOT LABEL", repr(label), "AND PROPS IS", repr(props)
                 if label in props:
                     rec = props[label]
 
index c5e1e76ee3e92fd288a4e9032b401ffd321924a7..d44a7095c0e5002c8e10e745d5e7e96b08fa4b0e 100644 (file)
@@ -38,6 +38,11 @@ class DefaultDialect(base.Dialect):
                 map[obj().get_dbapi_type(self.dialect)] = obj
         self._dbapi_type_map = map
     
+    def decode_result_columnname(self, name):
+        """decode a name found in cursor.description to a unicode object."""
+        
+        return name.decode(self.encoding)
+        
     def dbapi_type_map(self):
         return self._dbapi_type_map
             
index 9a4afb9f5b245187f62a1a1518c2a41f8fd903d3..12070b2b421b3057e91865d8d2814d7a0b096ed1 100644 (file)
@@ -7,7 +7,6 @@
 from sqlalchemy import sql, util, exceptions, sql_util, logging, schema
 from sqlalchemy.orm import mapper, class_mapper, object_mapper
 from sqlalchemy.orm.interfaces import OperationContext
-import random
 
 __all__ = ['Query', 'QueryContext', 'SelectionContext']
 
@@ -217,7 +216,7 @@ class Query(object):
         # alias non-labeled column elements. 
         # TODO: make the generation deterministic
         if isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'):
-            column = column.label("anon_" + hex(random.randint(0, 65535))[2:])
+            column = column.label(None)
 
         q._entities = q._entities + [column]
         return q
index b1fbc736986b220e1b36ef1e6ebc1a92c065665f..36dd99ae8d4aa33f98f9fa4045f223a19c1e1157 100644 (file)
@@ -11,7 +11,6 @@ from sqlalchemy.orm import mapper, attributes
 from sqlalchemy.orm.interfaces import *
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm import util as mapperutil
-import random
 
 
 class ColumnLoader(LoaderStrategy):
@@ -395,10 +394,6 @@ class LazyLoader(AbstractRelationLoader):
             FindColumnInColumnClause().traverse(expr)
             return len(columns) and columns[0] or None
         
-        def bind_label():
-            # TODO: make this generation deterministic
-            return "lazy_" + hex(random.randint(0, 65535))[2:]
-
         def visit_binary(binary):
             leftcol = find_column_in_expr(binary.left)
             rightcol = find_column_in_expr(binary.right)
@@ -407,7 +402,7 @@ class LazyLoader(AbstractRelationLoader):
             if should_bind(leftcol, rightcol):
                 col = leftcol
                 binary.left = binds.setdefault(leftcol,
-                        sql.bindparam(bind_label(), None, shortname=leftcol.name, type=binary.right.type, unique=True))
+                        sql.bindparam(None, None, shortname=leftcol.name, type=binary.right.type, unique=True))
                 reverse[rightcol] = binds[col]
 
             # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
@@ -415,7 +410,7 @@ class LazyLoader(AbstractRelationLoader):
             if leftcol is not rightcol and should_bind(rightcol, leftcol):
                 col = rightcol
                 binary.right = binds.setdefault(rightcol,
-                        sql.bindparam(bind_label(), None, shortname=rightcol.name, type=binary.left.type, unique=True))
+                        sql.bindparam(None, None, shortname=rightcol.name, type=binary.left.type, unique=True))
                 reverse[leftcol] = binds[col]
 
         lazywhere = primaryjoin
@@ -485,14 +480,13 @@ class EagerLoader(AbstractRelationLoader):
         """
         
         def __init__(self, eagerloader, parentclauses=None):
-            self.id = (parentclauses is not None and (parentclauses.id + "/") or '') + str(eagerloader.parent_property)
             self.parent = eagerloader
             self.target = eagerloader.select_table
-            self.eagertarget = eagerloader.select_table.alias(self._aliashash("/target"))
+            self.eagertarget = eagerloader.select_table.alias(None)
             self.extra_cols = {}
 
             if eagerloader.secondary:
-                self.eagersecondary = eagerloader.secondary.alias(self._aliashash("/secondary"))
+                self.eagersecondary = eagerloader.secondary.alias(None)
                 if parentclauses is not None:
                     aliasizer = sql_util.ClauseAdapter(self.eagertarget).\
                             chain(sql_util.ClauseAdapter(self.eagersecondary)).\
@@ -540,17 +534,11 @@ class EagerLoader(AbstractRelationLoader):
                     select._should_correlate = False
                     select.append_correlation(self.eagertarget)
             aliased_column = sql_util.ClauseAdapter(self.eagertarget).chain(ModifySubquery()).traverse(aliased_column, clone=True)
-            alias = self._aliashash(column.name)
-            aliased_column = aliased_column.label(alias)
+            aliased_column = aliased_column.label(None)
             self._row_decorator.map[column] = alias
             self.extra_cols[column] = aliased_column
             return aliased_column
             
-        def _aliashash(self, extra):
-            """return a deterministic 4 digit hash value for this AliasedClause's id + extra."""
-            # use the first 4 digits of an MD5 hash
-            return "anon_" + util.hash(self.id + extra)[0:4]
-            
         def _create_decorator_row(self):
             class EagerRowAdapter(object):
                 def __init__(self, row):
index 38c1cb13f78b10174c09f6333fbb9d193f3b2208..db7625382b79ac788d8f05a80cd6b26ef45bd3f3 100644 (file)
@@ -26,7 +26,7 @@ are less guaranteed to stay the same in future releases.
 
 from sqlalchemy import util, exceptions, logging
 from sqlalchemy import types as sqltypes
-import string, re, random, sets
+import string, re, sets
 
 __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
            'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
@@ -1736,7 +1736,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
           ``ClauseElement``.
         """
 
-        self.key = key
+        self.key = key or "{ANON %d param}" % id(self)
         self.value = value
         self.shortname = shortname or key
         self.unique = unique
@@ -2301,11 +2301,7 @@ class Alias(FromClause):
         if alias is None:
             if self.original.named_with_column():
                 alias = getattr(self.original, 'name', None)
-            if alias is None:
-                alias = 'anon'
-            elif len(alias) > 15:
-                alias = alias[0:15]
-            alias = alias + "_" + hex(random.randint(0, 65535))[2:]
+            alias = '{ANON %d %s}' % (id(self), alias or 'anon')
         self.name = alias
         self.encodedname = alias.encode('ascii', 'backslashreplace')
         self.case_sensitive = getattr(baseselectable, "case_sensitive", True)
@@ -2390,9 +2386,10 @@ class _Label(ColumnElement):
     """
     
     def __init__(self, name, obj, type=None):
-        self.name = name
         while isinstance(obj, _Label):
             obj = obj.obj
+        self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
+
         self.obj = obj.self_group(against='AS')
         self.case_sensitive = getattr(obj, "case_sensitive", True)
         self.type = sqltypes.to_instance(type or getattr(obj, 'type', None))
@@ -2422,8 +2419,6 @@ class _Label(ColumnElement):
         else:
             return column(self.name)._make_proxy(selectable=selectable)
 
-legal_characters = util.Set(string.ascii_letters + string.digits + '_')
-
 class _ColumnClause(ColumnElement):
     """Represents a generic column expression from any textual string.
     This includes columns associated with tables, aliases and select
@@ -2492,7 +2487,6 @@ class _ColumnClause(ColumnElement):
                     counter += 1
             else:
                 self.__label = self.name
-            self.__label = "".join([x for x in self.__label if x in legal_characters])
         return self.__label
 
     is_labeled = property(lambda self:self.name != list(self.orig_set)[0].name)
index 384fead50b4fc4936aa59d61880a233d133143f9..6029c83088ac02df99f99e2fcf238a178dfe5fd2 100644 (file)
@@ -16,7 +16,7 @@ class LabelTypeTest(testbase.PersistTest):
 
 class LongLabelsTest(testbase.PersistTest):
     def setUpAll(self):
-        global metadata, table1
+        global metadata, table1, maxlen
         metadata = MetaData(testbase.db)
         table1 = Table("some_large_named_table", metadata,
             Column("this_is_the_primarykey_column", Integer, Sequence("this_is_some_large_seq"), primary_key=True),
@@ -24,11 +24,16 @@ class LongLabelsTest(testbase.PersistTest):
             )
             
         metadata.create_all()
+        
+        maxlen = testbase.db.dialect.max_identifier_length
+        testbase.db.dialect.max_identifier_length = lambda: 29
+        
     def tearDown(self):
         table1.delete().execute()
         
     def tearDownAll(self):
         metadata.drop_all()
+        testbase.db.dialect.max_identifier_length = maxlen
         
     def test_result(self):
         table1.insert().execute(**{"this_is_the_primarykey_column":1, "this_is_the_data_column":"data1"})
index f885dc56ba7b1460d0bb088930dd4320fe74715f..00bcff7c767ab8f042d0c7f208855397bdb8c588 100644 (file)
@@ -13,12 +13,12 @@ class UnicodeSchemaTest(testbase.PersistTest):
         metadata = MetaData(testbase.db)
         t1 = Table('unitable1', metadata,
             Column(u'méil', Integer, primary_key=True),
-            Column(u'éXXm', Integer),
+            Column(u'測試', Integer),
 
             )
         t2 = Table(u'unitéble2', metadata,
             Column(u'méil', Integer, primary_key=True, key="a"),
-            Column(u'éXXm', Integer, ForeignKey(u'unitable1.méil'), key="b"),
+            Column(u'測試', Integer, ForeignKey(u'unitable1.méil'), key="b"),
 
             )
         metadata.create_all()
@@ -31,21 +31,21 @@ class UnicodeSchemaTest(testbase.PersistTest):
         metadata.drop_all()
         
     def test_insert(self):
-        t1.insert().execute({u'méil':1, u'éXXm':5})
+        t1.insert().execute({u'méil':1, u'測試':5})
         t2.insert().execute({'a':1, 'b':1})
         
         assert t1.select().execute().fetchall() == [(1, 5)]
         assert t2.select().execute().fetchall() == [(1, 1)]
     
     def test_reflect(self):
-        t1.insert().execute({u'méil':2, u'éXXm':7})
+        t1.insert().execute({u'méil':2, u'測試':7})
         t2.insert().execute({'a':2, 'b':2})
 
         meta = MetaData(testbase.db)
         tt1 = Table(t1.name, meta, autoload=True)
         tt2 = Table(t2.name, meta, autoload=True)
-        tt1.insert().execute({u'méil':1, u'éXXm':5})
-        tt2.insert().execute({u'méil':1, u'éXXm':1})
+        tt1.insert().execute({u'méil':1, u'測試':5})
+        tt2.insert().execute({u'méil':1, u'測試':1})
 
         assert tt1.select(order_by=desc(u'méil')).execute().fetchall() == [(2, 7), (1, 5)]
         assert tt2.select(order_by=desc(u'méil')).execute().fetchall() == [(2, 2), (1, 1)]
@@ -59,7 +59,7 @@ class UnicodeSchemaTest(testbase.PersistTest):
         mapper(A, t1, properties={
             't2s':relation(B),
             'a':t1.c[u'méil'],
-            'b':t1.c[u'éXXm']
+            'b':t1.c[u'測試']
         })
         mapper(B, t2)
         sess = create_session()