From 79555fb434660bc4b317b5d384120deaa2dc9b60 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 15 May 2006 23:47:07 +0000 Subject: [PATCH] rick morrison's CASE statement + unit test --- lib/sqlalchemy/ansisql.py | 8 ++++- lib/sqlalchemy/sql.py | 67 ++++++++++++++++++++++++++++----------- test/alltests.py | 1 + test/case_statement.py | 59 ++++++++++++++++++++++++++++++++++ 4 files changed, 116 insertions(+), 19 deletions(-) create mode 100644 test/case_statement.py diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 602da58eef..df3f8fa59a 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -224,7 +224,13 @@ class ANSICompiler(sql.Compiled): def apply_function_parens(self, func): return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 - + + def visit_calculatedclause(self, list): + if list.parens: + self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ' ') + ")" + else: + self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ' ') + def visit_function(self, func): if len(self.select_stack): self.typemap.setdefault(func.name, func.type) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 5edfe6f1fa..fc0346b855 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -1,4 +1,3 @@ -# sql.py # Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under @@ -13,7 +12,7 @@ from exceptions import * import string, re, random types = __import__('types') -__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] +__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'between_', 'case', 'cast', 'union', 'union_all', 'null', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists'] def desc(column): """returns a descending ORDER BY clause element, e.g.: @@ -132,6 +131,17 @@ def between_(ctest, cleft, cright): """ returns BETWEEN predicate clause (clausetest BETWEEN clauseleft AND clauseright) """ return BooleanExpression(ctest, and_(cleft, cright), 'BETWEEN') between = between_ + +def case(whens, value=None, else_=None): + """ SQL CASE statement -- whens are a sequence of pairs to be translated into "when / then" clauses; + optional [value] for simple case statements, and [else_] for case defaults """ + whenlist = [CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens] + if else_: + whenlist.append(CompoundClause(None, 'ELSE', else_)) + cc = CalculatedClause(None, 'CASE', value, *whenlist + ['END']) + for c in cc.clauses: + c.parens = False + return cc def cast(clause, totype, **kwargs): """ returns CAST function CAST(clause AS totype) @@ -295,6 +305,7 @@ class ClauseVisitor(object): def visit_join(self, join):pass def visit_null(self, null):pass def visit_clauselist(self, list):pass + def visit_calculatedclause(self, calcclause):pass def visit_function(self, func):pass def visit_label(self, label):pass def visit_typeclause(self, typeclause):pass @@ -831,9 +842,42 @@ class CompoundClause(ClauseList): return self.operator == other.operator else: return False + +class CalculatedClause(ClauseList, ColumnElement): + """ describes a calculated SQL expression that has a type, like CASE. extends ColumnElement to + provide column-level comparison operators. """ + def __init__(self, name, *clauses, **kwargs): + self.name = name + self.type = sqltypes.to_instance(kwargs.get('type', None)) + self._engine = kwargs.get('engine', None) + ClauseList.__init__(self, *clauses) + key = property(lambda self:self.name or "_calc_") + def _process_from_dict(self, data, asfrom): + super(CalculatedClause, self)._process_from_dict(data, asfrom) + # this helps a Select object get the engine from us + data.setdefault(self, self) + def copy_container(self): + clauses = [clause.copy_container() for clause in self.clauses] + return CalculatedClause(type=self.type, engine=self._engine, *clauses) + def accept_visitor(self, visitor): + for c in self.clauses: + c.accept_visitor(visitor) + visitor.visit_calculatedclause(self) + def _bind_param(self, obj): + return BindParamClause(self.name, obj, type=self.type) + def select(self): + return select([self]) + def scalar(self): + return select([self]).scalar() + def execute(self): + return select([self]).execute() + def _compare_type(self, obj): + return self.type + -class Function(ClauseList, ColumnElement): - """describes a SQL function. extends ClauseList to provide comparison operators.""" +class Function(CalculatedClause): + """describes a SQL function. extends CalculatedClause turn the "clauselist" into function + arguments, also adds a "packagenames" argument""" def __init__(self, name, *clauses, **kwargs): self.name = name self.type = sqltypes.to_instance(kwargs.get('type', None)) @@ -848,10 +892,6 @@ class Function(ClauseList, ColumnElement): else: clause = BindParamClause(self.name, clause, shortname=self.name, type=None) self.clauses.append(clause) - def _process_from_dict(self, data, asfrom): - super(Function, self)._process_from_dict(data, asfrom) - # this helps a Select object get the engine from us - data.setdefault(self, self) def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] return Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses) @@ -859,16 +899,7 @@ class Function(ClauseList, ColumnElement): for c in self.clauses: c.accept_visitor(visitor) visitor.visit_function(self) - def _bind_param(self, obj): - return BindParamClause(self.name, obj, shortname=self.name, type=self.type) - def select(self): - return select([self]) - def scalar(self): - return select([self]).scalar() - def execute(self): - return select([self]).execute() - def _compare_type(self, obj): - return self.type + class FunctionGenerator(object): """generates Function objects based on getattr calls""" diff --git a/test/alltests.py b/test/alltests.py index 4e9c73c2c2..3595edd7ed 100644 --- a/test/alltests.py +++ b/test/alltests.py @@ -24,6 +24,7 @@ def suite(): # SQL syntax 'select', 'selectable', + 'case_statement', # assorted round-trip tests 'query', diff --git a/test/case_statement.py b/test/case_statement.py new file mode 100644 index 0000000000..fc0e919a5b --- /dev/null +++ b/test/case_statement.py @@ -0,0 +1,59 @@ +import sys +import testbase +from sqlalchemy import * + + +class CaseTest(testbase.PersistTest): + + def setUpAll(self): + global info_table + info_table = Table('infos', testbase.db, + Column('pk', Integer, primary_key=True), + Column('info', String)) + + info_table.create() + + info_table.insert().execute( + {'pk':1, 'info':'pk_1_data'}, + {'pk':2, 'info':'pk_2_data'}, + {'pk':3, 'info':'pk_3_data'}, + {'pk':4, 'info':'pk_4_data'}, + {'pk':5, 'info':'pk_5_data'}) + def tearDownAll(self): + info_table.drop() + + def testcase(self): + inner = select([case([[info_table.c.pk < 3, literal('lessthan3', type=String)], + [info_table.c.pk >= 3, literal('gt3', type=String)]]).label('x'), + info_table.c.pk, info_table.c.info], from_obj=[info_table]).alias('q_inner') + + inner_result = inner.execute().fetchall() + + # Outputs: + # lessthan3 1 pk_1_data + # lessthan3 2 pk_2_data + # gt3 3 pk_3_data + # gt3 4 pk_4_data + # gt3 5 pk_5_data + assert inner_result == [ + ('lessthan3', 1, 'pk_1_data'), + ('lessthan3', 2, 'pk_2_data'), + ('gt3', 3, 'pk_3_data'), + ('gt3', 4, 'pk_4_data'), + ('gt3', 5, 'pk_5_data'), + ] + + outer = select([inner]) + + outer_result = outer.execute().fetchall() + + assert outer_result == [ + ('lessthan3', 1, 'pk_1_data'), + ('lessthan3', 2, 'pk_2_data'), + ('gt3', 3, 'pk_3_data'), + ('gt3', 4, 'pk_4_data'), + ('gt3', 5, 'pk_5_data'), + ] + +if __name__ == "__main__": + testbase.main() -- 2.47.2