From: Mike Bayer Date: Mon, 13 Aug 2012 18:37:58 +0000 (-0400) Subject: - develop new system of applying custom operators to ColumnElement classes. resembles X-Git-Tag: rel_0_8_0b1~255 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=d9b5991f9c21836e1d48555b949a402fc4ce6b35;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - develop new system of applying custom operators to ColumnElement classes. resembles that of the ORM so far. --- diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index d0732b9135..f2014e964a 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -237,6 +237,9 @@ class PropComparator(operators.ColumnOperators): return self.__class__(self.prop, self.mapper, adapter) + def __getattr__(self, key): + return getattr(self.__clause_element__(), key) + @staticmethod def any_op(a, b, **kwargs): return a.any(b, **kwargs) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index e79e12f7b8..683a93d98b 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -765,6 +765,12 @@ class Column(SchemaItem, expression.ColumnClause): setups, such as the one demonstrated in the ORM documentation at :ref:`post_update`. + :param comparator_factory: a :class:`.operators.ColumnOperators` subclass + which will produce custom operator behavior. + + .. versionadded: 0.8 support for pluggable operators in + core column expressions. + :param default: A scalar, Python callable, or :class:`~sqlalchemy.sql.expression.ClauseElement` representing the *default value* for this column, which will be invoked upon insert @@ -885,7 +891,9 @@ class Column(SchemaItem, expression.ColumnClause): no_type = type_ is None - super(Column, self).__init__(name, None, type_) + super(Column, self).__init__(name, None, type_, + comparator_factory= + kwargs.pop('comparator_factory', None)) self.key = kwargs.pop('key', name) self.primary_key = kwargs.pop('primary_key', False) self.nullable = kwargs.pop('nullable', not self.primary_key) @@ -1074,6 +1082,7 @@ class Column(SchemaItem, expression.ColumnClause): name=self.name, type_=self.type, key = self.key, + comparator_factory = self.comparator_factory, primary_key = self.primary_key, nullable = self.nullable, unique = self.unique, @@ -1112,6 +1121,7 @@ class Column(SchemaItem, expression.ColumnClause): key = key if key else name if name else self.key, primary_key = self.primary_key, nullable = self.nullable, + comparator_factory = self.comparator_factory, quote=self.quote, _proxies=[self], *fk) except TypeError, e: # Py3K diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 2c1c047b57..84d7c1a299 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1773,32 +1773,6 @@ class ClauseElement(Visitable): return self - @util.deprecated('0.7', - 'Only SQL expressions which subclass ' - ':class:`.Executable` may provide the ' - ':func:`.execute` method.') - def execute(self, *multiparams, **params): - """Compile and execute this :class:`.ClauseElement`. - - """ - e = self.bind - if e is None: - label = getattr(self, 'description', self.__class__.__name__) - msg = ('This %s does not support direct execution.' % label) - raise exc.UnboundExecutionError(msg) - return e._execute_clauseelement(self, multiparams, params) - - @util.deprecated('0.7', - 'Only SQL expressions which subclass ' - ':class:`.Executable` may provide the ' - ':func:`.scalar` method.') - def scalar(self, *multiparams, **params): - """Compile and execute this :class:`.ClauseElement`, returning - the result's scalar representation. - - """ - return self.execute(*multiparams, **params).scalar() - def compile(self, bind=None, dialect=None, **kw): """Compile this SQL expression. @@ -1901,7 +1875,7 @@ class Immutable(object): return self -class CompareMixin(ColumnOperators): +class _DefaultColumnComparator(ColumnOperators): """Defines comparison and math operations. The :class:`.CompareMixin` is part of the interface provided @@ -1913,91 +1887,74 @@ class CompareMixin(ColumnOperators): """ - def __compare(self, op, obj, negate=None, reverse=False, + def __compare(self, expr, op, obj, negate=None, reverse=False, **kwargs ): if obj is None or isinstance(obj, Null): if op == operators.eq: - return BinaryExpression(self, null(), operators.is_, + return BinaryExpression(expr, null(), operators.is_, negate=operators.isnot) elif op == operators.ne: - return BinaryExpression(self, null(), operators.isnot, + return BinaryExpression(expr, null(), operators.isnot, negate=operators.is_) else: raise exc.ArgumentError("Only '='/'!=' operators can " "be used with NULL") else: - obj = self._check_literal(op, obj) + obj = self._check_literal(expr, op, obj) if reverse: return BinaryExpression(obj, - self, + expr, op, type_=sqltypes.BOOLEANTYPE, negate=negate, modifiers=kwargs) else: - return BinaryExpression(self, + return BinaryExpression(expr, obj, op, type_=sqltypes.BOOLEANTYPE, negate=negate, modifiers=kwargs) - def __operate(self, op, obj, reverse=False): - obj = self._check_literal(op, obj) - + def __operate(self, expr, op, obj, reverse=False): + obj = self._check_literal(expr, op, obj) + comparator_factory = None if reverse: - left, right = obj, self + left, right = obj, expr else: - left, right = self, obj + left, right = expr, obj if left.type is None: op, result_type = sqltypes.NULLTYPE._adapt_expression(op, right.type) + result_type = sqltypes.to_instance(result_type) + if right.type._compare_type_affinity(result_type): + comparator_factory = right.comparator_factory elif right.type is None: op, result_type = left.type._adapt_expression(op, sqltypes.NULLTYPE) + result_type = sqltypes.to_instance(result_type) + if left.type._compare_type_affinity(result_type): + comparator_factory = left.comparator_factory else: - op, result_type = left.type._adapt_expression(op, - right.type) - return BinaryExpression(left, right, op, type_=result_type) + op, result_type = left.type._adapt_expression(op, right.type) + result_type = sqltypes.to_instance(result_type) + if left.type._compare_type_affinity(result_type): + comparator_factory = left.comparator_factory + elif right.type._compare_type_affinity(result_type): + comparator_factory = right.comparator_factory - # a mapping of operators with the method they use, along with their negated - # operator for comparison operators - operators = { - "add": (__operate,), - "mul": (__operate,), - "sub": (__operate,), - "div": (__operate,), - "mod": (__operate,), - "truediv": (__operate,), - "custom_op": (__operate,), - "lt": (__compare, operators.ge), - "le": (__compare, operators.gt), - "ne": (__compare, operators.eq), - "gt": (__compare, operators.le), - "ge": (__compare, operators.lt), - "eq": (__compare, operators.ne), - "like_op": (__compare, operators.notlike_op), - "ilike_op": (__compare, operators.notilike_op), - } + return BinaryExpression(left, right, op, type_=result_type, + comparator_factory=comparator_factory) - def operate(self, op, *other, **kwargs): - o = CompareMixin.operators[op.__name__] - return o[0](self, op, other[0], *o[1:], **kwargs) + def __scalar(self, expr, op, fn, **kw): + return fn(expr) - def reverse_operate(self, op, other, **kwargs): - o = CompareMixin.operators[op.__name__] - return o[0](self, op, other, reverse=True, *o[1:], **kwargs) - - def in_(self, other): - """See :meth:`.ColumnOperators.in_`.""" - return self._in_impl(operators.in_op, operators.notin_op, other) - - def _in_impl(self, op, negate_op, seq_or_selectable): + def _in_impl(self, expr, op, seq_or_selectable, negate_op): seq_or_selectable = _clause_element_as_expr(seq_or_selectable) if isinstance(seq_or_selectable, ScalarSelect): - return self.__compare(op, seq_or_selectable, + return self.__compare(expr, op, seq_or_selectable, negate=negate_op) elif isinstance(seq_or_selectable, SelectBase): @@ -2006,10 +1963,10 @@ class CompareMixin(ColumnOperators): # as_scalar() to produce a multi- column selectable that # does not export itself as a FROM clause - return self.__compare(op, seq_or_selectable.as_scalar(), + return self.__compare(expr, op, seq_or_selectable.as_scalar(), negate=negate_op) elif isinstance(seq_or_selectable, (Selectable, TextClause)): - return self.__compare(op, seq_or_selectable, + return self.__compare(expr, op, seq_or_selectable, negate=negate_op) @@ -2018,12 +1975,12 @@ class CompareMixin(ColumnOperators): args = [] for o in seq_or_selectable: if not _is_literal(o): - if not isinstance(o, CompareMixin): + if not isinstance(o, ColumnOperators): raise exc.InvalidRequestError('in() function accept' 's either a list of non-selectable values, ' 'or a selectable: %r' % o) else: - o = self._bind_param(op, o) + o = expr._bind_param(op, o) args.append(o) if len(args) == 0: @@ -2037,18 +1994,17 @@ class CompareMixin(ColumnOperators): 'empty sequence. This results in a ' 'contradiction, which nonetheless can be ' 'expensive to evaluate. Consider alternative ' - 'strategies for improved performance.' % self) - return self != self + 'strategies for improved performance.' % expr) + return expr != expr - return self.__compare(op, + return self.__compare(expr, op, ClauseList(*args).self_group(against=op), negate=negate_op) - - def __neg__(self): + def _neg_impl(self): """See :meth:`.ColumnOperators.__neg__`.""" - return UnaryExpression(self, operator=operators.neg) + return UnaryExpression(self.expr, operator=operators.neg) - def startswith(self, other, escape=None): + def _startswith_impl(self, other, escape=None): """See :meth:`.ColumnOperators.startswith`.""" # use __radd__ to force string concat behavior return self.__compare( @@ -2058,7 +2014,7 @@ class CompareMixin(ColumnOperators): ), escape=escape) - def endswith(self, other, escape=None): + def _endswith_impl(self, other, escape=None): """See :meth:`.ColumnOperators.endswith`.""" return self.__compare( operators.like_op, @@ -2066,7 +2022,7 @@ class CompareMixin(ColumnOperators): self._check_literal(operators.like_op, other), escape=escape) - def contains(self, other, escape=None): + def _contains_impl(self, other, escape=None): """See :meth:`.ColumnOperators.contains`.""" return self.__compare( operators.like_op, @@ -2075,44 +2031,18 @@ class CompareMixin(ColumnOperators): literal_column("'%'", type_=sqltypes.String), escape=escape) - def match(self, other): + def _match_impl(self, other): """See :meth:`.ColumnOperators.match`.""" return self.__compare(operators.match_op, self._check_literal(operators.match_op, other)) - def label(self, name): - """Produce a column label, i.e. `` AS ``. - - This is a shortcut to the :func:`~.expression.label` function. - - if 'name' is None, an anonymous label name will be generated. - - """ - return Label(name, self, self.type) - - def desc(self): - """See :meth:`.ColumnOperators.desc`.""" - return desc(self) - - def asc(self): - """See :meth:`.ColumnOperators.asc`.""" - return asc(self) - - def nullsfirst(self): - """See :meth:`.ColumnOperators.nullsfirst`.""" - return nullsfirst(self) - - def nullslast(self): - """See :meth:`.ColumnOperators.nullslast`.""" - return nullslast(self) - - def distinct(self): + def _distinct_impl(self): """See :meth:`.ColumnOperators.distinct`.""" return UnaryExpression(self, operator=operators.distinct_op, type_=self.type) - def between(self, cleft, cright): + def _between_impl(self, cleft, cright): """See :meth:`.ColumnOperators.between`.""" return BinaryExpression( self, @@ -2123,23 +2053,52 @@ class CompareMixin(ColumnOperators): group=False), operators.between_op) - def collate(self, collation): - """See :meth:`.ColumnOperators.collate`.""" + def _collate_impl(self, expr, op, other): + return collate(expr, other) + + # a mapping of operators with the method they use, along with their negated + # operator for comparison operators + operators = { + "add": (__operate,), + "mul": (__operate,), + "sub": (__operate,), + "div": (__operate,), + "mod": (__operate,), + "truediv": (__operate,), + "custom_op": (__operate,), + "lt": (__compare, operators.ge), + "le": (__compare, operators.gt), + "ne": (__compare, operators.eq), + "gt": (__compare, operators.le), + "ge": (__compare, operators.lt), + "eq": (__compare, operators.ne), + "like_op": (__compare, operators.notlike_op), + "ilike_op": (__compare, operators.notilike_op), + "desc_op": (__scalar, desc), + "asc_op": (__scalar, asc), + "nullsfirst_op": (__scalar, nullsfirst), + "nullslast_op": (__scalar, nullslast), + "in_op": (_in_impl, operators.notin_op), + "collate": (_collate_impl,), + } + + def operate(self, expr, op, *other, **kwargs): + o = self.operators[op.__name__] + return o[0](self, expr, op, *(other + o[1:]), **kwargs) + + def reverse_operate(self, expr, op, other, **kwargs): + o = self.operators[op.__name__] + return o[0](self, expr, op, other, reverse=True, *o[1:], **kwargs) - return collate(self, collation) - def _bind_param(self, operator, obj): - return BindParameter(None, obj, - _compared_to_operator=operator, - _compared_to_type=self.type, unique=True) - def _check_literal(self, operator, other): + def _check_literal(self, expr, operator, other): if isinstance(other, BindParameter) and \ isinstance(other.type, sqltypes.NullType): # TODO: perhaps we should not mutate the incoming bindparam() # here and instead make a copy of it. this might # be the only place that we're mutating an incoming construct. - other.type = self.type + other.type = expr.type return other elif hasattr(other, '__clause_element__'): other = other.__clause_element__() @@ -2147,14 +2106,16 @@ class CompareMixin(ColumnOperators): other = other.as_scalar() return other elif not isinstance(other, ClauseElement): - return self._bind_param(operator, other) + return expr._bind_param(operator, other) elif isinstance(other, (SelectBase, Alias)): return other.as_scalar() else: return other +_DEFAULT_COMPARATOR = _DefaultColumnComparator() -class ColumnElement(ClauseElement, CompareMixin): + +class ColumnElement(ClauseElement, ColumnOperators): """Represent an element that is usable within the "column clause" portion of a ``SELECT`` statement. @@ -2203,6 +2164,32 @@ class ColumnElement(ClauseElement, CompareMixin): _key_label = None _alt_names = () + comparator = None + + class Comparator(operators.ColumnOperators): + def __init__(self, expr): + self.expr = expr + + def operate(self, op, *other, **kwargs): + return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, other, + **kwargs) + + def __getattr__(self, key): + if self.comparator is None: + raise AttributeError(key) + try: + return getattr(self.comparator, key) + except AttributeError: + raise AttributeError( + 'Neither %r object nor %r object has an attribute %r' % ( + type(self).__name__, + type(self.comparator).__name__, + key) + ) + @property def expression(self): """Return a column expression. @@ -2212,6 +2199,23 @@ class ColumnElement(ClauseElement, CompareMixin): """ return self + def operate(self, op, *other, **kwargs): + if self.comparator: + return op(self.comparator, *other, **kwargs) + else: + return _DEFAULT_COMPARATOR.operate(self, op, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + if self.comparator: + return op(other, self.comparator, **kwargs) + else: + return _DEFAULT_COMPARATOR.reverse_operate(self, op, *other, **kwargs) + + def _bind_param(self, operator, obj): + return BindParameter(None, obj, + _compared_to_operator=operator, + _compared_to_type=self.type, unique=True) + @property def _select_iterable(self): return (self, ) @@ -2292,6 +2296,16 @@ class ColumnElement(ClauseElement, CompareMixin): else: return False + def label(self, name): + """Produce a column label, i.e. `` AS ``. + + This is a shortcut to the :func:`~.expression.label` function. + + if 'name' is None, an anonymous label name will be generated. + + """ + return Label(name, self, self.type) + @util.memoized_property def anon_label(self): """provides a constant 'anonymous label' for this ColumnElement. @@ -3544,7 +3558,7 @@ class BinaryExpression(ColumnElement): __visit_name__ = 'binary' def __init__(self, left, right, operator, type_=None, - negate=None, modifiers=None): + negate=None, modifiers=None, comparator_factory=None): # allow compatibility with libraries that # refer to BinaryExpression directly and pass strings if isinstance(operator, basestring): @@ -3554,6 +3568,11 @@ class BinaryExpression(ColumnElement): self.operator = operator self.type = sqltypes.to_instance(type_) self.negate = negate + + self.comparator_factory = comparator_factory + if comparator_factory is not None: + self.comparator = comparator_factory(self) + if modifiers is None: self.modifiers = {} else: @@ -3959,12 +3978,16 @@ class Grouping(ColumnElement): return getattr(self.element, attr) def __getstate__(self): - return {'element':self.element, 'type':self.type} + return {'element': self.element, 'type': self.type} def __setstate__(self, state): self.element = state['element'] self.type = state['type'] + def compare(self, other, **kw): + return isinstance(other, Grouping) and \ + self.element.compare(other.element) + class FromGrouping(FromClause): """Represent a grouping of a FROM clause""" __visit_name__ = 'grouping' @@ -4186,6 +4209,12 @@ class ColumnClause(Immutable, ColumnElement): :func:`literal_column()` function is usually used to create such a :class:`.ColumnClause`. + :param comparator_factory: a :class:`.operators.ColumnOperators` subclass + which will produce custom operator behavior. + + .. versionadded: 0.8 support for pluggable operators in + core column expressions. + """ __visit_name__ = 'column' @@ -4193,11 +4222,15 @@ class ColumnClause(Immutable, ColumnElement): _memoized_property = util.group_expirable_memoized_property() - def __init__(self, text, selectable=None, type_=None, is_literal=False): + def __init__(self, text, selectable=None, type_=None, is_literal=False, + comparator_factory=None): self.key = self.name = text self.table = selectable self.type = sqltypes.to_instance(type_) self.is_literal = is_literal + self.comparator_factory = comparator_factory + if comparator_factory: + self.comparator = comparator_factory(self) def _compare_name_for_result(self, other): if self.table is not None and hasattr(other, 'proxy_set'): @@ -4283,7 +4316,8 @@ class ColumnClause(Immutable, ColumnElement): _as_truncated(name if name else self.name), selectable=selectable, type_=self.type, - is_literal=is_literal + is_literal=is_literal, + comparator_factory=self.comparator_factory ) c.proxies = [self] if selectable._is_clone_of is not None: @@ -5916,7 +5950,6 @@ class ReleaseSavepointClause(_IdentifiedClause): # old names for compatibility _BindParamClause = BindParameter _Label = Label -_CompareMixin = CompareMixin _SelectBase = SelectBase _BinaryExpression = BinaryExpression _Cast = Cast diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 89681fa6a6..f851a5b003 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -176,6 +176,10 @@ class custom_op(object): self.opstring = opstring self.precedence = precedence + def __eq__(self, other): + return isinstance(other, custom_op) and \ + other.opstring == self.opstring + def __call__(self, left, right, **kw): return left.operate(self, right, **kw) diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 658fc77e91..a79bf03290 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -844,7 +844,7 @@ class _DateAffinity(object): othertype = othertype._type_affinity return op, \ self._expression_adaptations.get(op, self._blank_dict).\ - get(othertype, NULLTYPE) + get(othertype, self) class String(Concatenable, TypeEngine): """The base for all string and character types. diff --git a/test/lib/profiling.py b/test/lib/profiling.py index 6142c41c9e..6ca28d4620 100644 --- a/test/lib/profiling.py +++ b/test/lib/profiling.py @@ -95,7 +95,7 @@ class ProfileStatsFile(object): """ def __init__(self): from test.bootstrap.config import options - self.write = options.write_profiles + self.write = options is not None and options.write_profiles dirname, fname = os.path.split(__file__) self.short_fname = "profiles.txt" self.fname = os.path.join(dirname, self.short_fname) diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py new file mode 100644 index 0000000000..6e1966a587 --- /dev/null +++ b/test/sql/test_operators.py @@ -0,0 +1,105 @@ +from test.lib import fixtures +from sqlalchemy.sql import column, desc, asc, literal, collate +from sqlalchemy.sql.expression import BinaryExpression, \ + ClauseList, Grouping, _DefaultColumnComparator +from sqlalchemy.sql import operators +from sqlalchemy.schema import Column, Table, MetaData +from sqlalchemy.types import Integer + +class DefaultColumnComparatorTest(fixtures.TestBase): + + def _do_scalar_test(self, operator, compare_to): + cc = _DefaultColumnComparator() + left = column('left') + assert cc.operate(left, operator).compare( + compare_to(left) + ) + + def _do_operate_test(self, operator): + cc = _DefaultColumnComparator() + left = column('left') + right = column('right') + + assert cc.operate(left, operator, right).compare( + BinaryExpression(left, right, operator) + ) + + def test_desc(self): + self._do_scalar_test(operators.desc_op, desc) + + def test_asc(self): + self._do_scalar_test(operators.asc_op, asc) + + def test_plus(self): + self._do_operate_test(operators.add) + + def test_in(self): + cc = _DefaultColumnComparator() + left = column('left') + assert cc.operate(left, operators.in_op, [1, 2, 3]).compare( + BinaryExpression( + left, + Grouping(ClauseList( + literal(1), literal(2), literal(3) + )), + operators.in_op + ) + ) + + def test_collate(self): + cc = _DefaultColumnComparator() + left = column('left') + right = "some collation" + cc.operate(left, operators.collate, right).compare( + collate(left, right) + ) + +class CustomComparatorTest(fixtures.TestBase): + def _add_override_factory(self): + class MyComparator(Column.Comparator): + def __init__(self, expr): + self.expr = expr + + def __add__(self, other): + return self.expr.op("goofy")(other) + return MyComparator + + def _assert_add_override(self, expr): + assert (expr + 5).compare( + expr.op("goofy")(5) + ) + + def _assert_not_add_override(self, expr): + assert not (expr + 5).compare( + expr.op("goofy")(5) + ) + + def test_override_builtin(self): + c1 = Column('foo', Integer, + comparator_factory=self._add_override_factory()) + self._assert_add_override(c1) + + def test_column_proxy(self): + t = Table('t', MetaData(), + Column('foo', Integer, + comparator_factory=self._add_override_factory())) + proxied = t.select().c.foo + self._assert_add_override(proxied) + + def test_binary_propagate(self): + c1 = Column('foo', Integer, + comparator_factory=self._add_override_factory()) + + self._assert_add_override(c1 - 6) + + def test_binary_multi_propagate(self): + c1 = Column('foo', Integer, + comparator_factory=self._add_override_factory()) + self._assert_add_override((c1 - 6) + 5) + + def test_no_boolean_propagate(self): + c1 = Column('foo', Integer, + comparator_factory=self._add_override_factory()) + + self._assert_not_add_override(c1 == 56) +