]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
removed get_str(), get_from_text() from ansicompiler. removes a few hundred method...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 16 Jul 2007 22:05:08 +0000 (22:05 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 16 Jul 2007 22:05:08 +0000 (22:05 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/informix.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/oracle.py
test/orm/unitofwork.py

index 3e6ddd34c5e98d3c685372f9fb0be5b4b4d27da2..361fd7b1ea5e4d22d6f2a00dd3b3cfd27e63ed73 100644 (file)
@@ -140,12 +140,6 @@ class ANSICompiler(engine.Compiled):
         # for aliases
         self.generated_ids = {}
         
-        # True if this compiled represents an INSERT
-        self.isinsert = False
-
-        # True if this compiled represents an UPDATE
-        self.isupdate = False
-
         # default formatting style for bind parameters
         self.bindtemplate = ":%s"
 
@@ -204,12 +198,6 @@ class ANSICompiler(engine.Compiled):
         text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text)
         self.strings[self.statement] = text
 
-    def get_from_text(self, obj):
-        return self.froms.get(obj, None)
-
-    def get_str(self, obj):
-        return self.strings[obj]
-    
     def is_subquery(self, select):
         return self.correlate_state[select].get('is_subquery', False)
         
@@ -337,13 +325,13 @@ class ANSICompiler(engine.Compiled):
             sep = " "
         else:
             sep = " " + sep + " "
-        self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], sep)
+        self.strings[list] = string.join([s for s in [self.strings[c] for c in list.clauses] if s is not None], sep)
 
     def apply_function_parens(self, func):
         return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
 
     def visit_calculatedclause(self, clause):
-        self.strings[clause] = self.get_str(clause.clause_expr)
+        self.strings[clause] = self.strings[clause.clause_expr]
 
     def visit_cast(self, cast):
         if len(self.select_stack):
@@ -358,12 +346,12 @@ class ANSICompiler(engine.Compiled):
             self.strings[func] = ".".join(func.packagenames + [func.name])
             self.froms[func] = self.strings[func]
         else:
-            self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.get_str(func.clause_expr)
+            self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.strings[func.clause_expr]
             self.froms[func] = self.strings[func]
 
     def visit_compound_select(self, cs):
-        text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
-        group_by = self.get_str(cs._group_by_clause)
+        text = string.join([self.strings[c] for c in cs.selects], " " + cs.keyword + " ")
+        group_by = self.strings[cs._group_by_clause]
         if group_by:
             text += " GROUP BY " + group_by
         text += self.order_by_clause(cs)            
@@ -372,7 +360,7 @@ class ANSICompiler(engine.Compiled):
         self.froms[cs] = "(" + text + ")"
 
     def visit_unary(self, unary):
-        s = self.get_str(unary.element)
+        s = self.strings[unary.element]
         if unary.operator:
             s = unary.operator + " " + s
         if unary.modifier:
@@ -380,10 +368,10 @@ class ANSICompiler(engine.Compiled):
         self.strings[unary] = s
         
     def visit_binary(self, binary):
-        result = self.get_str(binary.left)
+        result = self.strings[binary.left]
         if binary.operator is not None:
             result += " " + self.binary_operator_string(binary)
-        result += " " + self.get_str(binary.right)
+        result += " " + self.strings[binary.right]
         self.strings[binary] = result
 
     def binary_operator_string(self, binary):
@@ -455,8 +443,8 @@ class ANSICompiler(engine.Compiled):
         return self.bindtemplate % name
 
     def visit_alias(self, alias):
-        self.froms[alias] = self.get_from_text(alias.original) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name))
-        self.strings[alias] = self.get_str(alias.original)
+        self.froms[alias] = self.froms[alias.original] + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name))
+        self.strings[alias] = self.strings[alias.original]
 
     def enter_select(self, select):
         select._calculate_correlations(self.correlate_state)
@@ -510,19 +498,19 @@ class ANSICompiler(engine.Compiled):
                     inner_columns[labelname] = l
                 else:
                     self.traverse(co)
-                    inner_columns[self.get_str(co)] = co
+                    inner_columns[self.strings[co]] = co
             else:
                 l = self.label_select_column(select, co)
                 if l is not None:
                     self.traverse(l)
-                    inner_columns[self.get_str(l.obj)] = l
+                    inner_columns[self.strings[l.obj]] = l
                 else:
                     self.traverse(co)
