]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refactoring of ANSIIdentifierPreparer to be one instance per-dialect, simplified...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 21 Aug 2006 03:50:23 +0000 (03:50 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 21 Aug 2006 03:50:23 +0000 (03:50 +0000)
lib/sqlalchemy/ansisql.py

index f4b0852e6f2c814290ee01254360a13b80c70094..53c6db6c47e91cd342ed8b7a7d93e8da0f1f8b5c 100644 (file)
@@ -29,7 +29,7 @@ def create_engine():
 class ANSIDialect(default.DefaultDialect):
     def __init__(self, **kwargs):
         super(ANSIDialect,self).__init__(**kwargs)
-        self._identifier_cache = weakref.WeakKeyDictionary()
+        self.identifier_preparer = self.preparer()
 
     def connect_args(self):
         return ([],{})
@@ -79,8 +79,7 @@ class ANSICompiler(sql.Compiled):
         self.paramstyle = dialect.paramstyle
         self.positional = dialect.positional
         self.positiontup = []
-        self.preparer = dialect.preparer()
-        
+        self.preparer = dialect.identifier_preparer
         
     def after_compile(self):
         # this re will search for params like :param
@@ -722,8 +721,8 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
         self.initial_quote = initial_quote
         self.final_quote = final_quote or self.initial_quote
         self.omit_schema = omit_schema
-        self.strings = {}
-        self.__visited = util.Set()
+        self.__strings = weakref.WeakKeyDictionary()
+        self.__visited = weakref.WeakKeyDictionary()
     def _escape_identifier(self, value):
         """escape an identifier.
         
@@ -749,75 +748,55 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor):
         """return true if the given identifier requires quoting."""
         return False
     
-    def __requires_quotes_cached(self, value, natural_case):
-        try:
-            return self.dialect._identifier_cache[(value, natural_case)]
-        except KeyError:
-            result = self._requires_quotes(value, natural_case)
-            self.dialect._identifier_cache[(value, natural_case)] = result
-            return result
-            
     def visit_table(self, table):
         if table in self.__visited:
             return
         
-        # cache the results within the dialect, weakly keyed to the table    
-        try:
-            (self.strings[table], self.strings[(table, 'schema')]) = self.dialect._identifier_cache[table]
-            return
-        except KeyError:
-            pass
-        
         if table.quote or self._requires_quotes(table.name, table.natural_case):
-            self.strings[table] = self._quote_identifier(table.name)
+            tablestring = self._quote_identifier(table.name)
         else:
-            self.strings[table] = table.name
+            tablestring = table.name
+
         if table.schema:
             if table.quote_schema or self._requires_quotes(table.schema, table.natural_case_schema):
-                self.strings[(table, 'schema')] = self._quote_identifier(table.schema)
+                schemastring = self._quote_identifier(table.schema)
             else: 
-                self.strings[(table, 'schema')] = table.schema
+                schemastring = table.schema
         else:
-            self.strings[(table,'schema')] = None
-        self.dialect._identifier_cache[table] = (self.strings[table], self.strings[(table, 'schema')])
+            schemastring = None
+        
+        self.__strings[table] = (tablestring, schemastring)
         
     def visit_column(self, column):
         if column in self.__visited:
             return
-
-        # cache the results within the dialect, weakly keyed to the column    
-        try:
-            self.strings[column] = self.dialect._identifier_cache[column]
-            return
-        except KeyError:
-            pass
-
         if column.quote or self._requires_quotes(column.name, column.natural_case):
-            self.strings[column] = self._quote_identifier(column.name)
+            self.__strings[column] = self._quote_identifier(column.name)
         else:
-            self.strings[column] = column.name
-        self.dialect._identifier_cache[column] = self.strings[column]
+            self.__strings[column] = column.name
         
-    def __start_visit(self, obj):
+    def __analyze_identifiers(self, obj):
+        """insure that each object we encounter is analyzed only once for its lifetime."""
         if obj in self.__visited:
             return
         if isinstance(obj, schema.SchemaItem):
             obj.accept_schema_visitor(self)
-        self.__visited.add(obj)
+        self.__visited[obj] = True
          
     def __prepare_table(self, table, use_schema=False):
-        self.__start_visit(table)
-        if not self.omit_schema and use_schema and self.strings.get((table, 'schema'), None) is not None:
-            return self.strings[(table, 'schema')] + "." + self.strings.get(table, table.name)
+        self.__analyze_identifiers(table)
+        tablename = self.__strings.get(table, (table.name, None))[0]
+        if not self.omit_schema and use_schema and self.__strings.get(table, (None,None))[1] is not None:
+            return self.__strings[table][1] + "." + tablename
         else:
-            return self.strings.get(table, table.name)
+            return tablename
 
     def __prepare_column(self, column, use_table=True, **kwargs):
-        self.__start_visit(column)
+        self.__analyze_identifiers(column)
         if use_table:
-            return self.__prepare_table(column.table, **kwargs) + "." + self.strings.get(column, column.name)
+            return self.__prepare_table(column.table, **kwargs) + "." + self.__strings.get(column, column.name)
         else:
-            return self.strings.get(column, column.name)
+            return self.__strings.get(column, column.name)
     
     def format_table(self, table, use_schema=True):
         """Prepare a quoted table and schema name"""