]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added group_by, having to select. added func.foo(a, b) keyword to express functions...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Nov 2005 06:43:23 +0000 (06:43 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Nov 2005 06:43:23 +0000 (06:43 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/sql.py

index 55550bfa88650415c7214732a9e5ae413b37f74c..4d65a70ceafaf125b16948fab1ceacf71bdde6b8 100644 (file)
@@ -133,7 +133,10 @@ class ANSICompiler(sql.Compiled):
             self.strings[column] = "%s.%s" % (column.table.name, column.name)
 
     def visit_columnclause(self, column):
-        self.strings[column] = "%s.%s" % (column.table.name, column.text)
+        if column.table is not None and column.table.name is not None:
+            self.strings[column] = "%s.%s" % (column.table.name, column.text)
+        else:
+            self.strings[column] = column.text
 
     def visit_fromclause(self, fromclause):
         self.froms[fromclause] = fromclause.from_name
@@ -143,7 +146,8 @@ class ANSICompiler(sql.Compiled):
             self.strings[textclause] = "(" + textclause.text + ")"
         else:
             self.strings[textclause] = textclause.text
-
+        self.froms[textclause] = textclause.text
+        
     def visit_null(self, null):
         self.strings[null] = 'NULL'
        
@@ -163,6 +167,9 @@ class ANSICompiler(sql.Compiled):
             self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ', ') + ")"
         else:
             self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
+
+    def visit_function(self, func):
+        self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
         
     def visit_binary(self, binary):
         result = self.get_str(binary.left)
@@ -198,6 +205,7 @@ class ANSICompiler(sql.Compiled):
 
         for c in select._raw_columns:
             for co in c.columns:
+                co.accept_visitor(self)
                 inner_columns.append(co)
                 if select.use_labels:
                     self.typemap.setdefault(co.label, co.type)
@@ -205,9 +213,9 @@ class ANSICompiler(sql.Compiled):
                     self.typemap.setdefault(co.key, co.type)
                 
         if select.use_labels:
-            collist = string.join(["%s AS %s" % (c.fullname, c.label) for c in inner_columns], ', ')
+            collist = string.join(["%s AS %s" % (self.get_str(c), c.label) for c in inner_columns], ', ')
         else:
-            collist = string.join([c.fullname for c in inner_columns], ', ')
+            collist = string.join([self.get_str(c) for c in inner_columns], ', ')
 
         text = "SELECT "
         if select.distinct:
@@ -240,6 +248,11 @@ class ANSICompiler(sql.Compiled):
         for tup in select._clauses:
             text += " " + tup[0] + " " + self.get_str(tup[1])
 
+        if select.having is not None:
+            t = self.get_str(select.having)
+            if t:
+                text += " \nHAVING " + t
+                
         if getattr(select, 'issubquery', False):
             self.strings[select] = "(" + text + ")"
         else:
index 5b68ff1bddd169a03a6d966c5763b7811c2d74c8..83ac01a23e4c7f518dec94b03070434739a96255 100644 (file)
@@ -23,7 +23,7 @@ import sqlalchemy.util as util
 import sqlalchemy.types as types
 import string, re
 
-__ALL__ = ['textclause', 'select', 'join', 'and_', 'or_', 'not_', 'union', 'unionall', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'bindparam', 'sequence']
+__ALL__ = ['text', 'column', 'func', 'select', 'join', 'and_', 'or_', 'not_', 'union', 'unionall', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'bindparam', 'sequence']
 
 def desc(column):
     """returns a descending ORDER BY clause element, e.g.:
@@ -122,7 +122,11 @@ def or_(*clauses):
 def not_(clause):
     clause.parens=True
     return BinaryClause(TextClause("NOT"), clause, None)
-    
+
+            
+def column(table, text):
+    return ColumnClause(text, table)
+        
 def exists(*args, **params):
     s = select(*args, **params)
     return BinaryClause(TextClause("EXISTS"), s, None)
@@ -154,6 +158,13 @@ def null():
 def sequence():
     return Sequence()
 
+class FunctionGateway(object):
+    """returns a callable based on an attribute name, which then returns a Function 
+    object with that name."""
+    def __getattr__(self, name):
+        return lambda *c, **kwargs: Function(name, *c, **kwargs)
+func = FunctionGateway()
+
 def _compound_clause(keyword, *clauses):
     return CompoundClause(keyword, *clauses)
 
@@ -187,6 +198,7 @@ class ClauseVisitor(schema.SchemaVisitor):
     def visit_join(self, join):pass
     def visit_null(self, null):pass
     def visit_clauselist(self, list):pass
+    def visit_function(self, func):pass
     
 class Compiled(ClauseVisitor):
     """represents a compiled SQL expression.  the __str__ method of the Compiled object
@@ -345,7 +357,7 @@ class CompareMixin(object):
 class ColumnClause(ClauseElement, CompareMixin):
     """represents a textual column clause in a SQL statement."""
 
-    def __init__(self, text, selectable):
+    def __init__(self, text, selectable=None):
         self.text = text
         self.table = selectable
         self._impl = ColumnImpl(self)
@@ -355,13 +367,15 @@ class ColumnClause(ClauseElement, CompareMixin):
     name = property(lambda self:self.text)
     key = property(lambda self:self.text)
     label = property(lambda self:self.text)
-    fullname = property(lambda self:self.text)
 
     def accept_visitor(self, visitor): 
         visitor.visit_columnclause(self)
 
     def hash_key(self):
-        return "ColumnClause(%s, %s)" % (self.text, self.table.hash_key())
+        if self.table is not None:
+            return "ColumnClause(%s, %s)" % (self.text, self.table.hash_key())
+        else:
+            return "ColumnClause(%s)" % self.text
 
     def _get_from_objects(self):
         return []
@@ -431,6 +445,7 @@ class TextClause(ClauseElement):
         self.text = text
         self.parens = False
         self.engine = engine
+        self.id = id(self)
         if isliteral:
             if isinstance(text, int) or isinstance(text, long):
                 self.text = str(text)
@@ -452,53 +467,84 @@ class Null(ClauseElement):
     def hash_key(self):
         return "Null"
     
-class CompoundClause(ClauseElement):
-    """represents a list of clauses joined by an operator"""
-    def __init__(self, operator, *clauses, **kwargs):
-        self.operator = operator
+        
+class ClauseList(ClauseElement):
+    """describes a list of clauses.  by default, is comma-separated, 
+    such as a column listing."""
+    def __init__(self, *clauses, **kwargs):
         self.clauses = []
-        self.parens = False
         for c in clauses:
             if c is None: continue
             self.append(c)
-    
+        self.parens = kwargs.get('parens', False)
     def copy_container(self):
         clauses = [clause.copy_container() for clause in self.clauses]
-        return CompoundClause(self.operator, *clauses)
-        
+        return ClauseList(parens=self.parens, *clauses)
     def append(self, clause):
         if _is_literal(clause):
             clause = TextClause(str(clause))
-        elif isinstance(clause, CompoundClause):
-            clause.parens = True
         self.clauses.append(clause)
+    def accept_visitor(self, visitor):
+        for c in self.clauses:
+            c.accept_visitor(visitor)
+        visitor.visit_clauselist(self)
+    def _get_from_objects(self):
+        return []
 
+class CompoundClause(ClauseList):
+    """represents a list of clauses joined by an operator, such as AND or OR.  
+    extends ClauseList to add the operator as well as a from_objects accessor to 
+    help determine FROM objects in a SELECT statement."""
+    def __init__(self, operator, *clauses, **kwargs):
+        ClauseList.__init__(self, *clauses, **kwargs)
+        self.operator = operator
+    def copy_container(self):
+        clauses = [clause.copy_container() for clause in self.clauses]
+        return CompoundClause(self.operator, *clauses)
+    def append(self, clause):
+        if isinstance(clause, CompoundClause):
+            clause.parens = True
+        ClauseList.append(self, clause)
     def accept_visitor(self, visitor):
         for c in self.clauses:
             c.accept_visitor(visitor)
         visitor.visit_compound(self)
-
     def _get_from_objects(self):
         f = []
         for c in self.clauses:
             f += c._get_from_objects()
         return f
-        
     def hash_key(self):
         return string.join([c.hash_key() for c in self.clauses], self.operator)
-        
-class ClauseList(ClauseElement):
-    def __init__(self, *clauses, **kwargs):
-        self.clauses = clauses
-        self.parens = kwargs.get('parens', False)
-        
+
+class Function(ClauseList, CompareMixin):
+    """describes a SQL function. extends ClauseList to provide comparison operators."""
+    def __init__(self, name, *clauses, **kwargs):
+        ClauseList.__init__(self, parens=True, *clauses)
+        self.name = name
+        self.type = kwargs.get('type', None)
+        self.label = kwargs.get('label', None)
+    columns = property(lambda self: [self])
+    key = property(lambda self:self.label or self.name)
+    def copy_container(self):
+        return self
     def accept_visitor(self, visitor):
         for c in self.clauses:
             c.accept_visitor(visitor)
-        visitor.visit_clauselist(self)
-    
-    def _get_from_objects(self):
-        return []
+        visitor.visit_function(self)
+    def _compare(self, operator, obj):
+        if _is_literal(obj):
+            if obj is None:
+                if operator != '=':
+                    raise "Only '=' operator can be used with NULL"
+                return BinaryClause(self, null(), 'IS')
+            else:
+                obj = BindParamClause(self.name, obj, shortname=self.name, type=self.type)
+
+        return BinaryClause(self, obj, operator)
+    def _make_proxy(self, selectable, name = None):
+        return self
+
         
 class BinaryClause(ClauseElement):
     """represents two clauses with an operator in between"""
@@ -654,10 +700,8 @@ class ColumnImpl(Selectable, CompareMixin):
         
         if column.table.name:
             self.label = column.table.name + "_" + self.column.name
-            self.fullname = column.table.name + "." + self.column.name
         else:
             self.label = self.column.name
-            self.fullname = self.column.name
 
     engine = property(lambda s: s.column.engine)
     
@@ -721,8 +765,8 @@ class TableImpl(Selectable):
     def alias(self, name):
         return Alias(self.table, name)
             
-    def select(self, whereclauses = None, **params):
-        return select([self.table], whereclauses, **params)
+    def select(self, whereclause = None, **params):
+        return select([self.table], whereclause, **params)
 
     def insert(self, values = None):
         return insert(self.table, values=values)
@@ -748,13 +792,14 @@ class TableImpl(Selectable):
 class Select(Selectable):
     """finally, represents a SELECT statement, with appendable clauses, as well as 
     the ability to execute itself and return a result set."""
-    def __init__(self, columns, whereclause = None, from_obj = [], group_by = None, order_by = None, use_labels = False, distinct=False, engine = None):
+    def __init__(self, columns, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None):
         self.columns = util.OrderedProperties()
         self._froms = util.OrderedDict()
         self.use_labels = use_labels
         self.id = "Select(%d)" % id(self)
         self.name = None
         self.whereclause = None
+        self.having = None
         self._engine = engine
         self.rowid_column = None
         
@@ -777,16 +822,17 @@ class Select(Selectable):
             
         if whereclause is not None:
             self.append_whereclause(whereclause)
-
+        if having is not None:
+            self.append_having(having)
+            
         for f in from_obj:
             self.append_from(f)
 
-        if group_by:
-            self.append_clause("GROUP_BY", group_by)
-
         if order_by:
             self.order_by(*order_by)
-
+        if group_by:
+            self.group_by(*group_by)
+            
     class CorrelatedVisitor(ClauseVisitor):
         """visits a clause, locates any Select clauses, and tells them that they should
         correlate their FROM list to that of their parent."""
@@ -820,17 +866,19 @@ class Select(Selectable):
                 co._make_proxy(self)
 
     def append_whereclause(self, whereclause):
-        if type(whereclause) == str:
-            whereclause = TextClause(whereclause)
-
-        whereclause.accept_visitor(self._wherecorrelator)
-        whereclause._process_from_dict(self._froms, False)
-        
-        if self.whereclause is not None:
-            self.whereclause = and_(self.whereclause, whereclause)
+        self._append_condition('whereclause', whereclause)
+    def append_having(self, having):
+        self._append_condition('having', having)
+    def _append_condition(self, attribute, condition):
+        if type(condition) == str:
+            condition = TextClause(condition)
+        condition.accept_visitor(self._wherecorrelator)
+        condition._process_from_dict(self._froms, False)
+        if getattr(self, attribute) is not None:
+            setattr(self, attribute, and_(getattr(self, attribute), condition))
         else:
-            self.whereclause = whereclause
-
+            setattr(self, attribute, condition)
+    
     def clear_from(self, id):
         self.append_from(FromClause(from_name = None, from_key = id))
         
@@ -844,7 +892,6 @@ class Select(Selectable):
     def append_clause(self, keyword, clause):
         if type(clause) == str:
             clause = TextClause(clause)
-        
         self._clauses.append((keyword, clause))
         
     def compile(self, engine = None, bindparams = None):
@@ -852,7 +899,6 @@ class Select(Selectable):
             engine = self.engine
         if engine is None:
             raise "no engine supplied, and no engine could be located within the clauses!"
-
         return engine.compile(self, bindparams)
 
     def _get_froms(self):
@@ -864,18 +910,25 @@ class Select(Selectable):
             f.accept_visitor(visitor)
         if self.whereclause is not None:
             self.whereclause.accept_visitor(visitor)
+        if self.having is not None:
+            self.having.accept_visitor(visitor)
         for tup in self._clauses:
             tup[1].accept_visitor(visitor)
             
         visitor.visit_select(self)
     
     def order_by(self, *clauses):
-        if not hasattr(self, 'order_by_clause'):
-            self.order_by_clause = ClauseList(*clauses)
-            self.append_clause("ORDER BY", self.order_by_clause)
+        self._append_clause('order_by_clause', "ORDER BY", *clauses)
+    def group_by(self, *clauses):
+        self._append_clause('group_by_clause', "GROUP BY", *clauses)
+    def _append_clause(self, attribute, prefix, *clauses):
+        if not hasattr(self, attribute):
+            l = ClauseList(*clauses)
+            setattr(self, attribute, l)
+            self.append_clause(prefix, l)
         else:
-            self.order_by_clause.clauses += clauses
-        
+            getattr(self, attribute).clauses  += clauses
+                
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)