]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- moved query._with_parent into prop.compare() calls
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Jul 2007 01:14:33 +0000 (01:14 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Jul 2007 01:14:33 +0000 (01:14 +0000)
- built extensible operator framework in sql package, ORM
builds on top of it to shuttle python operator objects back down
to the individual columns.  no relation() comparisons yet.  implements
half of [ticket:643]

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql.py
test/orm/query.py
test/sql/select.py

index 8329fbaec9c7a825a5ac3474e77a69973f56b2d9..210e4f2c59e5c976a67c6af5d2bd0ee016f9f0c0 100644 (file)
@@ -1150,9 +1150,7 @@ class ResultProxy(object):
             elif isinstance(key, basestring) and key.lower() in props:
                 rec = props[key.lower()]
             elif isinstance(key, sql.ColumnElement):
-                print "LABEL ON COLUMN", repr(key.key), "IS", repr(key._label)
                 label = context.column_labels.get(key._label, key.name).lower()
-                print "SO YEAH, NOW WE GOT LABEL", repr(label), "AND PROPS IS", repr(props)
                 if label in props:
                     rec = props[label]
 
index 0351af214c0e70cf3fe1c87beb2d5db73b77d911..84d464e2047e66ad001212f7e6e883bdfa0b2ccf 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from sqlalchemy import util
+from sqlalchemy import util, sql
 from sqlalchemy.orm import util as orm_util, interfaces, collections
 from sqlalchemy.orm.mapper import class_mapper
 from sqlalchemy import logging, exceptions
@@ -14,15 +14,55 @@ import weakref
 PASSIVE_NORESULT = object()
 ATTR_WAS_SET = object()
 
-class InstrumentedAttribute(object):
-    def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, **kwargs):
+class InstrumentedAttribute(sql.Comparator):
+    """attribute access for instrumented classes."""
+    
+    def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, comparator=None, **kwargs):
+        """Construct an InstrumentedAttribute.
+        
+            class_
+              the class to be instrumented.
+                
+            manager
+              AttributeManager managing this class
+              
+            key
+              string name of the attribute
+              
+            callable_
+              optional function which generates a callable based on a parent 
+              instance, which produces the "default" values for a scalar or 
+              collection attribute when it's first accessed, if not present already.
+              
+            trackparent
+              if True, attempt to track if an instance has a parent attached to it 
+              via this attribute
+              
+            extension
+              an AttributeExtension object which will receive 
+              set/delete/append/remove/etc. events 
+              
+            compare_function
+              a function that compares two values which are normally assignable to this 
+              attribute
+              
+            mutable_scalars
+              if True, the values which are normally assignable to this attribute can mutate, 
+              and need to be compared against a copy of their original contents in order to 
+              detect changes on the parent instance
+              
+            comparator
+              a sql.Comparator to which compare/math events will be sent
+              
+        """
+        
         self.class_ = class_
         self.manager = manager
         self.key = key
         self.callable_ = callable_
         self.trackparent = trackparent
         self.mutable_scalars = mutable_scalars
-
+        self.comparator = comparator
         self.copy = None
         if compare_function is None:
             self.is_equal = lambda x,y: x == y
@@ -41,6 +81,15 @@ class InstrumentedAttribute(object):
             return self
         return self.get(obj)
 
+    def compare_self(self):
+        return self.comparator.compare_self()
+        
+    def operate(self, op, other):
+        return self.comparator.operate(op, other)
+
+    def reverse_operate(self, op, other):
+        return self.comparator.reverse_operate(op, other)
+        
     def hasparent(self, item, optimistic=False):
         """Return the boolean value of a `hasparent` flag attached to the given item.
 
@@ -242,6 +291,8 @@ InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute)
 
         
 class InstrumentedScalarAttribute(InstrumentedAttribute):
+    """represents a scalar-holding InstrumentedAttribute."""
+    
     def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs):
         super(InstrumentedScalarAttribute, self).__init__(class_, manager, key,
           callable_, trackparent=trackparent, extension=extension,
@@ -295,6 +346,9 @@ class InstrumentedScalarAttribute(InstrumentedAttribute):
         obj.__dict__[self.key] = value
         self.fire_replace_event(obj, value, old, initiator)
 
+    type = property(lambda self: self.property.columns[0].type)
+
+        
 class InstrumentedCollectionAttribute(InstrumentedAttribute):
     """A collection-holding attribute that instruments changes in membership.
 
