]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refactorings to sql generation, unions, engine location
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Dec 2005 08:41:18 +0000 (08:41 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Dec 2005 08:41:18 +0000 (08:41 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/sql.py

index 99402aebb555405a218f48b06e618c3b47aaf156..85b687ad1501b801576ea13c1a212e1e0659bca9 100644 (file)
@@ -170,7 +170,13 @@ class ANSICompiler(sql.Compiled):
 
     def visit_function(self, func):
         self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
-        
+    
+    def visit_compound_select(self, cs):
+        text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ")
+        for tup in cs.clauses:
+            text += " " + tup[0] + " " + self.get_str(tup[1])
+        self.strings[cs] = text
+            
     def visit_binary(self, binary):
         result = self.get_str(binary.left)
         if binary.operator is not None:
@@ -245,7 +251,7 @@ class ANSICompiler(sql.Compiled):
             if t:
                 text += " \nWHERE " + t
 
-        for tup in select._clauses:
+        for tup in select.clauses:
             text += " " + tup[0] + " " + self.get_str(tup[1])
 
         if select.having is not None:
@@ -275,6 +281,7 @@ class ANSICompiler(sql.Compiled):
         else:
             self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext +
             " ON " + self.get_str(join.onclause))
+        self.strings[join] = self.froms[join]
         
     def visit_insert(self, insert_stmt):
         self.isinsert = True
index 9763c697849149026e2675bcf737fd51f776e10b..91547d12c805d0d9a38fcb864a56d78df0a4526e 100644 (file)
@@ -151,10 +151,10 @@ class OracleCompiler(ansisql.ANSICompiler):
     """oracle compiler modifies the lexical structure of Select statements to work under 
     non-ANSI configured Oracle databases, if the use_ansi flag is False."""
     
-    def __init__(self, engine, statement, bindparams, use_ansi = True):
+    def __init__(self, engine, statement, bindparams, use_ansi = True, **kwargs):
         self._outertable = None
         self._use_ansi = use_ansi
-        ansisql.ANSICompiler.__init__(self, engine, statement, bindparams)
+        ansisql.ANSICompiler.__init__(self, engine, statement, bindparams, **kwargs)
         
     def visit_join(self, join):
         if self._use_ansi:
index 486404f33a9bb4ad849fc73817908b1ff4702fdb..6a17b445524a6ed60e8f231a64b17dbd9acd127f 100644 (file)
@@ -168,22 +168,12 @@ func = FunctionGateway()
 def _compound_clause(keyword, *clauses):
     return CompoundClause(keyword, *clauses)
 
-def _compound_select(keyword, *selects, **params):
-    if len(selects) == 0:
-        return None
-    s = selects[0]
-    for n in selects[1:]:
-        s.append_clause(keyword, n)
-
-    if params.get('order_by', None) is not None:
-        s.order_by(*params['order_by'])
-
-    return s
+def _compound_select(keyword, *selects, **kwargs):
+    return CompoundSelect(keyword, *selects, **kwargs)
 
 def _is_literal(element):
     return not isinstance(element, ClauseElement) and not isinstance(element, schema.SchemaItem)
 
-
 class ClauseVisitor(schema.SchemaVisitor):
     """builds upon SchemaVisitor to define the visiting of SQL statement elements in 
     addition to Schema elements."""
@@ -192,6 +182,7 @@ class ClauseVisitor(schema.SchemaVisitor):
     def visit_bindparam(self, bindparam):pass
     def visit_textclause(self, textclause):pass
     def visit_compound(self, compound):pass
+    def visit_compound_select(self, compound):pass
     def visit_binary(self, binary):pass
     def visit_alias(self, alias):pass
     def visit_select(self, select):pass
@@ -258,8 +249,13 @@ class ClauseElement(object):
         change if the underlying structure of the ClauseElement changes.""" 
         raise NotImplementedError(repr(self))
     def _get_from_objects(self):
+        """returns objects represented in this ClauseElement that should be added to the
+        FROM list of a query."""
         raise NotImplementedError(repr(self))
     def _process_from_dict(self, data, asfrom):
+        """given a dictionary attached to a Select object, places the appropriate
+        FROM objects in the dictionary corresponding to this ClauseElement,
+        and possibly removes or modifies others."""
         for f in self._get_from_objects():
             data.setdefault(f.id, f)
         if asfrom:
@@ -274,6 +270,20 @@ class ClauseElement(object):
         new structure can then be restructured without affecting the original."""
         return self
 
+    def _find_engine(self):
+        try:
+            if self._engine is not None:
+                return self._engine
+        except AttributeError:
+            pass
+        for f in self._get_from_objects():
+            engine = f.engine
+            if engine is not None: 
+                return engine
+        else:
+            return None
+            
+    engine = property(lambda s: s._find_engine())
 
     def compile(self, engine = None, bindparams = None, typemap=None):
         """compiles this SQL expression using its underlying SQLEngine to produce
@@ -281,18 +291,19 @@ class ClauseElement(object):
         bindparams is a dictionary representing the default bind parameters to be used with 
         the statement.  """
         if engine is None:
-            for f in self._get_from_objects():
-                engine = f.engine
-                if engine is not None: break
-            else:
-                import sqlalchemy.ansisql as ansisql
-                engine = ansisql.engine()
-                #raise "no engine supplied, and no engine could be located within the clauses!"
+            engine = self.engine
+
+        if engine is None:
+            raise "no SQLEngine could be located within this ClauseElement."
 
         return engine.compile(self, bindparams = bindparams, typemap=typemap)
 
     def __str__(self):
-        return str(self.compile())
+        e = self.engine
+        if e is None:
+            import sqlalchemy.ansisql as ansisql
+            e = ansisql.engine()
+        return str(self.compile(e))
         
     def execute(self, *multiparams, **params):
         """compiles and executes this SQL expression using its underlying SQLEngine. the
@@ -410,8 +421,6 @@ class FromClause(ClauseElement):
         # this could also be [self], at the moment it doesnt matter to the Select object
         return []
         
-    engine = property(lambda s: None)
-    
     def hash_key(self):
         return "FromClause(%s, %s)" % (repr(self.id), repr(self.from_name))
             
@@ -444,7 +453,7 @@ class TextClause(ClauseElement):
     def __init__(self, text = "", engine=None, isliteral=False):
         self.text = text
         self.parens = False
-        self.engine = engine
+        self._engine = engine
         self.id = id(self)
         if isliteral:
             if isinstance(text, int) or isinstance(text, long):
@@ -489,7 +498,10 @@ class ClauseList(ClauseElement):
             c.accept_visitor(visitor)
         visitor.visit_clauselist(self)
     def _get_from_objects(self):
-        return []
+        f = []
+        for c in self.clauses:
+            f += c._get_from_objects()
+        return f
 
 class CompoundClause(ClauseList):
     """represents a list of clauses joined by an operator, such as AND or OR.  
@@ -600,10 +612,6 @@ class Selectable(FromClause):
 
     def alias(self, name):
         return Alias(self, name)
-    def union(self, other, **kwargs):
-        return union(self, other, **kwargs)
-    def union_all(self, other, **kwargs):
-        return union_all(self, other, **kwargs)
     def group_parenthesized(self):
         """indicates if this Selectable requires parenthesis when grouped into a compound
         statement"""
@@ -644,7 +652,7 @@ class Join(Selectable):
         return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter))
 
     def select(self, whereclauses = None, **params):
