]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
simplification to quoting to just cache strings per-dialect, added quoting for alias...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Sep 2006 01:56:31 +0000 (01:56 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Sep 2006 01:56:31 +0000 (01:56 +0000)
fixes [ticket:294]

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/sql.py
test/orm/selectresults.py
test/sql/quote.py

index d65e8ad338f60b1ad9a56df474bccda2b6c9cf7a..d053f738983377ef70c03370d61829fecd138241 100644 (file)
@@ -32,10 +32,11 @@ def create_engine():
     return engine.ComposedSQLEngine(None, ANSIDialect())
     
 class ANSIDialect(default.DefaultDialect):
-    def __init__(self, **kwargs):
+    def __init__(self, cache_identifiers=True, **kwargs):
         super(ANSIDialect,self).__init__(**kwargs)
         self.identifier_preparer = self.preparer()
-
+        self.cache_identifiers = cache_identifiers
+        
     def connect_args(self):
         return ([],{})
 
@@ -158,7 +159,7 @@ class ANSICompiler(sql.Compiled):
     def visit_label(self, label):
         if len(self.select_stack):
             self.typemap.setdefault(label.name.lower(), label.obj.type)
-        self.strings[label] = self.strings[label.obj] + " AS "  + label.name
+        self.strings[label] = self.strings[label.obj] + " AS "  + self.preparer.format_label(label)
         
     def visit_column(self, column):
         if len(self.select_stack):
@@ -289,7 +290,7 @@ class ANSICompiler(sql.Compiled):
         return self.bindtemplate % name
         
     def visit_alias(self, alias):
-        self.froms[alias] = self.get_from_text(alias.original) + " AS " + alias.name
+        self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias)
         self.strings[alias] = self.get_str(alias.original)
 
     def visit_select(self, select):
@@ -717,7 +718,7 @@ class ANSISchemaDropper(engine.SchemaIterator):
 class ANSIDefaultRunner(engine.DefaultRunner):
     pass
 
-class ANSIIdentifierPreparer(schema.SchemaVisitor):
+class ANSIIdentifierPreparer(object):
     """handles quoting and case-folding of identifiers based on options"""
     def __init__(self, dialect, initial_quote='"', final_quote=None, omit_schema=False):
         """Constructs a new ANSIIdentifierPreparer object.
@@ -731,8 +732,7 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
         self.initial_quote = initial_quote
         self.final_quote = final_quote or self.initial_quote
         self.omit_schema = omit_schema
-        self.__strings = weakref.WeakKeyDictionary()
-        self.__visited = weakref.WeakKeyDictionary()
+        self.__strings = {}
     def _escape_identifier(self, value):
         """escape an identifier.
         
@@ -771,68 +771,24 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
             or bool(len([x for x in str(value) if x not in self._legal_characters()])) \
             or (case_sensitive and value.lower() != value)
     
-    def visit_table(self, table):
-        if table in self.__visited:
-            return
-        
-        if table.quote or self._requires_quotes(table.name, table.case_sensitive):
-            tablestring = self._quote_identifier(table.name)
-        else:
-            tablestring = table.name
-
-        if table.schema:
-            if table.quote_schema or self._requires_quotes(table.schema, table.case_sensitive_schema):
-                schemastring = self._quote_identifier(table.schema)
-            else: 
-                schemastring = table.schema
-        else:
-            schemastring = None
-        
-        self.__strings[table] = (tablestring, schemastring)
-        
-    def visit_column(self, column):
-        if column in self.__visited:
-            return
-        if column.quote or self._requires_quotes(column.name, column.case_sensitive):
-            self.__strings[column] = self._quote_identifier(column.name)
-        else:
-            self.__strings[column] = column.name
-    
-    def visit_sequence(self, sequence):
-        if sequence in self.__visited:
-            return
-        if sequence.quote or self._requires_quotes(sequence.name, sequence.case_sensitive):
-            self.__strings[sequence] = self._quote_identifier(sequence.name)
-        else:
-            self.__strings[sequence] = sequence.name
-                
-    def __analyze_identifiers(self, obj):
-        """insure that each object we encounter is analyzed only once for its lifetime."""
-        if obj in self.__visited:
-            return
-        if isinstance(obj, schema.SchemaItem):
-            obj.accept_schema_visitor(self)
-        self.__visited[obj] = True
-    
-    def __prepare_sequence(self, sequence):
-        self.__analyze_identifiers(sequence)
-        return self.__strings.get(sequence, sequence.name)
-             
-    def __prepare_table(self, table, use_schema=False):
-        self.__analyze_identifiers(table)
-        tablename = self.__strings.get(table, (table.name, None))[0]
-        if not self.omit_schema and use_schema and self.__strings.get(table, (None,None))[1] is not None:
-            return self.__strings[table][1] + "." + tablename
-        else:
-            return tablename
-
-    def __prepare_column(self, column, use_table=True, **kwargs):
-        self.__analyze_identifiers(column)
-        if use_table:
-            return self.__prepare_table(column.table, **kwargs) + "." + self.__strings.get(column, column.name)
+    def __generic_obj_format(self, obj, ident):
+        if getattr(obj, 'quote', False):
+            return self._quote_identifier(ident)
+        if self.dialect.cache_identifiers:
+            try:
+                return self.__strings[ident]
+            except KeyError:
+                if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())):
+                    self.__strings[ident] = self._quote_identifier(ident)
+                else:
+                    self.__strings[ident] = ident
+                return self.__strings[ident]
         else:
