From: Mike Bayer Date: Tue, 29 Nov 2005 06:43:23 +0000 (+0000) Subject: added group_by, having to select. added func.foo(a, b) keyword to express functions... X-Git-Tag: rel_0_1_0~287 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=92dc0d0dbda52a94911f858429a41b09b728ad52;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git added group_by, having to select. added func.foo(a, b) keyword to express functions within column lists and criterion lists --- diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 55550bfa88..4d65a70cea 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -133,7 +133,10 @@ class ANSICompiler(sql.Compiled): self.strings[column] = "%s.%s" % (column.table.name, column.name) def visit_columnclause(self, column): - self.strings[column] = "%s.%s" % (column.table.name, column.text) + if column.table is not None and column.table.name is not None: + self.strings[column] = "%s.%s" % (column.table.name, column.text) + else: + self.strings[column] = column.text def visit_fromclause(self, fromclause): self.froms[fromclause] = fromclause.from_name @@ -143,7 +146,8 @@ class ANSICompiler(sql.Compiled): self.strings[textclause] = "(" + textclause.text + ")" else: self.strings[textclause] = textclause.text - + self.froms[textclause] = textclause.text + def visit_null(self, null): self.strings[null] = 'NULL' @@ -163,6 +167,9 @@ class ANSICompiler(sql.Compiled): self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ', ') + ")" else: self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ') + + def visit_function(self, func): + self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")" def visit_binary(self, binary): result = self.get_str(binary.left) @@ -198,6 +205,7 @@ class ANSICompiler(sql.Compiled): for c in select._raw_columns: for co in c.columns: + co.accept_visitor(self) inner_columns.append(co) if select.use_labels: self.typemap.setdefault(co.label, co.type) @@ -205,9 +213,9 @@ class ANSICompiler(sql.Compiled): self.typemap.setdefault(co.key, co.type) if select.use_labels: - collist = string.join(["%s AS %s" % (c.fullname, c.label) for c in inner_columns], ', ') + collist = string.join(["%s AS %s" % (self.get_str(c), c.label) for c in inner_columns], ', ') else: - collist = string.join([c.fullname for c in inner_columns], ', ') + collist = string.join([self.get_str(c) for c in inner_columns], ', ') text = "SELECT " if select.distinct: @@ -240,6 +248,11 @@ class ANSICompiler(sql.Compiled): for tup in select._clauses: text += " " + tup[0] + " " + self.get_str(tup[1]) + if select.having is not None: + t = self.get_str(select.having) + if t: + text += " \nHAVING " + t + if getattr(select, 'issubquery', False): self.strings[select] = "(" + text + ")" else: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 5b68ff1bdd..83ac01a23e 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -23,7 +23,7 @@ import sqlalchemy.util as util import sqlalchemy.types as types import string, re -__ALL__ = ['textclause', 'select', 'join', 'and_', 'or_', 'not_', 'union', 'unionall', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'bindparam', 'sequence'] +__ALL__ = ['text', 'column', 'func', 'select', 'join', 'and_', 'or_', 'not_', 'union', 'unionall', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'bindparam', 'sequence'] def desc(column): """returns a descending ORDER BY clause element, e.g.: @@ -122,7 +122,11 @@ def or_(*clauses): def not_(clause): clause.parens=True return BinaryClause(TextClause("NOT"), clause, None) - + + +def column(table, text): + return ColumnClause(text, table) + def exists(*args, **params): s = select(*args, **params) return BinaryClause(TextClause("EXISTS"), s, None) @@ -154,6 +158,13 @@ def null(): def sequence(): return Sequence() +class FunctionGateway(object): + """returns a callable based on an attribute name, which then returns a Function + object with that name.""" + def __getattr__(self, name): + return lambda *c, **kwargs: Function(name, *c, **kwargs) +func = FunctionGateway() + def _compound_clause(keyword, *clauses): return CompoundClause(keyword, *clauses) @@ -187,6 +198,7 @@ class ClauseVisitor(schema.SchemaVisitor): def visit_join(self, join):pass def visit_null(self, null):pass def visit_clauselist(self, list):pass + def visit_function(self, func):pass class Compiled(ClauseVisitor): """represents a compiled SQL expression. the __str__ method of the Compiled object @@ -345,7 +357,7 @@ class CompareMixin(object): class ColumnClause(ClauseElement, CompareMixin): """represents a textual column clause in a SQL statement.""" - def __init__(self, text, selectable): + def __init__(self, text, selectable=None): self.text = text self.table = selectable self._impl = ColumnImpl(self) @@ -355,13 +367,15 @@ class ColumnClause(ClauseElement, CompareMixin): name = property(lambda self:self.text) key = property(lambda self:self.text) label = property(lambda self:self.text) - fullname = property(lambda self:self.text) def accept_visitor(self, visitor): visitor.visit_columnclause(self) def hash_key(self): - return "ColumnClause(%s, %s)" % (self.text, self.table.hash_key()) + if self.table is not None: + return "ColumnClause(%s, %s)" % (self.text, self.table.hash_key()) + else: + return "ColumnClause(%s)" % self.text def _get_from_objects(self): return [] @@ -431,6 +445,7 @@ class TextClause(ClauseElement): self.text = text self.parens = False self.engine = engine + self.id = id(self) if isliteral: if isinstance(text, int) or isinstance(text, long): self.text = str(text) @@ -452,53 +467,84 @@ class Null(ClauseElement): def hash_key(self): return "Null" -class CompoundClause(ClauseElement): - """represents a list of clauses joined by an operator""" - def __init__(self, operator, *clauses, **kwargs): - self.operator = operator + +class ClauseList(ClauseElement): + """describes a list of clauses. by default, is comma-separated, + such as a column listing.""" + def __init__(self, *clauses, **kwargs): self.clauses = [] - self.parens = False for c in clauses: if c is None: continue self.append(c) - + self.parens = kwargs.get('parens', False) def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] - return CompoundClause(self.operator, *clauses) - + return ClauseList(parens=self.parens, *clauses) def append(self, clause): if _is_literal(clause): clause = TextClause(str(clause)) - elif isinstance(clause, CompoundClause): - clause.parens = True self.clauses.append(clause) + def accept_visitor(self, visitor): + for c in self.clauses: + c.accept_visitor(visitor) + visitor.visit_clauselist(self) + def _get_from_objects(self): + return [] +class CompoundClause(ClauseList): + """represents a list of clauses joined by an operator, such as AND or OR. + extends ClauseList to add the operator as well as a from_objects accessor to + help determine FROM objects in a SELECT statement.""" + def __init__(self, operator, *clauses, **kwargs): + ClauseList.__init__(self, *clauses, **kwargs) + self.operator = operator + def copy_container(self): + clauses = [clause.copy_container() for clause in self.clauses] + return CompoundClause(self.operator, *clauses) + def append(self, clause): + if isinstance(clause, CompoundClause): + clause.parens = True + ClauseList.append(self, clause) def accept_visitor(self, visitor): for c in self.clauses: c.accept_visitor(visitor) visitor.visit_compound(self) - def _get_from_objects(self): f = [] for c in self.clauses: f += c._get_from_objects() return f - def hash_key(self): return string.join([c.hash_key() for c in self.clauses], self.operator) - -class ClauseList(ClauseElement): - def __init__(self, *clauses, **kwargs): - self.clauses = clauses - self.parens = kwargs.get('parens', False) - + +class Function(ClauseList, CompareMixin): + """describes a SQL function. extends ClauseList to provide comparison operators.""" + def __init__(self, name, *clauses, **kwargs): + ClauseList.__init__(self, parens=True, *clauses) + self.name = name + self.type = kwargs.get('type', None) + self.label = kwargs.get('label', None) + columns = property(lambda self: [self]) + key = property(lambda self:self.label or self.name) + def copy_container(self): + return self def accept_visitor(self, visitor): for c in self.clauses: c.accept_visitor(visitor) - visitor.visit_clauselist(self) - - def _get_from_objects(self): - return [] + visitor.visit_function(self) + def _compare(self, operator, obj): + if _is_literal(obj): + if obj is None: + if operator != '=': + raise "Only '=' operator can be used with NULL" + return BinaryClause(self, null(), 'IS') + else: + obj = BindParamClause(self.name, obj, shortname=self.name, type=self.type) + + return BinaryClause(self, obj, operator) + def _make_proxy(self, selectable, name = None): + return self + class BinaryClause(ClauseElement): """represents two clauses with an operator in between""" @@ -654,10 +700,8 @@ class ColumnImpl(Selectable, CompareMixin): if column.table.name: self.label = column.table.name + "_" + self.column.name - self.fullname = column.table.name + "." + self.column.name else: self.label = self.column.name - self.fullname = self.column.name engine = property(lambda s: s.column.engine) @@ -721,8 +765,8 @@ class TableImpl(Selectable): def alias(self, name): return Alias(self.table, name) - def select(self, whereclauses = None, **params): - return select([self.table], whereclauses, **params) + def select(self, whereclause = None, **params): + return select([self.table], whereclause, **params) def insert(self, values = None): return insert(self.table, values=values) @@ -748,13 +792,14 @@ class TableImpl(Selectable): class Select(Selectable): """finally, represents a SELECT statement, with appendable clauses, as well as the ability to execute itself and return a result set.""" - def __init__(self, columns, whereclause = None, from_obj = [], group_by = None, order_by = None, use_labels = False, distinct=False, engine = None): + def __init__(self, columns, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None): self.columns = util.OrderedProperties() self._froms = util.OrderedDict() self.use_labels = use_labels self.id = "Select(%d)" % id(self) self.name = None self.whereclause = None + self.having = None self._engine = engine self.rowid_column = None @@ -777,16 +822,17 @@ class Select(Selectable): if whereclause is not None: self.append_whereclause(whereclause) - + if having is not None: + self.append_having(having) + for f in from_obj: self.append_from(f) - if group_by: - self.append_clause("GROUP_BY", group_by) - if order_by: self.order_by(*order_by) - + if group_by: + self.group_by(*group_by) + class CorrelatedVisitor(ClauseVisitor): """visits a clause, locates any Select clauses, and tells them that they should correlate their FROM list to that of their parent.""" @@ -820,17 +866,19 @@ class Select(Selectable): co._make_proxy(self) def append_whereclause(self, whereclause): - if type(whereclause) == str: - whereclause = TextClause(whereclause) - - whereclause.accept_visitor(self._wherecorrelator) - whereclause._process_from_dict(self._froms, False) - - if self.whereclause is not None: - self.whereclause = and_(self.whereclause, whereclause) + self._append_condition('whereclause', whereclause) + def append_having(self, having): + self._append_condition('having', having) + def _append_condition(self, attribute, condition): + if type(condition) == str: + condition = TextClause(condition) + condition.accept_visitor(self._wherecorrelator) + condition._process_from_dict(self._froms, False) + if getattr(self, attribute) is not None: + setattr(self, attribute, and_(getattr(self, attribute), condition)) else: - self.whereclause = whereclause - + setattr(self, attribute, condition) + def clear_from(self, id): self.append_from(FromClause(from_name = None, from_key = id)) @@ -844,7 +892,6 @@ class Select(Selectable): def append_clause(self, keyword, clause): if type(clause) == str: clause = TextClause(clause) - self._clauses.append((keyword, clause)) def compile(self, engine = None, bindparams = None): @@ -852,7 +899,6 @@ class Select(Selectable): engine = self.engine if engine is None: raise "no engine supplied, and no engine could be located within the clauses!" - return engine.compile(self, bindparams) def _get_froms(self): @@ -864,18 +910,25 @@ class Select(Selectable): f.accept_visitor(visitor) if self.whereclause is not None: self.whereclause.accept_visitor(visitor) + if self.having is not None: + self.having.accept_visitor(visitor) for tup in self._clauses: tup[1].accept_visitor(visitor) visitor.visit_select(self) def order_by(self, *clauses): - if not hasattr(self, 'order_by_clause'): - self.order_by_clause = ClauseList(*clauses) - self.append_clause("ORDER BY", self.order_by_clause) + self._append_clause('order_by_clause', "ORDER BY", *clauses) + def group_by(self, *clauses): + self._append_clause('group_by_clause', "GROUP BY", *clauses) + def _append_clause(self, attribute, prefix, *clauses): + if not hasattr(self, attribute): + l = ClauseList(*clauses) + setattr(self, attribute, l) + self.append_clause(prefix, l) else: - self.order_by_clause.clauses += clauses - + getattr(self, attribute).clauses += clauses + def select(self, whereclauses = None, **params): return select([self], whereclauses, **params)