-        return select([self.left, self.right], and_(self.onclause, whereclauses), **params)
+        return select([self.left, self.right], whereclauses, from_obj=[self], **params)
 
     def accept_visitor(self, visitor):
         self.left.accept_visitor(visitor)
@@ -706,8 +714,6 @@ class Alias(Selectable):
     engine = property(lambda s: s.selectable.engine)
 
 
-
-
 class ColumnImpl(Selectable, CompareMixin):
     """Selectable implementation that gets attached to a schema.Column object."""
     
@@ -821,9 +827,52 @@ class TableImpl(Selectable):
 
     def drop(self, **params):
         self.table.engine.drop(self.table)
+
+class TailClauseMixin(object):
+    def order_by(self, *clauses):
+        self._append_clause('order_by_clause', "ORDER BY", *clauses)
+    def group_by(self, *clauses):
+        self._append_clause('group_by_clause', "GROUP BY", *clauses)
+    def _append_clause(self, attribute, prefix, *clauses):
+        if not hasattr(self, attribute):
+            l = ClauseList(*clauses)
+            setattr(self, attribute, l)
+            self.append_clause(prefix, l)
+        else:
+            getattr(self, attribute).clauses  += clauses
+    def append_clause(self, keyword, clause):
+        if type(clause) == str:
+            clause = TextClause(clause)
+        self.clauses.append((keyword, clause))
+            
+class CompoundSelect(Selectable, TailClauseMixin):
+    def __init__(self, keyword, *selects, **kwargs):
+        self.keyword = keyword
+        self.selects = selects
+        self.clauses = []
+        order_by = kwargs.get('order_by', None)
+        if order_by:
+            self.order_by(*order_by)
+        group_by = kwargs.get('group_by', None)
+        if group_by:
+            self.group_by(*group_by)
+
+    columns = property(lambda s:s.selects[0].columns)
+    def accept_visitor(self, visitor):
+        for tup in self.clauses:
+            tup[1].accept_visitor(visitor)
+        for s in self.selects:
+            s.accept_visitor(visitor)
+        visitor.visit_compound_select(self)
+    def _find_engine(self):
+        for s in self.selects:
+            e = s._find_engine()
+            if e:
+                return e
+        else:
+            return None
         