@@ -592,17 +646,7 @@ class AttributeHistory(object):
         return self.attr.hasparent(obj)
 
 class AttributeManager(object):
-    """Allow the instrumentation of object attributes.
-
-    ``AttributeManager`` is stateless, but can be overridden by
-    subclasses to redefine some of its factory operations. Also be
-    aware ``AttributeManager`` will cache attributes for a given
-    class, allowing not to determine those for each objects (used in
-    ``managed_attributes()`` and
-    ``noninherited_managed_attributes()``). This cache is cleared for
-    a given class while calling ``register_attribute()``, and can be
-    cleared using ``clear_attribute_cache()``.
-    """
+    """Allow the instrumentation of object attributes."""
 
     def __init__(self):
         # will cache attributes, indexed by class objects
index f1bb20a81877856f05ea27be31f8ca55a1746ba6..b0a2399f3b48adda49712154e26ec421a1cac5f4 100644 (file)
@@ -5,7 +5,7 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 
-from sqlalchemy import util, logging
+from sqlalchemy import util, logging, sql
 
 # returned by a MapperExtension method to indicate a "do nothing" response
 EXT_PASS = object()
@@ -334,6 +334,11 @@ class MapperProperty(object):
         raise NotImplementedError()
 
 
+class PropComparator(sql.Comparator):
+    """defines comparison operations for MapperProperty objects"""
+    
+    def __init__(self, prop):
+        self.prop = prop
 
 class StrategizedProperty(MapperProperty):
     """A MapperProperty which uses selectable strategies to affect
index de844ee236e0e5cb0b6bfa6db719a260233484eb..a32354fcad96926b61d53b444766615b265d88ee 100644 (file)
@@ -33,6 +33,7 @@ class ColumnProperty(StrategizedProperty):
         self.columns = list(columns)
         self.group = kwargs.pop('group', None)
         self.deferred = kwargs.pop('deferred', False)
+        self.comparator = ColumnProperty.ColumnComparator(self)
         
     def create_strategy(self):
         if self.deferred:
@@ -55,11 +56,23 @@ class ColumnProperty(StrategizedProperty):
     def merge(self, session, source, dest, _recursive):
         setattr(dest, self.key, getattr(source, self.key, None))
 
-    def compare(self, value):
-        return self.columns[0] == value
+    def compare(self, value, op='=='):
+        return self.comparator == value
 
     def get_col_value(self, column, value):
         return value
+
+    class ColumnComparator(PropComparator):
+        def compare_self(self):
+            return self.prop.columns[0]
+            
+        def operate(self, op, other):
+            return op(self.prop.columns[0], other)
+
+        def reverse_operate(self, op, other):
+            col = self.prop.columns[0]
+            return op(col._bind_param(other), col)
+            
             
 ColumnProperty.logger = logging.class_logger(ColumnProperty)
 
@@ -71,7 +84,8 @@ class CompositeProperty(ColumnProperty):
     def __init__(self, class_, *columns, **kwargs):
         super(CompositeProperty, self).__init__(*columns, **kwargs)
         self.composite_class = class_
-
+        self.comparator = None
+        
     def copy(self):
         return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
 
@@ -87,8 +101,12 @@ class CompositeProperty(ColumnProperty):
             if a is column:
                 setattr(obj, b, value)
 
-    def compare(self, value):
-        return sql.and_([a==b for a, b in zip(self.columns, value.__colset__())])
+    def compare(self, value, op='=='):
+        # TODO: build into operator framework
+        if op == '==':
+            return sql.and_([a==b for a, b in zip(self.columns, value.__colset__())])
+        elif op == '!=':
+            return sql.or_([a!=b for a, b in zip(self.columns, value.__colset__())])
 
     def get_col_value(self, column, value):
         for a, b in zip(self.columns, value.__colset__()):
@@ -119,6 +137,7 @@ class PropertyLoader(StrategizedProperty):
         self.remote_side = util.to_set(remote_side)
         self.enable_typechecks = enable_typechecks
         self._parent_join_cache = {}
+        self.comparator = None
 
         if cascade is not None:
             self.cascade = mapperutil.CascadeOptions(cascade)
@@ -143,8 +162,21 @@ class PropertyLoader(StrategizedProperty):
             self.backref = backref
         self.is_backref = is_backref
 
-    def compare(self, value):
-        return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))])
+    def compare(self, value, value_is_parent=False, op='=='):
+        if op == '==':
+            # optimized operation for ==, uses a lazy clause.
+            (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent)
+            bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+
+            class Visitor(sql.ClauseVisitor):
+                def visit_bindparam(s, bindparam):
+                    mapper = value_is_parent and self.parent or self.mapper
+                    bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key])
+            Visitor().traverse(criterion)
+            return criterion
+        else:
+            # TODO: build expressions like these into operator framework
+            return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))])
 
     private = property(lambda s:s.cascade.delete_orphan)
 
index 12070b2b421b3057e91865d8d2814d7a0b096ed1..3937149ee594a26eac97b65b15f054e7731cef37 100644 (file)
@@ -98,26 +98,6 @@ class Query(object):
         if instance is None:
             raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
         return instance
-
-
-    def _with_lazy_criterion(cls, instance, prop, reverse=False):
-        """extract query criterion from a LazyLoader strategy given a Mapper, 
-        source persisted/detached instance and PropertyLoader.
-        
-        """
-        
-        from sqlalchemy.orm import strategies
-        (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(prop, reverse_direction=reverse)
-        bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
-
-        class Visitor(sql.ClauseVisitor):
-            def visit_bindparam(self, bindparam):
-                mapper = reverse and prop.mapper or prop.parent
-                bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key])
-        Visitor().traverse(criterion)
-        return criterion
-    _with_lazy_criterion = classmethod(_with_lazy_criterion)
-    
         
     def query_from_parent(cls, instance, property, **kwargs):
         """return a newly constructed Query object, with criterion corresponding to 
