]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
rick morrison's CASE statement + unit test
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 15 May 2006 23:47:07 +0000 (23:47 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 15 May 2006 23:47:07 +0000 (23:47 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/sql.py
test/alltests.py
test/case_statement.py [new file with mode: 0644]

index 602da58eef973414b34dbf605ee6821c71f90e68..df3f8fa59ae5ae54b96f246bdfa176f054ecdd4d 100644 (file)
@@ -224,7 +224,13 @@ class ANSICompiler(sql.Compiled):
 
     def apply_function_parens(self, func):
         return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
-        
+
+    def visit_calculatedclause(self, list):
+        if list.parens:
+            self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ' ') + ")"
+        else:
+            self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ' ')
+       
     def visit_function(self, func):
         if len(self.select_stack):
             self.typemap.setdefault(func.name, func.type)
index 5edfe6f1faa9ad69ff894867ba59b9a5549a0f53..fc0346b855f83cae4758553d0316a2046cf31a75 100644 (file)
@@ -1,4 +1,3 @@
-# sql.py
 # Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
 #
 # This module is part of SQLAlchemy and is released under
@@ -13,7 +12,7 @@ from exceptions import *
 import string, re, random
 types = __import__('types')
 
-__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
+__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
 
 def desc(column):
     """returns a descending ORDER BY clause element, e.g.:
@@ -132,6 +131,17 @@ def between_(ctest, cleft, cright):
     """ returns BETWEEN predicate clause (clausetest BETWEEN clauseleft AND clauseright) """
     return BooleanExpression(ctest, and_(cleft, cright), 'BETWEEN')
 between = between_
+
+def case(whens, value=None, else_=None):
+    """ SQL CASE statement -- whens are a sequence of pairs to be translated into "when / then" clauses;
+        optional [value] for simple case statements, and [else_] for case defaults """
+    whenlist = [CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens]
+    if else_:
+        whenlist.append(CompoundClause(None, 'ELSE', else_))
+    cc = CalculatedClause(None, 'CASE', value, *whenlist + ['END'])
+    for c in cc.clauses:
+        c.parens = False
+    return cc
    
 def cast(clause, totype, **kwargs):
     """ returns CAST function CAST(clause AS totype) 
@@ -295,6 +305,7 @@ class ClauseVisitor(object):
     def visit_join(self, join):pass
     def visit_null(self, null):pass
     def visit_clauselist(self, list):pass
+    def visit_calculatedclause(self, calcclause):pass
     def visit_function(self, func):pass
     def visit_label(self, label):pass
     def visit_typeclause(self, typeclause):pass
@@ -831,9 +842,42 @@ class CompoundClause(ClauseList):
             return self.operator == other.operator
         else:
             return False
+
+class CalculatedClause(ClauseList, ColumnElement):
+    """ describes a calculated SQL expression that has a type, like CASE.  extends ColumnElement to
+    provide column-level comparison operators.  """
+    def __init__(self, name, *clauses, **kwargs):
+        self.name = name
+        self.type = sqltypes.to_instance(kwargs.get('type', None))
+        self._engine = kwargs.get('engine', None)
+        ClauseList.__init__(self, *clauses)
+    key = property(lambda self:self.name or "_calc_")
+    def _process_from_dict(self, data, asfrom):
+        super(CalculatedClause, self)._process_from_dict(data, asfrom)
+        # this helps a Select object get the engine from us
+        data.setdefault(self, self)
+    def copy_container(self):
+        clauses = [clause.copy_container() for clause in self.clauses]
+        return CalculatedClause(type=self.type, engine=self._engine, *clauses)
+    def accept_visitor(self, visitor):
+        for c in self.clauses:
+            c.accept_visitor(visitor)
+        visitor.visit_calculatedclause(self)
+    def _bind_param(self, obj):
+        return BindParamClause(self.name, obj, type=self.type)
+    def select(self):
+        return select([self])
+    def scalar(self):
+        return select([self]).scalar()
+    def execute(self):
+        return select([self]).execute()
+    def _compare_type(self, obj):
+        return self.type
+
                 