-                    inner_columns[self.get_str(co)] = co
+                    inner_columns[self.strings[co]] = co
                     
         self.select_stack.pop(-1)
 
-        collist = string.join([self.get_str(v) for v in inner_columns.values()], ', ')
+        collist = string.join([self.strings[v] for v in inner_columns.values()], ', ')
 
         text = "SELECT "
         text += self.visit_select_precolumns(select)
@@ -541,9 +529,7 @@ class ANSICompiler(engine.Compiled):
                 else:
                     whereclause = w
 
-            t = self.get_from_text(f)
-            if t is not None:
-                from_strings.append(t)
+            from_strings.append(self.froms[f])
 
         if len(froms):
             text += " \nFROM "
@@ -552,16 +538,16 @@ class ANSICompiler(engine.Compiled):
             text += self.default_from()
 
         if whereclause is not None:
-            t = self.get_str(whereclause)
+            t = self.strings[whereclause]
             if t:
                 text += " \nWHERE " + t
 
-        group_by = self.get_str(select._group_by_clause)
+        group_by = self.strings[select._group_by_clause]
         if group_by:
             text += " GROUP BY " + group_by
 
         if select._having is not None:
-            t = self.get_str(select._having)
+            t = self.strings[select._having]
             if t:
                 text += " \nHAVING " + t
 
@@ -586,7 +572,7 @@ class ANSICompiler(engine.Compiled):
         return (select._limit or select._offset) and self.limit_clause(select) or ""
 
     def order_by_clause(self, select):
-        order_by = self.get_str(select._order_by_clause)
+        order_by = self.strings[select._order_by_clause]
         if order_by:
             return " ORDER BY " + order_by
         else:
@@ -613,15 +599,15 @@ class ANSICompiler(engine.Compiled):
         self.strings[table] = ""
 
     def visit_join(self, join):
-        righttext = self.get_from_text(join.right)
+        righttext = self.froms[join.right]
         if join.right._group_parenthesized():
             righttext = "(" + righttext + ")"
         if join.isouter:
-            self.froms[join] = (self.get_from_text(join.left) + " LEFT OUTER JOIN " + righttext +
-            " ON " + self.get_str(join.onclause))
+            self.froms[join] = (self.froms[join.left] + " LEFT OUTER JOIN " + righttext +
+            " ON " + self.strings[join.onclause])
         else:
-            self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext +
-            " ON " + self.get_str(join.onclause))
+            self.froms[join] = (self.froms[join.left] + " JOIN " + righttext +
+            " ON " + self.strings[join.onclause])
         self.strings[join] = self.froms[join]
 
     def visit_insert_column_default(self, column, default, parameters):
@@ -699,9 +685,9 @@ class ANSICompiler(engine.Compiled):
                 self.inline_params.add(col)
                 self.traverse(p)
                 if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
-                    return "(" + self.get_str(p) + ")"
+                    return "(" + self.strings[p] + ")"
                 else:
-                    return self.get_str(p)
+                    return self.strings[p]
 
         text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
          " VALUES (" + string.join([create_param(*c) for c in colparams], ', ') + ")")
@@ -733,14 +719,14 @@ class ANSICompiler(engine.Compiled):
                 self.traverse(p)
                 self.inline_params.add(col)
                 if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnElement):
-                    return "(" + self.get_str(p) + ")"
+                    return "(" + self.strings[p] + ")"
                 else:
-                    return self.get_str(p)
+                    return self.strings[p]
 
         text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), create_param(*c)) for c in colparams], ', ')
 
         if update_stmt._whereclause:
-            text += " WHERE " + self.get_str(update_stmt._whereclause)
+            text += " WHERE " + self.strings[update_stmt._whereclause]
 
         self.strings[update_stmt] = text
 
@@ -805,7 +791,7 @@ class ANSICompiler(engine.Compiled):
         text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
 
         if delete_stmt._whereclause:
-            text += " WHERE " + self.get_str(delete_stmt._whereclause)
+            text += " WHERE " + self.strings[delete_stmt._whereclause]
 
         self.strings[delete_stmt] = text
         
@@ -822,7 +808,7 @@ class ANSICompiler(engine.Compiled):
         self.strings[savepoint_stmt] = text
     
     def __str__(self):