@@ -140,7 +120,7 @@ class Query(object):
         mapper = object_mapper(instance)
         prop = mapper.get_property(property, resolve_synonyms=True)
         target = prop.mapper
-        criterion = cls._with_lazy_criterion(instance, prop)
+        criterion = prop.compare(instance, value_is_parent=True)
         return Query(target, **kwargs).filter(criterion)
     query_from_parent = classmethod(query_from_parent)
         
@@ -169,7 +149,7 @@ class Query(object):
                 raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__))
         else:
             prop = mapper.get_property(property, resolve_synonyms=True)
-        return self.filter(Query._with_lazy_criterion(instance, prop))
+        return self.filter(prop.compare(instance, value_is_parent=True))
 
     def add_entity(self, entity):
         """add a mapped entity to the list of result columns to be returned.
@@ -285,10 +265,8 @@ class Query(object):
 
         for key, value in kwargs.iteritems():
             prop = joinpoint.get_property(key, resolve_synonyms=True)
-            if isinstance(prop, properties.PropertyLoader):
-                c = self._with_lazy_criterion(value, prop, True) # & self.join_via(keys[:-1]) - use aliasized join feature
-            else:
-                c = prop.compare(value) # & self.join_via(keys) - use aliasized join feature
+            c = prop.compare(value)
+
             if alias is not None:
                 sql_util.ClauseAdapter(alias).traverse(c)
             if clause is None:
@@ -1033,7 +1011,7 @@ class Query(object):
         for key, value in params.iteritems():
             (keys, prop) = self._locate_prop(key, start=start)
             if isinstance(prop, properties.PropertyLoader):
-                c = self._with_lazy_criterion(value, prop, True) & self.join_via(keys[:-1])
+                c = prop.compare(value) & self.join_via(keys[:-1])
             else:
                 c = prop.compare(value) & self.join_via(keys)
             if clause is None:
index 0fccba0293f61923bb7b59c539b33e6a74ba9f10..c790af71bb9020cda71c131d402ccc11c2f5636a 100644 (file)
@@ -43,12 +43,12 @@ class ColumnLoader(LoaderStrategy):
                     return False
             else:
                 return True
-        sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=copy, compare_function=compare, mutable_scalars=True)
+        sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator)
 
     def _init_scalar_attribute(self):
         self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
         coltype = self.columns[0].type
-        sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable())
+        sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
         
     def create_row_processor(self, selectcontext, mapper, row):
         if self.is_composite:
@@ -152,7 +152,7 @@ class DeferredColumnLoader(LoaderStrategy):
 
     def init_class_attribute(self):
         self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
-        sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable())
+        sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
 
     def setup_query(self, context, **kwargs):
         if self.group is not None and context.attributes.get(('undefer', self.group), False):
@@ -241,7 +241,7 @@ class AbstractRelationLoader(LoaderStrategy):
         
     def _register_attribute(self, class_, callable_=None):
         self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__))
-        sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade,  trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_)
+        sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade,  trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator)
 
 class NoLoader(AbstractRelationLoader):
     def init_class_attribute(self):
@@ -372,7 +372,7 @@ class LazyLoader(AbstractRelationLoader):
                     sessionlib.attribute_manager.reset_instance_attribute(instance, self.key)
             return (execute, None)
 
-    def _create_lazy_clause(cls, prop, reverse_direction=False):
+    def _create_lazy_clause(cls, prop, reverse_direction=False, op='=='):
         (primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side)
         
         binds = {}
@@ -399,6 +399,11 @@ class LazyLoader(AbstractRelationLoader):
             rightcol = find_column_in_expr(binary.right)
             if leftcol is None or rightcol is None:
                 return
+            
+            # TODO: comprehensive negation support for expressions    
+            if op == '!=' and binary.operator == '==':
+                binary.operator = '!='
+                
             if should_bind(leftcol, rightcol):
                 col = leftcol
                 binary.left = binds.setdefault(leftcol,
index db7625382b79ac788d8f05a80cd6b26ef45bd3f3..e177f4194a2f8f11c605c6189ce933119f8b96ae 100644 (file)
@@ -26,7 +26,7 @@ are less guaranteed to stay the same in future releases.
 
 from sqlalchemy import util, exceptions, logging
 from sqlalchemy import types as sqltypes
-import string, re, sets
+import string, re, sets, operator
 
 __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
            'ClauseVisitor', 'ColumnCollection', 'ColumnElement',
@@ -1126,44 +1126,144 @@ class ClauseElement(object):
     def _negate(self):
         return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
 
-class _CompareMixin(object):
-    """Defines comparison operations for ``ClauseElement`` instances.
+
+class Comparator(object):
+    """defines comparison and math operations"""
+
+    def like_op(a, b):
+        return a.like(b)
+    like_op = staticmethod(like_op)
     
