]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
math operators
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Dec 2005 02:15:06 +0000 (02:15 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 4 Dec 2005 02:15:06 +0000 (02:15 +0000)
&|~ boolean operators
added 'literal' keyword
working on column clauses being more flexible

lib/sqlalchemy/ansisql.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/mapper.py
test/objectstore.py
test/rundocs.py

index 5490819476a8bd1af297b487cb1d9fcb04c87e48..c5de1adfd4455e2e615676792900ccf4ad4a3323 100644 (file)
@@ -187,7 +187,8 @@ class ANSICompiler(sql.Compiled):
         self.strings[binary] = result
 
     def visit_bindparam(self, bindparam):
-        self.binds[bindparam.shortname] = bindparam
+        if bindparam.shortname != bindparam.key:
+            self.binds[bindparam.shortname] = bindparam
         count = 1
         key = bindparam.key
 
@@ -210,13 +211,18 @@ class ANSICompiler(sql.Compiled):
         inner_columns = []
 
         for c in select._raw_columns:
-            for co in c.columns:
-                co.accept_visitor(self)
-                inner_columns.append(co)
-                if select.use_labels:
-                    self.typemap.setdefault(co.label, co.type)
-                else:
-                    self.typemap.setdefault(co.key, co.type)
+            # TODO:  hackish.  try to get a more polymorphic approach.
+            if hasattr(c, 'columns'):
+                for co in c.columns:
+                    co.accept_visitor(self)
+                    inner_columns.append(co)
+                    if select.use_labels:
+                        self.typemap.setdefault(co.label, co.type)
+                    else:
+                        self.typemap.setdefault(co.key, co.type)
+            else:
+                c.accept_visitor(self)
+                inner_columns.append(c)
                 
         if select.use_labels:
             collist = string.join(["%s AS %s" % (self.get_str(c), c.label) for c in inner_columns], ', ')
index d5dd69dda39902e9d6ba07235662471a652697aa..bebf38efed849540eba5399ff55360513d83e74a 100644 (file)
@@ -221,6 +221,14 @@ class Column(SchemaItem):
     def __ne__(self, other): return self._impl.__ne__(other)
     def __gt__(self, other): return self._impl.__gt__(other)
     def __ge__(self, other): return self._impl.__ge__(other)
+    def __add__(self, other): return self._impl.__add__(other)
+    def __sub__(self, other): return self._impl.__sub__(other)
+    def __mul__(self, other): return self._impl.__mul__(other)
+    def __and__(self, other): return self._impl.__and__(other)
+    def __or__(self, other): return self._impl.__or__(other)
+    def __div__(self, other): return self._impl.__div__(other)
+    def __truediv__(self, other): return self._impl.__truediv__(other)
+    def __invert__(self, other): return self._impl.__invert__(other)
     def __str__(self): return self._impl.__str__()
 
 class ForeignKey(SchemaItem):
index a030bf4396dfb293a199b35f5788335d2f1092da..4f7090cf7bc74c6f010ffc191f0499bf4d65ffa1 100644 (file)
@@ -23,7 +23,7 @@ import sqlalchemy.util as util
 import sqlalchemy.types as types
 import string, re
 
-__all__ = ['text', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'bindparam', 'sequence', 'exists']
+__all__ = ['text', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'sequence', 'exists']
 
 def desc(column):
     """returns a descending ORDER BY clause element, e.g.:
@@ -143,6 +143,9 @@ def alias(*args, **params):
 def subquery(alias, *args, **params):
     return Alias(Select(*args, **params), alias)
 
+def literal(value, type=None):
+    return BindParamClause('literal', value, type=type)
+    
 def bindparam(key, value = None, type=None):
     if isinstance(key, schema.Column):
         return BindParamClause(key.name, value, type=key.type)
@@ -323,6 +326,13 @@ class ClauseElement(object):
         sequences, rowcounts, etc."""
         return self.execute(*multiparams, **params).fetchone()[0]
 
+    def __and__(self, other):
+        return and_(self, other)
+    def __or__(self, other):
+        return or_(self, other)
+    def __invert__(self):
+        return not_(self)
+
 class CompareMixin(object):
     def __lt__(self, other):
         return self._compare('<', other)
@@ -364,6 +374,28 @@ class CompareMixin(object):
     def endswith(self, other):
         return self._compare('LIKE', "%" + str(other))
 
+    # and here come the math operators:
+    def __add__(self, other):
+        return self._compare('+', other)
+    def __sub__(self, other):
+        return self._compare('-', other)
+    def __mul__(self, other):
+        return self._compare('*', other)
+    def __div__(self, other):
+        return self._compare('/', other)
+    def __truediv__(self, other):
+        return self._compare('/', other)
+    def _compare(self, operator, obj):
+        if _is_literal(obj):
+            if obj is None:
+                if operator != '=':
+                    raise "Only '=' operator can be used with NULL"
+                return BinaryClause(self, null(), 'IS')
+            else:
+                obj = BindParamClause('literal', obj, shortname=None, type=self.type)
+
+        return BinaryClause(self, obj, operator)
+
         
 class ColumnClause(ClauseElement, CompareMixin):
     """represents a textual column clause in a SQL statement."""
@@ -427,7 +459,7 @@ class FromClause(ClauseElement):
     def accept_visitor(self, visitor): 
         visitor.visit_fromclause(self)
     
-class BindParamClause(ClauseElement):
+class BindParamClause(ClauseElement, CompareMixin):
     def __init__(self, key, value, shortname = None, type = None):
         self.key = key
         self.value = value
@@ -532,12 +564,19 @@ class CompoundClause(ClauseList):
 class Function(ClauseList, CompareMixin):
     """describes a SQL function. extends ClauseList to provide comparison operators."""
     def __init__(self, name, *clauses, **kwargs):
-        ClauseList.__init__(self, parens=True, *clauses)
         self.name = name
         self.type = kwargs.get('type', None)
         self.label = kwargs.get('label', None)
+        ClauseList.__init__(self, parens=True, *clauses)
     columns = property(lambda self: [self])
     key = property(lambda self:self.label or self.name)
+    def append(self, clause):
+        if _is_literal(clause):
+            if clause is None:
+                clause = null()
+            else:
+                clause = BindParamClause(self.name, clause, shortname=self.name, type=None)
+        self.clauses.append(clause)
     def copy_container(self):
         return self
     def accept_visitor(self, visitor):
@@ -941,13 +980,15 @@ class Select(Selectable, TailClauseMixin):
             if self.rowid_column is None and hasattr(f, 'rowid_column'):
                 self.rowid_column = f.rowid_column._make_proxy(self)
         column._process_from_dict(self._froms, False)
-        
-        for co in column.columns:
-            if self.use_labels:
-                co._make_proxy(self, name = co.label)
-            else:
-                co._make_proxy(self)
 
+        # TODO: dont use hasattr here, get a general way to locate
+        # selectable columns off stuff working completely (i.e. Selectable)
+        if hasattr(column, 'columns'):
+            for co in column.columns:
+                if self.use_labels:
+                    co._make_proxy(self, name = co.label)
+                else:
+                    co._make_proxy(self)
 
     def get_col_by_original(self, column):
         if self.use_labels:
index deaa02c6cdf8fbefd1dcadec7e013090122402de..54284f0c0de9ae8d7e1b3727646711c9cddbe00d 100644 (file)
@@ -43,7 +43,7 @@ class MapperTest(MapperSuperTest):
 
     def testmultitable(self):
         usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
-        m = mapper(User, usersaddresses, primarytable = users, primary_keys=[users.c.user_id])
+        m = mapper(User, usersaddresses, primarytable = users, primary_key=[users.c.user_id])
         l = m.select()
         self.assert_result(l, User, {'user_id' : 7}, {'user_id' : 8})
 
index f17b49c1238a312d215346dfe50387fdc371ea0a..0942e41eec8d83b6f3babd682638233e6f7b1728 100644 (file)
@@ -604,7 +604,7 @@ class SaveTest(AssertMixin):
         m = mapper(Item, items, properties = dict(
                 keywords = relation(IKAssociation, itemkeywords, lazy = False, properties = dict(
                     keyword = relation(Keyword, keywords, lazy = False, uselist = False)
-                ), primary_keys = [itemkeywords.c.item_id, itemkeywords.c.keyword_id])
+                ), primary_key = [itemkeywords.c.item_id, itemkeywords.c.keyword_id])
             ))
 
         data = [Item,
index b98ff393b2a87a8f072c6e4032314fb4400e7993..66a4b9a60bc3a2eccde2c8c9e3dd3642df537b84 100644 (file)
@@ -208,7 +208,7 @@ class KeywordAssociation(object):pass
 # lazy loading for that.
 m = mapper(Article, articles, properties=dict(
     keywords = relation(KeywordAssociation, itemkeywords, lazy = False, 
-        primary_keys=[itemkeywords.c.article_id, itemkeywords.c.keyword_id], 
+        primary_key=[itemkeywords.c.article_id, itemkeywords.c.keyword_id], 
         properties=dict(
             keyword = relation(Keyword, keywords, lazy = False),
             user = relation(User, users, lazy = True)
@@ -236,4 +236,3 @@ for a in alist:
         if k.keyword.name == 'jacks_stories':
             print k.user.user_name
 
-