-            return self.__strings.get(column, column.name)
-   
+            if self._requires_quotes(ident, getattr(obj, 'case_sensitive', ident == ident.lower())):
+                return self._quote_identifier(ident)
+            else:
+                return ident
+            
     def should_quote(self, object):
         return object.quote or self._requires_quotes(object.name, object.case_sensitive) 
  
@@ -840,16 +796,38 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
         return object.quote or self._requires_quotes(object.name, object.case_sensitive)
         
     def format_sequence(self, sequence):
-        return self.__prepare_sequence(sequence)
+        return self.__generic_obj_format(sequence, sequence.name)
+    
+    def format_label(self, label):
+        return self.__generic_obj_format(label, label.name)
+
+    def format_alias(self, alias):
+        return self.__generic_obj_format(alias, alias.name)
         
     def format_table(self, table, use_schema=True):
         """Prepare a quoted table and schema name"""
-        return self.__prepare_table(table, use_schema=use_schema)
+        result = self.__generic_obj_format(table, table.name)
+        if use_schema and getattr(table, "schema", None):
+            result = self.__generic_obj_format(table, table.schema) + "." + result
+        return result
     
-    def format_column(self, column):
+    def format_column(self, column, use_table=False):
         """Prepare a quoted column name """
-        return self.__prepare_column(column, use_table=False)
-    
+        # TODO: isinstance alert !  get ColumnClause and Column to better
+        # differentiate themselves
+        if isinstance(column, schema.SchemaItem):
+            if use_table:
+                return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, column.name)
+            else:
+                return self.__generic_obj_format(column, column.name)
+        else:
+            # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
+            if use_table:
+                return column.table.name + "." + column.name
+            else:
+                return column.name
+            
     def format_column_with_table(self, column):
         """Prepare a quoted column name with table name"""
-        return self.__prepare_column(column)
+        return self.format_column(column, use_table=True)
+
index 596e0e8eef61c2a2b7b2964b8f01af951b7866fd..d2e270c32cfb2af7df53dcfd7ea8cfa771c629cc 100644 (file)
@@ -1148,6 +1148,7 @@ class Alias(FromClause):
                 alias = alias[0:15]
             alias = alias + "_" + hex(random.randint(0, 65535))[2:]
         self.name = alias
+        self.case_sensitive = getattr(baseselectable, "case_sensitive", alias.lower() != alias)
         
     def _locate_oid_column(self):
         if self.selectable.oid_column is not None:
@@ -1180,6 +1181,7 @@ class Label(ColumnElement):
         while isinstance(obj, Label):
             obj = obj.obj
         self.obj = obj
+        self.case_sensitive = getattr(obj, "case_sensitive", name.lower() != name)
         self.type = sqltypes.to_instance(type)
         obj.parens=True
     key = property(lambda s: s.name)