-        return self.get_str(self.statement)
+        return self.strings[self.statement]
 
 class ANSISchemaBase(engine.SchemaIterator):
     def find_alterables(self, tables):
index fc6c5bf366cf41e8045425c823b17f7766302120..0d6bd3360542cb74feeb17aad99f7c1c7c9ca11b 100644 (file)
@@ -305,8 +305,8 @@ class FBCompiler(ansisql.ANSICompiler):
 
     def visit_alias(self, alias):
         # Override to not use the AS keyword which FB 1.5 does not like
-        self.froms[alias] = self.get_from_text(alias.original) + " " + self.preparer.format_alias(alias)
-        self.strings[alias] = self.get_str(alias.original)
+        self.froms[alias] = self.froms[alias.original] + " " + self.preparer.format_alias(alias)
+        self.strings[alias] = self.strings[alias.original]
 
     def visit_function(self, func):
         if len(func.clauses):
index 5382ce2a42424f2a4932f24b0d6b41967a0fbd58..d3f397118200b0e3064adc2d47a76e196ebc5ffd 100644 (file)
@@ -426,9 +426,9 @@ class InfoCompiler(ansisql.ANSICompiler):
         except:
             li = [ c for c in list.clauses ]
         if list.parens:
-            self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in li] if s is not None ], ', ') + ")"
+            self.strings[list] = "(" + string.join([s for s in [self.strings[c] for c in li] if s is not None ], ', ') + ")"
         else:
-            self.strings[list] = string.join([s for s in [self.get_str(c) for c in li] if s is not None], ', ')
+            self.strings[list] = string.join([s for s in [self.strings[c] for c in li] if s is not None], ', ')
 
 class InfoSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, first_pk=False):
index e9e49bfbb03f340134ff4b507be905d8c07dc3b2..0e5a41a34fdd2d97f880ca69779e7964bbf25691 100644 (file)
@@ -856,7 +856,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
         return ''
 
     def order_by_clause(self, select):
-        order_by = self.get_str(select._order_by_clause)
+        order_by = self.strings[select._order_by_clause]
 
         # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
         if order_by and (not self.is_subquery(select) or select._limit):
index 8fd00e504c2fa49599a5ed9c5aeb24c92c36f917..a2b469a304e8f2e64f252ec6f940fde8a9e12ca5 100644 (file)
@@ -475,7 +475,7 @@ class OracleCompiler(ansisql.ANSICompiler):
         if self.dialect.use_ansi:
             return ansisql.ANSICompiler.visit_join(self, join)
 
-        self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right)
+        self.froms[join] = self.froms[join.left] + ", " + self.froms[join.right]
         where = self.wheres.get(join.left, None)
         if where is not None:
             self.wheres[join] = sql.and_(where, join.onclause)
@@ -507,8 +507,8 @@ class OracleCompiler(ansisql.ANSICompiler):
     def visit_alias(self, alias):
         """Oracle doesn't like ``FROM table AS alias``.  Is the AS standard SQL??"""
 
-        self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name
-        self.strings[alias] = self.get_str(alias.original)
+        self.froms[alias] = self.froms[alias.original] + " " + alias.name
+        self.strings[alias] = self.strings[alias.original]
 
     def visit_column(self, column):
         ansisql.ANSICompiler.visit_column(self, column)
@@ -573,7 +573,7 @@ class OracleCompiler(ansisql.ANSICompiler):
 
     def visit_binary(self, binary):
         if binary.operator == '%': 
-            self.strings[binary] = ("MOD(%s,%s)"%(self.get_str(binary.left), self.get_str(binary.right)))
+            self.strings[binary] = ("MOD(%s,%s)"%(self.strings[binary.left], self.strings[binary.right]))
         else:
             return ansisql.ANSICompiler.visit_binary(self, binary)
         
index d9c3cf4c3085f4ed5f3bee6ebbc9ca95570842c6..509fc6acabf30450b361b38a2f2cfe3cb11a3480 100644 (file)
@@ -328,7 +328,7 @@ class MutableTypesTest(UnitOfWorkTest):
 class PKTest(UnitOfWorkTest):
     def setUpAll(self):
         UnitOfWorkTest.setUpAll(self)
-        global table, table2, table3
+        global table, table2, table3, metadata
         metadata = MetaData(db)
         table = Table(
             'multipk', metadata,