-    This is a mixin class that adds the capability to produce ``ClauseElement``
-    instances based on regular Python operators.  
-    These operations are achieved using Python's operator overload methods
-    (i.e. ``__eq__()``, ``__ne__()``, etc.
+    def between_op(a, b):
+        return a.between(b)
+    between_op = staticmethod(between_op)
     
-    Overridden operators include all comparison operators (i.e. '==', '!=', '<'),
-    math operators ('+', '-', '*', etc), the '&' and '|' operators which evaluate
-    to ``AND`` and ``OR`` respectively. 
-
-    Other methods exist to create additional SQL clauses such as ``IN``, ``LIKE``, 
-    ``DISTINCT``, etc.
+    def in_op(a, b):
+        return a.in_(b)
+    in_op = staticmethod(in_op)
     
-    """
+    def startswith_op(a, b):
+        return a.startswith(b)
+    startswith_op = staticmethod(startswith_op)
+    
+    def endswith_op(a, b):
+        return a.endswith(b)
+    endswith_op = staticmethod(endswith_op)
+    
+    def compare_self(self):
+        raise NotImplementedError()
+        
+    def operate(self, op, other):
+        raise NotImplementedError()
 
+    def reverse_operate(self, op, other):
+        raise NotImplementedError()
+    
     def __lt__(self, other):
-        return self._compare('<', other)
+        return self.operate(operator.lt, other)
 
     def __le__(self, other):
-        return self._compare('<=', other)
+        return self.operate(operator.le, other)
 
     def __eq__(self, other):
-        return self._compare('=', other)
+        return self.operate(operator.eq, other)
 
     def __ne__(self, other):
-        return self._compare('!=', other)
+        return self.operate(operator.ne, other)
 
     def __gt__(self, other):
-        return self._compare('>', other)
+        return self.operate(operator.gt, other)
 
     def __ge__(self, other):
-        return self._compare('>=', other)
+        return self.operate(operator.ge, other)
 
     def like(self, other):
-        """produce a ``LIKE`` clause."""
-        return self._compare('LIKE', other)
+        return self.operate(Comparator.like_op, other)
+
+    def in_(self, *other):
+        return self.operate(Comparator.in_op, other)
+
+    def startswith(self, other):
+        return self.operate(Comparator.startswith_op, other)
+
+    def endswith(self, other):
+        return self.operate(Comparator.endswith_op, other)
+
+    def __radd__(self, other):
+        return self.reverse_operate(operator.add, other)
+
+    def __rsub__(self, other):
+        return self.reverse_operate(operator.sub, other)
+
+    def __rmul__(self, other):
+        return self.reverse_operate(operator.mul, other)
+
+    def __rdiv__(self, other):
+        return self.reverse_operate(operator.div, other)
+
+    def between(self, cleft, cright):
+        return self.operate(Comparator.between_op, (cleft, cright))
+
+    def __add__(self, other):
+        return self.operate(operator.add, other)
+
+    def __sub__(self, other):
+        return self.operate(operator.sub, other)
+
+    def __mul__(self, other):
+        return self.operate(operator.mul, other)
+
+    def __div__(self, other):
+        return self.operate(operator.div, other)
+
+    def __mod__(self, other):
+        return self.operate(operator.mod, other)
+
+    def __truediv__(self, other):
+        return self.operate(operator.truediv, other)
+
+class _CompareMixin(Comparator):
+    """Defines comparison and math operations for ``ClauseElement`` instances."""
+
+    def __compare(self, operator, obj, negate=None):
+        if obj is None or isinstance(obj, _Null):
+            if operator == '=':
+                return _BinaryExpression(self.compare_self(), null(), 'IS', negate='IS NOT')
+            elif operator == '!=':
+                return _BinaryExpression(self.compare_self(), null(), 'IS NOT', negate='IS')
+            else:
+                raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
+        else:
+            obj = self._check_literal(obj)
+
+        return _BinaryExpression(self.compare_self(), obj, operator, type=sqltypes.Boolean, negate=negate)
+
+    def __operate(self, operator, obj):
+        obj = self._check_literal(obj)
+        return _BinaryExpression(self.compare_self(), obj, operator, type=self._compare_type(obj))
+
+    operators = {
+        operator.add : (__operate, '+'),
+        operator.mul : (__operate, '*'),
+        operator.sub : (__operate, '-'),
+        operator.div : (__operate, '/'),
+        operator.mod : (__operate, '%'),
+        operator.truediv : (__operate, '/'),
+        operator.lt : (__compare, '<', '=>'),
+        operator.le : (__compare, '<=', '>'),
+        operator.ne : (__compare, '!=', '='),
+        operator.gt : (__compare, '>', '<='),
+        operator.ge : (__compare, '>=', '<'),
+        operator.eq : (__compare, '=', '!='),
+        Comparator.like_op : (__compare, 'LIKE', 'NOT LIKE'),
+    }
+
+    def operate(self, op, other):
+        o = _CompareMixin.operators[op]
+        return o[0](self, o[1], other, *o[2:])
+    
+    def reverse_operate(self, op, other):
+        return self._bind_param(other).operate(op, self)
 
     def in_(self, *other):
         """produce an ``IN`` clause."""
@@ -1175,7 +1275,7 @@ class _CompareMixin(object):
                 return self.__eq__( o)    #single item -> ==
             else:
                 assert hasattr( o, '_selectable')   #better check?
-                return self._compare( 'IN', o, negate='NOT IN')   #single selectable
+                return self.__compare( 'IN', o, negate='NOT IN')   #single selectable
 
         args = []
         for o in other:
@@ -1185,12 +1285,12 @@ class _CompareMixin(object):
             else:
                 o = self._bind_param(o)
             args.append(o)
-        return self._compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN')
+        return self.__compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN')
 
     def startswith(self, other):
         """produce the clause ``LIKE '<other>%'``"""
         perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String)
-        return self._compare('LIKE', other + perc)
+        return self.__compare('LIKE', other + perc)
 
     def endswith(self, other):
         """produce the clause ``LIKE '%<other>'``"""
@@ -1198,16 +1298,7 @@ class _CompareMixin(object):
         else:
             po = literal('%', type= sqltypes.String) + other
             po.type = sqltypes.to_instance( sqltypes.String)     #force!
-        return self._compare('LIKE', po)
-
-    def __radd__(self, other):
-        return self._bind_param(other)._operate('+', self)
-    def __rsub__(self, other):
-        return self._bind_param(other)._operate('-', self)
-    def __rmul__(self, other):
-        return self._bind_param(other)._operate('*', self)
-    def __rdiv__(self, other):
-        return self._bind_param(other)._operate('/', self)
+        return self.__compare('LIKE', po)
 
     def label(self, name):
         """produce a column label, i.e. ``<columnname> AS <name>``"""
@@ -1238,59 +1329,21 @@ class _CompareMixin(object):
             passed to the generated function.
             
         """
-        return lambda other: self._operate(operator, other)
-
-    # and here come the math operators:
-
-    def __add__(self, other):
-        return self._operate('+', other)
-
-    def __sub__(self, other):
-        return self._operate('-', other)
-
-    def __mul__(self, other):
-        return self._operate('*', other)
-
-    def __div__(self, other):
-        return self._operate('/', other)
-
-    def __mod__(self, other):
-        return self._operate('%', other)
-
-    def __truediv__(self, other):
-        return self._operate('/', other)
+        return lambda other: self.__operate(operator, other)
 
     def _bind_param(self, obj):
         return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True)
 
     def _check_literal(self, other):
