From 20e7f815b115d0dc5b864cdeda2f4220e34f417a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 4 Dec 2005 02:15:06 +0000 Subject: [PATCH] math operators &|~ boolean operators added 'literal' keyword working on column clauses being more flexible --- lib/sqlalchemy/ansisql.py | 22 +++++++++------ lib/sqlalchemy/schema.py | 8 ++++++ lib/sqlalchemy/sql.py | 59 +++++++++++++++++++++++++++++++++------ test/mapper.py | 2 +- test/objectstore.py | 2 +- test/rundocs.py | 3 +- 6 files changed, 75 insertions(+), 21 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 5490819476..c5de1adfd4 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -187,7 +187,8 @@ class ANSICompiler(sql.Compiled): self.strings[binary] = result def visit_bindparam(self, bindparam): - self.binds[bindparam.shortname] = bindparam + if bindparam.shortname != bindparam.key: + self.binds[bindparam.shortname] = bindparam count = 1 key = bindparam.key @@ -210,13 +211,18 @@ class ANSICompiler(sql.Compiled): inner_columns = [] for c in select._raw_columns: - for co in c.columns: - co.accept_visitor(self) - inner_columns.append(co) - if select.use_labels: - self.typemap.setdefault(co.label, co.type) - else: - self.typemap.setdefault(co.key, co.type) + # TODO: hackish. try to get a more polymorphic approach. + if hasattr(c, 'columns'): + for co in c.columns: + co.accept_visitor(self) + inner_columns.append(co) + if select.use_labels: + self.typemap.setdefault(co.label, co.type) + else: + self.typemap.setdefault(co.key, co.type) + else: + c.accept_visitor(self) + inner_columns.append(c) if select.use_labels: collist = string.join(["%s AS %s" % (self.get_str(c), c.label) for c in inner_columns], ', ') diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index d5dd69dda3..bebf38efed 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -221,6 +221,14 @@ class Column(SchemaItem): def __ne__(self, other): return self._impl.__ne__(other) def __gt__(self, other): return self._impl.__gt__(other) def __ge__(self, other): return self._impl.__ge__(other) + def __add__(self, other): return self._impl.__add__(other) + def __sub__(self, other): return self._impl.__sub__(other) + def __mul__(self, other): return self._impl.__mul__(other) + def __and__(self, other): return self._impl.__and__(other) + def __or__(self, other): return self._impl.__or__(other) + def __div__(self, other): return self._impl.__div__(other) + def __truediv__(self, other): return self._impl.__truediv__(other) + def __invert__(self, other): return self._impl.__invert__(other) def __str__(self): return self._impl.__str__() class ForeignKey(SchemaItem): diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index a030bf4396..4f7090cf7b 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -23,7 +23,7 @@ import sqlalchemy.util as util import sqlalchemy.types as types import string, re -__all__ = ['text', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'bindparam', 'sequence', 'exists'] +__all__ = ['text', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'sequence', 'exists'] def desc(column): """returns a descending ORDER BY clause element, e.g.: @@ -143,6 +143,9 @@ def alias(*args, **params): def subquery(alias, *args, **params): return Alias(Select(*args, **params), alias) +def literal(value, type=None): + return BindParamClause('literal', value, type=type) + def bindparam(key, value = None, type=None): if isinstance(key, schema.Column): return BindParamClause(key.name, value, type=key.type) @@ -323,6 +326,13 @@ class ClauseElement(object): sequences, rowcounts, etc.""" return self.execute(*multiparams, **params).fetchone()[0] + def __and__(self, other): + return and_(self, other) + def __or__(self, other): + return or_(self, other) + def __invert__(self): + return not_(self) + class CompareMixin(object): def __lt__(self, other): return self._compare('<', other) @@ -364,6 +374,28 @@ class CompareMixin(object): def endswith(self, other): return self._compare('LIKE', "%" + str(other)) + # and here come the math operators: + def __add__(self, other): + return self._compare('+', other) + def __sub__(self, other): + return self._compare('-', other) + def __mul__(self, other): + return self._compare('*', other) + def __div__(self, other): + return self._compare('/', other) + def __truediv__(self, other): + return self._compare('/', other) + def _compare(self, operator, obj): + if _is_literal(obj): + if obj is None: + if operator != '=': + raise "Only '=' operator can be used with NULL" + return BinaryClause(self, null(), 'IS') + else: + obj = BindParamClause('literal', obj, shortname=None, type=self.type) + + return BinaryClause(self, obj, operator) + class ColumnClause(ClauseElement, CompareMixin): """represents a textual column clause in a SQL statement.""" @@ -427,7 +459,7 @@ class FromClause(ClauseElement): def accept_visitor(self, visitor): visitor.visit_fromclause(self) -class BindParamClause(ClauseElement): +class BindParamClause(ClauseElement, CompareMixin): def __init__(self, key, value, shortname = None, type = None): self.key = key self.value = value @@ -532,12 +564,19 @@ class CompoundClause(ClauseList): class Function(ClauseList, CompareMixin): """describes a SQL function. extends ClauseList to provide comparison operators.""" def __init__(self, name, *clauses, **kwargs): - ClauseList.__init__(self, parens=True, *clauses) self.name = name self.type = kwargs.get('type', None) self.label = kwargs.get('label', None) + ClauseList.__init__(self, parens=True, *clauses) columns = property(lambda self: [self]) key = property(lambda self:self.label or self.name) + def append(self, clause): + if _is_literal(clause): + if clause is None: + clause = null() + else: + clause = BindParamClause(self.name, clause, shortname=self.name, type=None) + self.clauses.append(clause) def copy_container(self): return self def accept_visitor(self, visitor): @@ -941,13 +980,15 @@ class Select(Selectable, TailClauseMixin): if self.rowid_column is None and hasattr(f, 'rowid_column'): self.rowid_column = f.rowid_column._make_proxy(self) column._process_from_dict(self._froms, False) - - for co in column.columns: - if self.use_labels: - co._make_proxy(self, name = co.label) - else: - co._make_proxy(self) + # TODO: dont use hasattr here, get a general way to locate + # selectable columns off stuff working completely (i.e. Selectable) + if hasattr(column, 'columns'): + for co in column.columns: + if self.use_labels: + co._make_proxy(self, name = co.label) + else: + co._make_proxy(self) def get_col_by_original(self, column): if self.use_labels: diff --git a/test/mapper.py b/test/mapper.py index deaa02c6cd..54284f0c0d 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -43,7 +43,7 @@ class MapperTest(MapperSuperTest): def testmultitable(self): usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) - m = mapper(User, usersaddresses, primarytable = users, primary_keys=[users.c.user_id]) + m = mapper(User, usersaddresses, primarytable = users, primary_key=[users.c.user_id]) l = m.select() self.assert_result(l, User, {'user_id' : 7}, {'user_id' : 8}) diff --git a/test/objectstore.py b/test/objectstore.py index f17b49c123..0942e41eec 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -604,7 +604,7 @@ class SaveTest(AssertMixin): m = mapper(Item, items, properties = dict( keywords = relation(IKAssociation, itemkeywords, lazy = False, properties = dict( keyword = relation(Keyword, keywords, lazy = False, uselist = False) - ), primary_keys = [itemkeywords.c.item_id, itemkeywords.c.keyword_id]) + ), primary_key = [itemkeywords.c.item_id, itemkeywords.c.keyword_id]) )) data = [Item, diff --git a/test/rundocs.py b/test/rundocs.py index b98ff393b2..66a4b9a60b 100644 --- a/test/rundocs.py +++ b/test/rundocs.py @@ -208,7 +208,7 @@ class KeywordAssociation(object):pass # lazy loading for that. m = mapper(Article, articles, properties=dict( keywords = relation(KeywordAssociation, itemkeywords, lazy = False, - primary_keys=[itemkeywords.c.article_id, itemkeywords.c.keyword_id], + primary_key=[itemkeywords.c.article_id, itemkeywords.c.keyword_id], properties=dict( keyword = relation(Keyword, keywords, lazy = False), user = relation(User, users, lazy = True) @@ -236,4 +236,3 @@ for a in alist: if k.keyword.name == 'jacks_stories': print k.user.user_name - -- 2.47.2