From: Mike Bayer Date: Tue, 17 Jul 2007 01:14:33 +0000 (+0000) Subject: - moved query._with_parent into prop.compare() calls X-Git-Tag: rel_0_4_6~95 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9ff169028d74c65350d41030d89119ab06c8ac49;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - moved query._with_parent into prop.compare() calls - built extensible operator framework in sql package, ORM builds on top of it to shuttle python operator objects back down to the individual columns. no relation() comparisons yet. implements half of [ticket:643] --- diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 8329fbaec9..210e4f2c59 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1150,9 +1150,7 @@ class ResultProxy(object): elif isinstance(key, basestring) and key.lower() in props: rec = props[key.lower()] elif isinstance(key, sql.ColumnElement): - print "LABEL ON COLUMN", repr(key.key), "IS", repr(key._label) label = context.column_labels.get(key._label, key.name).lower() - print "SO YEAH, NOW WE GOT LABEL", repr(label), "AND PROPS IS", repr(props) if label in props: rec = props[label] diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 0351af214c..84d464e204 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import util +from sqlalchemy import util, sql from sqlalchemy.orm import util as orm_util, interfaces, collections from sqlalchemy.orm.mapper import class_mapper from sqlalchemy import logging, exceptions @@ -14,15 +14,55 @@ import weakref PASSIVE_NORESULT = object() ATTR_WAS_SET = object() -class InstrumentedAttribute(object): - def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, **kwargs): +class InstrumentedAttribute(sql.Comparator): + """attribute access for instrumented classes.""" + + def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, comparator=None, **kwargs): + """Construct an InstrumentedAttribute. + + class_ + the class to be instrumented. + + manager + AttributeManager managing this class + + key + string name of the attribute + + callable_ + optional function which generates a callable based on a parent + instance, which produces the "default" values for a scalar or + collection attribute when it's first accessed, if not present already. + + trackparent + if True, attempt to track if an instance has a parent attached to it + via this attribute + + extension + an AttributeExtension object which will receive + set/delete/append/remove/etc. events + + compare_function + a function that compares two values which are normally assignable to this + attribute + + mutable_scalars + if True, the values which are normally assignable to this attribute can mutate, + and need to be compared against a copy of their original contents in order to + detect changes on the parent instance + + comparator + a sql.Comparator to which compare/math events will be sent + + """ + self.class_ = class_ self.manager = manager self.key = key self.callable_ = callable_ self.trackparent = trackparent self.mutable_scalars = mutable_scalars - + self.comparator = comparator self.copy = None if compare_function is None: self.is_equal = lambda x,y: x == y @@ -41,6 +81,15 @@ class InstrumentedAttribute(object): return self return self.get(obj) + def compare_self(self): + return self.comparator.compare_self() + + def operate(self, op, other): + return self.comparator.operate(op, other) + + def reverse_operate(self, op, other): + return self.comparator.reverse_operate(op, other) + def hasparent(self, item, optimistic=False): """Return the boolean value of a `hasparent` flag attached to the given item. @@ -242,6 +291,8 @@ InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute) class InstrumentedScalarAttribute(InstrumentedAttribute): + """represents a scalar-holding InstrumentedAttribute.""" + def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs): super(InstrumentedScalarAttribute, self).__init__(class_, manager, key, callable_, trackparent=trackparent, extension=extension, @@ -295,6 +346,9 @@ class InstrumentedScalarAttribute(InstrumentedAttribute): obj.__dict__[self.key] = value self.fire_replace_event(obj, value, old, initiator) + type = property(lambda self: self.property.columns[0].type) + + class InstrumentedCollectionAttribute(InstrumentedAttribute): """A collection-holding attribute that instruments changes in membership. @@ -592,17 +646,7 @@ class AttributeHistory(object): return self.attr.hasparent(obj) class AttributeManager(object): - """Allow the instrumentation of object attributes. - - ``AttributeManager`` is stateless, but can be overridden by - subclasses to redefine some of its factory operations. Also be - aware ``AttributeManager`` will cache attributes for a given - class, allowing not to determine those for each objects (used in - ``managed_attributes()`` and - ``noninherited_managed_attributes()``). This cache is cleared for - a given class while calling ``register_attribute()``, and can be - cleared using ``clear_attribute_cache()``. - """ + """Allow the instrumentation of object attributes.""" def __init__(self): # will cache attributes, indexed by class objects diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index f1bb20a818..b0a2399f3b 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -5,7 +5,7 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import util, logging +from sqlalchemy import util, logging, sql # returned by a MapperExtension method to indicate a "do nothing" response EXT_PASS = object() @@ -334,6 +334,11 @@ class MapperProperty(object): raise NotImplementedError() +class PropComparator(sql.Comparator): + """defines comparison operations for MapperProperty objects""" + + def __init__(self, prop): + self.prop = prop class StrategizedProperty(MapperProperty): """A MapperProperty which uses selectable strategies to affect diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index de844ee236..a32354fcad 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -33,6 +33,7 @@ class ColumnProperty(StrategizedProperty): self.columns = list(columns) self.group = kwargs.pop('group', None) self.deferred = kwargs.pop('deferred', False) + self.comparator = ColumnProperty.ColumnComparator(self) def create_strategy(self): if self.deferred: @@ -55,11 +56,23 @@ class ColumnProperty(StrategizedProperty): def merge(self, session, source, dest, _recursive): setattr(dest, self.key, getattr(source, self.key, None)) - def compare(self, value): - return self.columns[0] == value + def compare(self, value, op='=='): + return self.comparator == value def get_col_value(self, column, value): return value + + class ColumnComparator(PropComparator): + def compare_self(self): + return self.prop.columns[0] + + def operate(self, op, other): + return op(self.prop.columns[0], other) + + def reverse_operate(self, op, other): + col = self.prop.columns[0] + return op(col._bind_param(other), col) + ColumnProperty.logger = logging.class_logger(ColumnProperty) @@ -71,7 +84,8 @@ class CompositeProperty(ColumnProperty): def __init__(self, class_, *columns, **kwargs): super(CompositeProperty, self).__init__(*columns, **kwargs) self.composite_class = class_ - + self.comparator = None + def copy(self): return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns) @@ -87,8 +101,12 @@ class CompositeProperty(ColumnProperty): if a is column: setattr(obj, b, value) - def compare(self, value): - return sql.and_([a==b for a, b in zip(self.columns, value.__colset__())]) + def compare(self, value, op='=='): + # TODO: build into operator framework + if op == '==': + return sql.and_([a==b for a, b in zip(self.columns, value.__colset__())]) + elif op == '!=': + return sql.or_([a!=b for a, b in zip(self.columns, value.__colset__())]) def get_col_value(self, column, value): for a, b in zip(self.columns, value.__colset__()): @@ -119,6 +137,7 @@ class PropertyLoader(StrategizedProperty): self.remote_side = util.to_set(remote_side) self.enable_typechecks = enable_typechecks self._parent_join_cache = {} + self.comparator = None if cascade is not None: self.cascade = mapperutil.CascadeOptions(cascade) @@ -143,8 +162,21 @@ class PropertyLoader(StrategizedProperty): self.backref = backref self.is_backref = is_backref - def compare(self, value): - return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))]) + def compare(self, value, value_is_parent=False, op='=='): + if op == '==': + # optimized operation for ==, uses a lazy clause. + (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent) + bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) + + class Visitor(sql.ClauseVisitor): + def visit_bindparam(s, bindparam): + mapper = value_is_parent and self.parent or self.mapper + bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key]) + Visitor().traverse(criterion) + return criterion + else: + # TODO: build expressions like these into operator framework + return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))]) private = property(lambda s:s.cascade.delete_orphan) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 12070b2b42..3937149ee5 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -98,26 +98,6 @@ class Query(object): if instance is None: raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident)) return instance - - - def _with_lazy_criterion(cls, instance, prop, reverse=False): - """extract query criterion from a LazyLoader strategy given a Mapper, - source persisted/detached instance and PropertyLoader. - - """ - - from sqlalchemy.orm import strategies - (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(prop, reverse_direction=reverse) - bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) - - class Visitor(sql.ClauseVisitor): - def visit_bindparam(self, bindparam): - mapper = reverse and prop.mapper or prop.parent - bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key]) - Visitor().traverse(criterion) - return criterion - _with_lazy_criterion = classmethod(_with_lazy_criterion) - def query_from_parent(cls, instance, property, **kwargs): """return a newly constructed Query object, with criterion corresponding to @@ -140,7 +120,7 @@ class Query(object): mapper = object_mapper(instance) prop = mapper.get_property(property, resolve_synonyms=True) target = prop.mapper - criterion = cls._with_lazy_criterion(instance, prop) + criterion = prop.compare(instance, value_is_parent=True) return Query(target, **kwargs).filter(criterion) query_from_parent = classmethod(query_from_parent) @@ -169,7 +149,7 @@ class Query(object): raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__)) else: prop = mapper.get_property(property, resolve_synonyms=True) - return self.filter(Query._with_lazy_criterion(instance, prop)) + return self.filter(prop.compare(instance, value_is_parent=True)) def add_entity(self, entity): """add a mapped entity to the list of result columns to be returned. @@ -285,10 +265,8 @@ class Query(object): for key, value in kwargs.iteritems(): prop = joinpoint.get_property(key, resolve_synonyms=True) - if isinstance(prop, properties.PropertyLoader): - c = self._with_lazy_criterion(value, prop, True) # & self.join_via(keys[:-1]) - use aliasized join feature - else: - c = prop.compare(value) # & self.join_via(keys) - use aliasized join feature + c = prop.compare(value) + if alias is not None: sql_util.ClauseAdapter(alias).traverse(c) if clause is None: @@ -1033,7 +1011,7 @@ class Query(object): for key, value in params.iteritems(): (keys, prop) = self._locate_prop(key, start=start) if isinstance(prop, properties.PropertyLoader): - c = self._with_lazy_criterion(value, prop, True) & self.join_via(keys[:-1]) + c = prop.compare(value) & self.join_via(keys[:-1]) else: c = prop.compare(value) & self.join_via(keys) if clause is None: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 0fccba0293..c790af71bb 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -43,12 +43,12 @@ class ColumnLoader(LoaderStrategy): return False else: return True - sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=copy, compare_function=compare, mutable_scalars=True) + sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator) def _init_scalar_attribute(self): self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__)) coltype = self.columns[0].type - sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable()) + sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator) def create_row_processor(self, selectcontext, mapper, row): if self.is_composite: @@ -152,7 +152,7 @@ class DeferredColumnLoader(LoaderStrategy): def init_class_attribute(self): self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__)) - sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable()) + sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator) def setup_query(self, context, **kwargs): if self.group is not None and context.attributes.get(('undefer', self.group), False): @@ -241,7 +241,7 @@ class AbstractRelationLoader(LoaderStrategy): def _register_attribute(self, class_, callable_=None): self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__)) - sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_) + sessionlib.attribute_manager.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator) class NoLoader(AbstractRelationLoader): def init_class_attribute(self): @@ -372,7 +372,7 @@ class LazyLoader(AbstractRelationLoader): sessionlib.attribute_manager.reset_instance_attribute(instance, self.key) return (execute, None) - def _create_lazy_clause(cls, prop, reverse_direction=False): + def _create_lazy_clause(cls, prop, reverse_direction=False, op='=='): (primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side) binds = {} @@ -399,6 +399,11 @@ class LazyLoader(AbstractRelationLoader): rightcol = find_column_in_expr(binary.right) if leftcol is None or rightcol is None: return + + # TODO: comprehensive negation support for expressions + if op == '!=' and binary.operator == '==': + binary.operator = '!=' + if should_bind(leftcol, rightcol): col = leftcol binary.left = binds.setdefault(leftcol, diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index db7625382b..e177f4194a 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -26,7 +26,7 @@ are less guaranteed to stay the same in future releases. from sqlalchemy import util, exceptions, logging from sqlalchemy import types as sqltypes -import string, re, sets +import string, re, sets, operator __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters', 'ClauseVisitor', 'ColumnCollection', 'ColumnElement', @@ -1126,44 +1126,144 @@ class ClauseElement(object): def _negate(self): return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None) -class _CompareMixin(object): - """Defines comparison operations for ``ClauseElement`` instances. + +class Comparator(object): + """defines comparison and math operations""" + + def like_op(a, b): + return a.like(b) + like_op = staticmethod(like_op) - This is a mixin class that adds the capability to produce ``ClauseElement`` - instances based on regular Python operators. - These operations are achieved using Python's operator overload methods - (i.e. ``__eq__()``, ``__ne__()``, etc. + def between_op(a, b): + return a.between(b) + between_op = staticmethod(between_op) - Overridden operators include all comparison operators (i.e. '==', '!=', '<'), - math operators ('+', '-', '*', etc), the '&' and '|' operators which evaluate - to ``AND`` and ``OR`` respectively. - - Other methods exist to create additional SQL clauses such as ``IN``, ``LIKE``, - ``DISTINCT``, etc. + def in_op(a, b): + return a.in_(b) + in_op = staticmethod(in_op) - """ + def startswith_op(a, b): + return a.startswith(b) + startswith_op = staticmethod(startswith_op) + + def endswith_op(a, b): + return a.endswith(b) + endswith_op = staticmethod(endswith_op) + + def compare_self(self): + raise NotImplementedError() + + def operate(self, op, other): + raise NotImplementedError() + def reverse_operate(self, op, other): + raise NotImplementedError() + def __lt__(self, other): - return self._compare('<', other) + return self.operate(operator.lt, other) def __le__(self, other): - return self._compare('<=', other) + return self.operate(operator.le, other) def __eq__(self, other): - return self._compare('=', other) + return self.operate(operator.eq, other) def __ne__(self, other): - return self._compare('!=', other) + return self.operate(operator.ne, other) def __gt__(self, other): - return self._compare('>', other) + return self.operate(operator.gt, other) def __ge__(self, other): - return self._compare('>=', other) + return self.operate(operator.ge, other) def like(self, other): - """produce a ``LIKE`` clause.""" - return self._compare('LIKE', other) + return self.operate(Comparator.like_op, other) + + def in_(self, *other): + return self.operate(Comparator.in_op, other) + + def startswith(self, other): + return self.operate(Comparator.startswith_op, other) + + def endswith(self, other): + return self.operate(Comparator.endswith_op, other) + + def __radd__(self, other): + return self.reverse_operate(operator.add, other) + + def __rsub__(self, other): + return self.reverse_operate(operator.sub, other) + + def __rmul__(self, other): + return self.reverse_operate(operator.mul, other) + + def __rdiv__(self, other): + return self.reverse_operate(operator.div, other) + + def between(self, cleft, cright): + return self.operate(Comparator.between_op, (cleft, cright)) + + def __add__(self, other): + return self.operate(operator.add, other) + + def __sub__(self, other): + return self.operate(operator.sub, other) + + def __mul__(self, other): + return self.operate(operator.mul, other) + + def __div__(self, other): + return self.operate(operator.div, other) + + def __mod__(self, other): + return self.operate(operator.mod, other) + + def __truediv__(self, other): + return self.operate(operator.truediv, other) + +class _CompareMixin(Comparator): + """Defines comparison and math operations for ``ClauseElement`` instances.""" + + def __compare(self, operator, obj, negate=None): + if obj is None or isinstance(obj, _Null): + if operator == '=': + return _BinaryExpression(self.compare_self(), null(), 'IS', negate='IS NOT') + elif operator == '!=': + return _BinaryExpression(self.compare_self(), null(), 'IS NOT', negate='IS') + else: + raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") + else: + obj = self._check_literal(obj) + + return _BinaryExpression(self.compare_self(), obj, operator, type=sqltypes.Boolean, negate=negate) + + def __operate(self, operator, obj): + obj = self._check_literal(obj) + return _BinaryExpression(self.compare_self(), obj, operator, type=self._compare_type(obj)) + + operators = { + operator.add : (__operate, '+'), + operator.mul : (__operate, '*'), + operator.sub : (__operate, '-'), + operator.div : (__operate, '/'), + operator.mod : (__operate, '%'), + operator.truediv : (__operate, '/'), + operator.lt : (__compare, '<', '=>'), + operator.le : (__compare, '<=', '>'), + operator.ne : (__compare, '!=', '='), + operator.gt : (__compare, '>', '<='), + operator.ge : (__compare, '>=', '<'), + operator.eq : (__compare, '=', '!='), + Comparator.like_op : (__compare, 'LIKE', 'NOT LIKE'), + } + + def operate(self, op, other): + o = _CompareMixin.operators[op] + return o[0](self, o[1], other, *o[2:]) + + def reverse_operate(self, op, other): + return self._bind_param(other).operate(op, self) def in_(self, *other): """produce an ``IN`` clause.""" @@ -1175,7 +1275,7 @@ class _CompareMixin(object): return self.__eq__( o) #single item -> == else: assert hasattr( o, '_selectable') #better check? - return self._compare( 'IN', o, negate='NOT IN') #single selectable + return self.__compare( 'IN', o, negate='NOT IN') #single selectable args = [] for o in other: @@ -1185,12 +1285,12 @@ class _CompareMixin(object): else: o = self._bind_param(o) args.append(o) - return self._compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN') + return self.__compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN') def startswith(self, other): """produce the clause ``LIKE '%'``""" perc = isinstance(other,(str,unicode)) and '%' or literal('%',type= sqltypes.String) - return self._compare('LIKE', other + perc) + return self.__compare('LIKE', other + perc) def endswith(self, other): """produce the clause ``LIKE '%'``""" @@ -1198,16 +1298,7 @@ class _CompareMixin(object): else: po = literal('%', type= sqltypes.String) + other po.type = sqltypes.to_instance( sqltypes.String) #force! - return self._compare('LIKE', po) - - def __radd__(self, other): - return self._bind_param(other)._operate('+', self) - def __rsub__(self, other): - return self._bind_param(other)._operate('-', self) - def __rmul__(self, other): - return self._bind_param(other)._operate('*', self) - def __rdiv__(self, other): - return self._bind_param(other)._operate('/', self) + return self.__compare('LIKE', po) def label(self, name): """produce a column label, i.e. `` AS ``""" @@ -1238,59 +1329,21 @@ class _CompareMixin(object): passed to the generated function. """ - return lambda other: self._operate(operator, other) - - # and here come the math operators: - - def __add__(self, other): - return self._operate('+', other) - - def __sub__(self, other): - return self._operate('-', other) - - def __mul__(self, other): - return self._operate('*', other) - - def __div__(self, other): - return self._operate('/', other) - - def __mod__(self, other): - return self._operate('%', other) - - def __truediv__(self, other): - return self._operate('/', other) + return lambda other: self.__operate(operator, other) def _bind_param(self, obj): return _BindParamClause('literal', obj, shortname=None, type=self.type, unique=True) def _check_literal(self, other): - if _is_literal(other): + if isinstance(other, Comparator): + return other.compare_self() + elif _is_literal(other): return self._bind_param(other) else: return other - - def _compare(self, operator, obj, negate=None): - if obj is None or isinstance(obj, _Null): - if operator == '=': - return _BinaryExpression(self._compare_self(), null(), 'IS', negate='IS NOT') - elif operator == '!=': - return _BinaryExpression(self._compare_self(), null(), 'IS NOT', negate='IS') - else: - raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") - else: - obj = self._check_literal(obj) - - return _BinaryExpression(self._compare_self(), obj, operator, type=sqltypes.Boolean, negate=negate) - - def _operate(self, operator, obj): - if _is_literal(obj): - obj = self._bind_param(obj) - return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj)) - - def _compare_self(self): - """Allow ``ColumnImpl`` to return its ``Column`` object for - usage in ``ClauseElements``, all others to just return self. - """ + + def compare_self(self): + """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions.""" return self @@ -2398,7 +2451,7 @@ class _Label(ColumnElement): _label = property(lambda s: s.name) orig_set = property(lambda s:s.obj.orig_set) - def _compare_self(self): + def compare_self(self): return self.obj def _copy_internals(self): diff --git a/test/orm/query.py b/test/orm/query.py index 57d533a91c..c5c48f4bb4 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -1,8 +1,10 @@ from sqlalchemy import * +from sqlalchemy import ansisql from sqlalchemy.orm import * import testbase from testbase import Table, Column from fixtures import * +import operator class Base(object): def __init__(self, **kwargs): @@ -97,10 +99,63 @@ class GetTest(QueryTest): mapper(LocalFoo, table) assert create_session().query(LocalFoo).get(ustring) == LocalFoo(id=ustring, data=ustring) +class OperatorTest(QueryTest): + """test sql.Comparator implementation for MapperProperties""" + + def _test(self, clause, expected): + c = str(clause.compile(dialect=ansisql.ANSIDialect())) + assert c == expected, "%s != %s" % (c, expected) + + def test_arithmetic(self): + create_session().query(User) + for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'), + (operator.sub, '-'), (operator.div, '/'), + ): + for (lhs, rhs, res) in ( + ('a', User.id, ':users_id %s users.id'), + ('a', literal('b'), ':literal %s :literal_1'), + (User.id, 'b', 'users.id %s :users_id'), + (User.id, literal('b'), 'users.id %s :literal'), + (User.id, User.id, 'users.id %s users.id'), + (literal('a'), 'b', ':literal %s :literal_1'), + (literal('a'), User.id, ':literal %s users.id'), + (literal('a'), literal('b'), ':literal %s :literal_1'), + ): + self._test(py_op(lhs, rhs), res % sql_op) + + def test_comparison(self): + create_session().query(User) + for (py_op, fwd_op, rev_op) in ((operator.lt, '<', '>'), + (operator.gt, '>', '<'), + (operator.eq, '=', '='), + (operator.ne, '!=', '!='), + (operator.le, '<=', '>='), + (operator.ge, '>=', '<=')): + for (lhs, rhs, l_sql, r_sql) in ( + ('a', User.id, ':users_id', 'users.id'), + ('a', literal('b'), ':literal_1', ':literal'), # note swap! + (User.id, 'b', 'users.id', ':users_id'), + (User.id, literal('b'), 'users.id', ':literal'), + (User.id, User.id, 'users.id', 'users.id'), + (literal('a'), 'b', ':literal', ':literal_1'), + (literal('a'), User.id, ':literal', 'users.id'), + (literal('a'), literal('b'), ':literal', ':literal_1'), + ): + + # the compiled clause should match either (e.g.): + # 'a' < 'b' -or- 'b' > 'a'. + compiled = str(py_op(lhs, rhs).compile(dialect=ansisql.ANSIDialect())) + fwd_sql = "%s %s %s" % (l_sql, fwd_op, r_sql) + rev_sql = "%s %s %s" % (r_sql, rev_op, l_sql) + + self.assert_(compiled == fwd_sql or compiled == rev_sql, + "\n'" + compiled + "'\n does not match\n'" + + fwd_sql + "'\n or\n'" + rev_sql + "'") + class CompileTest(QueryTest): def test_deferred(self): session = create_session() - s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), addresses.c.user_id==users.c.id)).compile() + s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).compile() l = session.query(User).instances(s.execute(emailad = 'jack@bean.com')) assert [User(id=7)] == l @@ -109,7 +164,7 @@ class SliceTest(QueryTest): def test_first(self): assert User(id=7) == create_session().query(User).first() - assert create_session().query(User).filter(users.c.id==27).first() is None + assert create_session().query(User).filter(User.id==27).first() is None # more slice tests are available in test/orm/generative.py @@ -122,7 +177,7 @@ class TextTest(QueryTest): assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter("id=9").all() - assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(users.c.id==9).all() + assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(User.id==9).all() def test_binds(self): assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all() @@ -139,14 +194,8 @@ class FilterTest(QueryTest): assert User(id=8) == create_session().query(User)[1] def test_onefilter(self): - assert [User(id=8), User(id=9)] == create_session().query(User).filter(users.c.name.endswith('ed')).all() + assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.name.endswith('ed')).all() - def test_typecheck(self): - try: - create_session().query(User).filter(User.name==5) - assert False - except exceptions.ArgumentError, e: - assert str(e) == "filter() argument must be of type sqlalchemy.sql.ClauseElement or string" class CountTest(QueryTest): def test_basic(self): @@ -163,7 +212,7 @@ class TextTest(QueryTest): assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter("id=9").all() - assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(users.c.id==9).all() + assert [User(id=9)] == create_session().query(User).filter("name='fred'").filter(User.id==9).all() def test_binds(self): assert [User(id=8), User(id=9)] == create_session().query(User).filter("id in (:id1, :id2)").params(id1=8, id2=9).all() diff --git a/test/sql/select.py b/test/sql/select.py index 8c1b9da7d6..d5b00e1dab 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -267,19 +267,6 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A ) def testoperators(self): - self.runtest( - table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name = :mytable_name" - ) - - self.runtest( - table1.select((table1.c.myid != 12) & ~table1.c.name), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name" - ) - - self.runtest( - literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2" - ) # exercise arithmetic operators for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'), @@ -325,6 +312,25 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A "\n'" + compiled + "'\n does not match\n'" + fwd_sql + "'\n or\n'" + rev_sql + "'") + self.runtest( + table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND mytable.name != :mytable_name" + ) + + self.runtest( + table1.select((table1.c.myid != 12) & ~and_(table1.c.name=='john', table1.c.name=='ed', table1.c.name=='fred')), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT (mytable.name = :mytable_name AND mytable.name = :mytable_name_1 AND mytable.name = :mytable_name_2)" + ) + + self.runtest( + table1.select((table1.c.myid != 12) & ~table1.c.name), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name" + ) + + self.runtest( + literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2" + ) + # test the op() function, also that its results are further usable in expressions self.runtest( table1.select(table1.c.myid.op('hoho')(12)==14), @@ -978,8 +984,8 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE "SELECT op.field FROM op WHERE :literal + (op.field IN (:op_field, :op_field_1))") self.runtest(table.select((5 + table.c.field).in_(5,6)), "SELECT op.field FROM op WHERE :op_field + op.field IN (:literal, :literal_1)") - self.runtest(table.select(not_(table.c.field == 5)), - "SELECT op.field FROM op WHERE NOT op.field = :op_field") + self.runtest(table.select(not_(and_(table.c.field == 5, table.c.field == 7))), + "SELECT op.field FROM op WHERE NOT (op.field = :op_field AND op.field = :op_field_1)") self.runtest(table.select(not_(table.c.field) == 5), "SELECT op.field FROM op WHERE (NOT op.field) = :literal") self.runtest(table.select((table.c.field == table.c.field).between(False, True)),