-        if _is_literal(other):
+        if isinstance(other, Comparator):
+            return other.compare_self()
+        elif _is_literal(other):
             return self._bind_param(other)
         else:
             return other
-
-    def _compare(self, operator, obj, negate=None):
-        if obj is None or isinstance(obj, _Null):
-            if operator == '=':
-                return _BinaryExpression(self._compare_self(), null(), 'IS', negate='IS NOT')
-            elif operator == '!=':
-                return _BinaryExpression(self._compare_self(), null(), 'IS NOT', negate='IS')
-            else:
-                raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
-        else:
-            obj = self._check_literal(obj)
-
-        return _BinaryExpression(self._compare_self(), obj, operator, type=sqltypes.Boolean, negate=negate)
-
-    def _operate(self, operator, obj):
-        if _is_literal(obj):
-            obj = self._bind_param(obj)
-        return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj))
-
-    def _compare_self(self):
-        """Allow ``ColumnImpl`` to return its ``Column`` object for
-        usage in ``ClauseElements``, all others to just return self.
-        """
+    
+    def compare_self(self):
+        """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions."""
 
         return self
 
@@ -2398,7 +2451,7 @@ class _Label(ColumnElement):
     _label = property(lambda s: s.name)
     orig_set = property(lambda s:s.obj.orig_set)
 