-    
-class Select(Selectable):
+class Select(Selectable, TailClauseMixin):
     """finally, represents a SELECT statement, with appendable clauses, as well as 
     the ability to execute itself and return a result set."""
     def __init__(self, columns, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None):
@@ -842,11 +891,11 @@ class Select(Selectable):
         # indicates if this select statement is a subquery as a criterion
         # inside of a WHERE clause
         self.is_where = False
+        self.clauses = []
 
         self.distinct = distinct
         self._text = None
         self._raw_columns = []
-        self._clauses = []
         self._correlated = None
         self._correlator = Select.CorrelatedVisitor(self, False)
         self._wherecorrelator = Select.CorrelatedVisitor(self, True)
@@ -929,18 +978,6 @@ class Select(Selectable):
 
         fromclause.accept_visitor(self._correlator)
         fromclause._process_from_dict(self._froms, True)
-        
-    def append_clause(self, keyword, clause):
-        if type(clause) == str:
-            clause = TextClause(clause)
-        self._clauses.append((keyword, clause))
-        
-    def compile(self, engine = None, bindparams = None):
-        if engine is None:
-            engine = self.engine
-        if engine is None:
-            raise "no engine supplied, and no engine could be located within the clauses!"
-        return engine.compile(self, bindparams)
 
     def _get_froms(self):
         return [f for f in self._froms.values() if self._correlated is None or not self._correlated.has_key(f.id)]
@@ -953,25 +990,17 @@ class Select(Selectable):
             self.whereclause.accept_visitor(visitor)
         if self.having is not None:
             self.having.accept_visitor(visitor)
-        for tup in self._clauses:
+        for tup in self.clauses:
             tup[1].accept_visitor(visitor)
             
         visitor.visit_select(self)
     
-    def order_by(self, *clauses):
-        self._append_clause('order_by_clause', "ORDER BY", *clauses)
-    def group_by(self, *clauses):
-        self._append_clause('group_by_clause', "GROUP BY", *clauses)
-    def _append_clause(self, attribute, prefix, *clauses):
-        if not hasattr(self, attribute):
-            l = ClauseList(*clauses)
-            setattr(self, attribute, l)
-            self.append_clause(prefix, l)
-        else:
-            getattr(self, attribute).clauses  += clauses
-                
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
+    def union(self, other, **kwargs):
+        return union(self, other, **kwargs)
+    def union_all(self, other, **kwargs):
+        return union_all(self, other, **kwargs)
 
     def _find_engine(self):
         """tries to return a SQLEngine, either explicitly set in this object, or searched
@@ -988,7 +1017,7 @@ class Select(Selectable):
             
         return None
 
-    engine = property(lambda s: s._find_engine())
+#    engine = property(lambda s: s._find_engine())
     
     def _get_from_objects(self):
         if self.is_where:
@@ -1068,21 +1097,13 @@ class UpdateBase(ClauseElement):
                 values.append((c, value))
         return values
 
-    def compile(self, engine = None, bindparams = None):
-        if engine is None:
-            engine = self.engine
-            
-        if engine is None:
-            raise "no engine supplied, and no engine could be located within the clauses!"
-
-        return engine.compile(self, bindparams)
 
 class Insert(UpdateBase):
     def __init__(self, table, values=None, **params):
         self.table = table
         self.select = None
         self.parameters = self._process_colparams(values)
-        self.engine = self.table.engine
+        self._engine = self.table.engine
         
     def accept_visitor(self, visitor):
         if self.select is not None:
@@ -1095,7 +1116,7 @@ class Update(UpdateBase):
         self.table = table
         self.whereclause = whereclause
         self.parameters = self._process_colparams(values)
-        self.engine = self.table.engine
+        self._engine = self.table.engine
 
     def accept_visitor(self, visitor):
         if self.whereclause is not None:
@@ -1106,7 +1127,7 @@ class Delete(UpdateBase):
     def __init__(self, table, whereclause, **params):
         self.table = table
         self.whereclause = whereclause
-        self.engine = self.table.engine
+        self._engine = self.table.engine
 
     def accept_visitor(self, visitor):
         if self.whereclause is not None: