-# sql.py
# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
import string, re, random
types = __import__('types')
-__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
+__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
def desc(column):
"""returns a descending ORDER BY clause element, e.g.:
""" returns BETWEEN predicate clause (clausetest BETWEEN clauseleft AND clauseright) """
return BooleanExpression(ctest, and_(cleft, cright), 'BETWEEN')
between = between_
+
+def case(whens, value=None, else_=None):
+ """ SQL CASE statement -- whens are a sequence of pairs to be translated into "when / then" clauses;
+ optional [value] for simple case statements, and [else_] for case defaults """
+ whenlist = [CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens]
+ if else_:
+ whenlist.append(CompoundClause(None, 'ELSE', else_))
+ cc = CalculatedClause(None, 'CASE', value, *whenlist + ['END'])
+ for c in cc.clauses:
+ c.parens = False
+ return cc
def cast(clause, totype, **kwargs):
""" returns CAST function CAST(clause AS totype)
def visit_join(self, join):pass
def visit_null(self, null):pass
def visit_clauselist(self, list):pass
+ def visit_calculatedclause(self, calcclause):pass
def visit_function(self, func):pass
def visit_label(self, label):pass
def visit_typeclause(self, typeclause):pass
return self.operator == other.operator
else:
return False
+
+class CalculatedClause(ClauseList, ColumnElement):
+ """ describes a calculated SQL expression that has a type, like CASE. extends ColumnElement to
+ provide column-level comparison operators. """
+ def __init__(self, name, *clauses, **kwargs):
+ self.name = name
+ self.type = sqltypes.to_instance(kwargs.get('type', None))
+ self._engine = kwargs.get('engine', None)
+ ClauseList.__init__(self, *clauses)
+ key = property(lambda self:self.name or "_calc_")
+ def _process_from_dict(self, data, asfrom):
+ super(CalculatedClause, self)._process_from_dict(data, asfrom)
+ # this helps a Select object get the engine from us
+ data.setdefault(self, self)
+ def copy_container(self):
+ clauses = [clause.copy_container() for clause in self.clauses]
+ return CalculatedClause(type=self.type, engine=self._engine, *clauses)
+ def accept_visitor(self, visitor):
+ for c in self.clauses:
+ c.accept_visitor(visitor)
+ visitor.visit_calculatedclause(self)
+ def _bind_param(self, obj):
+ return BindParamClause(self.name, obj, type=self.type)
+ def select(self):
+ return select([self])
+ def scalar(self):
+ return select([self]).scalar()
+ def execute(self):
+ return select([self]).execute()
+ def _compare_type(self, obj):
+ return self.type
+
-class Function(ClauseList, ColumnElement):
- """describes a SQL function. extends ClauseList to provide comparison operators."""
+class Function(CalculatedClause):
+ """describes a SQL function. extends CalculatedClause turn the "clauselist" into function
+ arguments, also adds a "packagenames" argument"""
def __init__(self, name, *clauses, **kwargs):
self.name = name
self.type = sqltypes.to_instance(kwargs.get('type', None))
else:
clause = BindParamClause(self.name, clause, shortname=self.name, type=None)
self.clauses.append(clause)
- def _process_from_dict(self, data, asfrom):
- super(Function, self)._process_from_dict(data, asfrom)
- # this helps a Select object get the engine from us
- data.setdefault(self, self)
def copy_container(self):
clauses = [clause.copy_container() for clause in self.clauses]
return Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses)
for c in self.clauses:
c.accept_visitor(visitor)
visitor.visit_function(self)
- def _bind_param(self, obj):
- return BindParamClause(self.name, obj, shortname=self.name, type=self.type)
- def select(self):
- return select([self])
- def scalar(self):
- return select([self]).scalar()
- def execute(self):
- return select([self]).execute()
- def _compare_type(self, obj):
- return self.type
+
class FunctionGenerator(object):
"""generates Function objects based on getattr calls"""
--- /dev/null
+import sys
+import testbase
+from sqlalchemy import *
+
+
+class CaseTest(testbase.PersistTest):
+
+ def setUpAll(self):
+ global info_table
+ info_table = Table('infos', testbase.db,
+ Column('pk', Integer, primary_key=True),
+ Column('info', String))
+
+ info_table.create()
+
+ info_table.insert().execute(
+ {'pk':1, 'info':'pk_1_data'},
+ {'pk':2, 'info':'pk_2_data'},
+ {'pk':3, 'info':'pk_3_data'},
+ {'pk':4, 'info':'pk_4_data'},
+ {'pk':5, 'info':'pk_5_data'})
+ def tearDownAll(self):
+ info_table.drop()
+
+ def testcase(self):
+ inner = select([case([[info_table.c.pk < 3, literal('lessthan3', type=String)],
+ [info_table.c.pk >= 3, literal('gt3', type=String)]]).label('x'),
+ info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner')
+
+ inner_result = inner.execute().fetchall()
+
+ # Outputs:
+ # lessthan3 1 pk_1_data
+ # lessthan3 2 pk_2_data
+ # gt3 3 pk_3_data
+ # gt3 4 pk_4_data
+ # gt3 5 pk_5_data
+ assert inner_result == [
+ ('lessthan3', 1, 'pk_1_data'),
+ ('lessthan3', 2, 'pk_2_data'),
+ ('gt3', 3, 'pk_3_data'),
+ ('gt3', 4, 'pk_4_data'),
+ ('gt3', 5, 'pk_5_data'),
+ ]
+
+ outer = select([inner])
+
+ outer_result = outer.execute().fetchall()
+
+ assert outer_result == [
+ ('lessthan3', 1, 'pk_1_data'),
+ ('lessthan3', 2, 'pk_2_data'),
+ ('gt3', 3, 'pk_3_data'),
+ ('gt3', 4, 'pk_4_data'),
+ ('gt3', 5, 'pk_5_data'),
+ ]
+
+if __name__ == "__main__":
+ testbase.main()