-    def _compare_self(self):
+    def compare_self(self):
         return self.obj
     
     def _copy_internals(self):
index 57d533a91c491077d122638561f7e3e9a7f00cbb..c5c48f4bb463becbb14c059baf609d84a6aa0234 100644 (file)
@@ -1,8 +1,10 @@
 from sqlalchemy import *
+from sqlalchemy import ansisql
 from sqlalchemy.orm import *
 import testbase
 from testbase import Table, Column
 from fixtures import *
+import operator
 
 class Base(object):
     def __init__(self, **kwargs):
@@ -97,10 +99,63 @@ class GetTest(QueryTest):
         mapper(LocalFoo, table)
         assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring)
 
+class OperatorTest(QueryTest):
+    """test sql.Comparator implementation for MapperProperties"""
+    
+    def _test(self, clause, expected):
+        c = str(clause.compile(dialect=ansisql.ANSIDialect()))
+        assert c == expected, "%s != %s" % (c, expected)
+        
+    def test_arithmetic(self):
+        create_session().query(User)
+        for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'),
+                                (operator.sub, '-'), (operator.div, '/'),
+                                ):
+            for (lhs, rhs, res) in (
+                ('a', User.id, ':users_id %s users.id'),
+                ('a', literal('b'), ':literal %s :literal_1'),
+                (User.id, 'b', 'users.id %s :users_id'),
+                (User.id, literal('b'), 'users.id %s :literal'),
+                (User.id, User.id, 'users.id %s users.id'),
+                (literal('a'), 'b', ':literal %s :literal_1'),
+                (literal('a'), User.id, ':literal %s users.id'),
+                (literal('a'), literal('b'), ':literal %s :literal_1'),
+                ):
+                self._test(py_op(lhs, rhs), res % sql_op)
+
+    def test_comparison(self):
+        create_session().query(User)
+        for (py_op, fwd_op, rev_op) in ((operator.lt, '<', '>'),
+                                        (operator.gt, '>', '<'),
+                                        (operator.eq, '=', '='),
+                                        (operator.ne, '!=', '!='),
+                                        (operator.le, '<=', '>='),
+                                        (operator.ge, '>=', '<=')):
+            for (lhs, rhs, l_sql, r_sql) in (
+                ('a', User.id, ':users_id', 'users.id'),
+                ('a', literal('b'), ':literal_1', ':literal'), # note swap!
+                (User.id, 'b', 'users.id', ':users_id'),
+                (User.id, literal('b'), 'users.id', ':literal'),
+                (User.id, User.id, 'users.id', 'users.id'),
+                (literal('a'), 'b', ':literal', ':literal_1'),
+                (literal('a'), User.id, ':literal', 'users.id'),
+                (literal('a'), literal('b'), ':literal', ':literal_1'),
+                ):
+
+                # the compiled clause should match either (e.g.):
+                # 'a' < 'b' -or- 'b' > 'a'.
+                compiled = str(py_op(lhs, rhs).compile(dialect=ansisql.ANSIDialect()))
+                fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql)
+                rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql)
+
+                self.assert_(compiled == fwd_sql or compiled == rev_sql,
+                             "\n'" + compiled + "'\n does not match\n'" +
+                             fwd_sql + "'\n or\n'" + rev_sql + "'")
+    
 class CompileTest(QueryTest):
     def test_deferred(self):
         session = create_session()