@@ -1206,7 +1208,7 @@ class ColumnClause(ColumnElement):
     def _get_label(self):
         if self.__label is None:
             if self.table is not None and self.table.named_with_column():
-                self.__label =  self.table.name + "_" + self.name
+                self.__label = self.table.name + "_" + self.name
                 if self.table.c.has_key(self.__label) or len(self.__label) >= 30:
                     self.__label = self.__label[0:24] + "_" + hex(random.randint(0, 65535))[2:]
             else:
index c4b1d6a56eb611b760ed463451b037d3cce9b9d9..6997dfe6bb520665d464305d34a28f06225614da 100644 (file)
@@ -32,6 +32,7 @@ class SelectResultsTest(PersistTest):
         global foo
         foo.drop()
         self.uninstall_threadlocal()
+        clear_mappers()
     
     def test_selectby(self):
         res = self.query.select_by(range=5)
@@ -111,6 +112,7 @@ class SelectResultsTest2(PersistTest):
     def tearDownAll(self):
         metadata.drop_all()
         self.uninstall_threadlocal()
+        clear_mappers()
 
     def test_distinctcount(self):
         res = self.query.select()
@@ -120,6 +122,42 @@ class SelectResultsTest2(PersistTest):
         res = self.query.select(and_(table1.c.id==table2.c.t1id,table2.c.t1id==1), distinct=True)
         self.assertEqual(res.count(), 1)
 
+class SelectResultsTest3(PersistTest):
+    def setUpAll(self):
+        self.install_threadlocal()
+        global metadata, table1, table2
+        metadata = BoundMetaData(testbase.db)
+        table1 = Table('Table1', metadata,
+            Column('ID', Integer, primary_key=True),
+            )
+        table2 = Table('Table2', metadata,
+            Column('T1ID', Integer, ForeignKey("Table1.ID"), primary_key=True),
+            Column('NUM', Integer, primary_key=True),
+            )
+        assign_mapper(Obj1, table1, extension=SelectResultsExt())
+        assign_mapper(Obj2, table2, extension=SelectResultsExt())
+        metadata.create_all()
+        table1.insert().execute({'ID':1},{'ID':2},{'ID':3},{'ID':4})
+        table2.insert().execute({'NUM':1,'T1ID':1},{'NUM':2,'T1ID':1},{'NUM':3,'T1ID':1},\
+{'NUM':4,'T1ID':2},{'NUM':5,'T1ID':2},{'NUM':6,'T1ID':3})
+
+    def setUp(self):
+        self.query = Obj1.mapper.query()
+        #self.orig = self.query.select_whereclause()
+        #self.res = self.query.select()
+
+    def tearDownAll(self):
+        metadata.drop_all()
+        self.uninstall_threadlocal()
+        clear_mappers()
+        
+    def test_distinctcount(self):
+        res = self.query.select()
+        assert res.count() == 4
+        res = self.query.select(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1))
+        assert res.count() == 3
+        res = self.query.select(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1), distinct=True)
+        self.assertEqual(res.count(), 1)
 
 
 if __name__ == "__main__":
index 6b38accbd9f9fcc0269ccc8f4bedd1d086abbf6f..3e1e95a2664ea0d25758d71be68ca2b76dcfc919 100644 (file)
@@ -77,6 +77,18 @@ class QuoteTest(PersistTest):
         assert lcmetadata.case_sensitive is False
         assert t1.c.UcCol.case_sensitive is False
         assert t2.c.normalcol.case_sensitive is False
+    
+    def testlabels(self):
+        """test the quoting of labels.
+        
+        if labels arent quoted, a query in postgres in particular will fail since it produces:
+        
+        SELECT LaLa.lowercase, LaLa."UPPERCASE", LaLa."MixedCase", LaLa."ASC" 
+        FROM (SELECT DISTINCT "WorstCase1".lowercase AS lowercase, "WorstCase1"."UPPERCASE" AS UPPERCASE, "WorstCase1"."MixedCase" AS MixedCase, "WorstCase1"."ASC" AS ASC \nFROM "WorstCase1") AS LaLa
+        
+        where the "UPPERCASE" column of "LaLa" doesnt exist.
+        """
+        x = table1.select(distinct=True).alias("LaLa").select().scalar()
         
         
 if __name__ == "__main__":