-class Function(ClauseList, ColumnElement):
-    """describes a SQL function. extends ClauseList to provide comparison operators."""
+class Function(CalculatedClause):
+    """describes a SQL function. extends CalculatedClause turn the "clauselist" into function
+    arguments, also adds a "packagenames" argument"""
     def __init__(self, name, *clauses, **kwargs):
         self.name = name
         self.type = sqltypes.to_instance(kwargs.get('type', None))
@@ -848,10 +892,6 @@ class Function(ClauseList, ColumnElement):
             else:
                 clause = BindParamClause(self.name, clause, shortname=self.name, type=None)
         self.clauses.append(clause)
-    def _process_from_dict(self, data, asfrom):
-        super(Function, self)._process_from_dict(data, asfrom)
-        # this helps a Select object get the engine from us
-        data.setdefault(self, self)
     def copy_container(self):
         clauses = [clause.copy_container() for clause in self.clauses]
         return Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses)
@@ -859,16 +899,7 @@ class Function(ClauseList, ColumnElement):
         for c in self.clauses:
             c.accept_visitor(visitor)
         visitor.visit_function(self)
-    def _bind_param(self, obj):
-        return BindParamClause(self.name, obj, shortname=self.name, type=self.type)
-    def select(self):
-        return select([self])
-    def scalar(self):
-        return select([self]).scalar()
-    def execute(self):
-        return select([self]).execute()
-    def _compare_type(self, obj):
-        return self.type
+
 
 class FunctionGenerator(object):
     """generates Function objects based on getattr calls"""
index 4e9c73c2c22bb218bc7cb74b529d2a4950024746..3595edd7ed001a3a7e8bbf10fb9c1ddc1518eda8 100644 (file)
@@ -24,6 +24,7 @@ def suite():
         # SQL syntax
         'select',
         'selectable',
+        'case_statement', 
         
         # assorted round-trip tests
         'query',
diff --git a/test/case_statement.py b/test/case_statement.py
new file mode 100644 (file)
index 0000000..fc0e919
--- /dev/null
@@ -0,0 +1,59 @@
+import sys
+import testbase
+from sqlalchemy import *
+
+
+class CaseTest(testbase.PersistTest):
+
+    def setUpAll(self):
+        global info_table
+        info_table = Table('infos', testbase.db,
+               Column('pk', Integer, primary_key=True),
+               Column('info', String))
+
+        info_table.create()
+
+        info_table.insert().execute(
+               {'pk':1, 'info':'pk_1_data'},
+               {'pk':2, 'info':'pk_2_data'},
+               {'pk':3, 'info':'pk_3_data'},
+               {'pk':4, 'info':'pk_4_data'},
+           {'pk':5, 'info':'pk_5_data'})
+    def tearDownAll(self):
+        info_table.drop()
+    
+    def testcase(self):
+        inner = select([case([[info_table.c.pk < 3, literal('lessthan3', type=String)],
+               [info_table.c.pk >= 3, literal('gt3', type=String)]]).label('x'),
+               info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner')
+
+        inner_result = inner.execute().fetchall()
+
+        # Outputs:
+        # lessthan3 1 pk_1_data
+        # lessthan3 2 pk_2_data
+        # gt3 3 pk_3_data
+        # gt3 4 pk_4_data
+        # gt3 5 pk_5_data
+        assert inner_result == [
+            ('lessthan3', 1, 'pk_1_data'),
+            ('lessthan3', 2, 'pk_2_data'),
+            ('gt3', 3, 'pk_3_data'),
+            ('gt3', 4, 'pk_4_data'),
+            ('gt3', 5, 'pk_5_data'),
+        ]
+
+        outer = select([inner])
+
+        outer_result = outer.execute().fetchall()
+
+        assert outer_result == [
+            ('lessthan3', 1, 'pk_1_data'),
+            ('lessthan3', 2, 'pk_2_data'),
+            ('gt3', 3, 'pk_3_data'),
+            ('gt3', 4, 'pk_4_data'),
+            ('gt3', 5, 'pk_5_data'),
+        ]
+
+if __name__ == "__main__":
+    testbase.main()