-        s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.id)).compile()
+        s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).compile()
         
         l = session.query(User).instances(s.execute(emailad = 'jack@bean.com'))
         assert [User(id=7)] == l
@@ -109,7 +164,7 @@ class SliceTest(QueryTest):
     def test_first(self):
         assert  User(id=7) == create_session().query(User).first()
         
-        assert create_session().query(User).filter(users.c.id==27).first() is None
+        assert create_session().query(User).filter(User.id==27).first() is None
         
         # more slice tests are available in test/orm/generative.py
         
@@ -122,7 +177,7 @@ class TextTest(QueryTest):
 
         assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter("id=9").all()
 
-        assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(users.c.id==9).all()
+        assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(User.id==9).all()
 
     def test_binds(self):
         assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all()
@@ -139,14 +194,8 @@ class FilterTest(QueryTest):
         assert User(id=8) == create_session().query(User)[1]
         
     def test_onefilter(self):
-        assert [User(id=8), User(id=9)] == create_session().query(User).filter(users.c.name.endswith('ed')).all()
+        assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.name.endswith('ed')).all()
 
-    def test_typecheck(self):
-        try:
-            create_session().query(User).filter(User.name==5)
-            assert False
-        except exceptions.ArgumentError, e:
-            assert str(e) == "filter() argument must be of type sqlalchemy.sql.ClauseElement or string"
 
 class CountTest(QueryTest):
     def test_basic(self):
@@ -163,7 +212,7 @@ class TextTest(QueryTest):
 
         assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter("id=9").all()
 
-        assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(users.c.id==9).all()
+        assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(User.id==9).all()
 
     def test_binds(self):
         assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all()
index 8c1b9da7d659b7eb6e5c090bfd97b4c3c4078b1d..d5b00e1dab9c28d7d5310b38ae0287d11b71a442 100644 (file)
@@ -267,19 +267,6 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         )
         
     def testoperators(self):
-        self.runtest(
-            table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), 
-            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name = :mytable_name"
-        )
-
-        self.runtest(
-            table1.select((table1.c.myid != 12) & ~table1.c.name), 
-            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name"
-        )
-        
-        self.runtest(
-            literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2"
-        )
 
         # exercise arithmetic operators
         for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'),
@@ -325,6 +312,25 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
                              "\n'" + compiled + "'\n does not match\n'" +
                              fwd_sql + "'\n or\n'" + rev_sql + "'")
 
+        self.runtest(
+         table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), 
+         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND mytable.name != :mytable_name"
+        )
+
+        self.runtest(
+         table1.select((table1.c.myid != 12) & ~and_(table1.c.name=='john', table1.c.name=='ed', table1.c.name=='fred')), 
+         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT (mytable.name = :mytable_name AND mytable.name = :mytable_name_1 AND mytable.name = :mytable_name_2)"
+        )
+
+        self.runtest(
+         table1.select((table1.c.myid != 12) & ~table1.c.name), 
+         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name"
+        )
+
+        self.runtest(
+         literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2"
+        )
+
         # test the op() function, also that its results are further usable in expressions
         self.runtest(
             table1.select(table1.c.myid.op('hoho')(12)==14),
@@ -978,8 +984,8 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
             "SELECT op.field FROM op WHERE :literal + (op.field IN (:op_field, :op_field_1))")
         self.runtest(table.select((5 + table.c.field).in_(5,6)),
             "SELECT op.field FROM op WHERE :op_field + op.field IN (:literal, :literal_1)")
-        self.runtest(table.select(not_(table.c.field == 5)),
-            "SELECT op.field FROM op WHERE NOT op.field = :op_field")
+        self.runtest(table.select(not_(and_(table.c.field == 5, table.c.field == 7))),
+            "SELECT op.field FROM op WHERE NOT (op.field = :op_field AND op.field = :op_field_1)")
         self.runtest(table.select(not_(table.c.field) == 5),
             "SELECT op.field FROM op WHERE (NOT op.field) = :literal")
         self.runtest(table.select((table.c.field == table.c.field).between(False, True)),