From 37fad88b84db61fba0a09a1c76bcf95d055aa6e2 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 13 Aug 2012 16:18:12 -0400 Subject: [PATCH] move the whole thing to TypeEngine. the feature is pretty much for free like this. --- lib/sqlalchemy/schema.py | 12 +-- lib/sqlalchemy/sql/expression.py | 54 +++-------- lib/sqlalchemy/types.py | 93 +++++++++++------- test/sql/test_operators.py | 161 +++++++++++++++++++++++++------ 4 files changed, 200 insertions(+), 120 deletions(-) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 683a93d98b..e79e12f7b8 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -765,12 +765,6 @@ class Column(SchemaItem, expression.ColumnClause): setups, such as the one demonstrated in the ORM documentation at :ref:`post_update`. - :param comparator_factory: a :class:`.operators.ColumnOperators` subclass - which will produce custom operator behavior. - - .. versionadded: 0.8 support for pluggable operators in - core column expressions. - :param default: A scalar, Python callable, or :class:`~sqlalchemy.sql.expression.ClauseElement` representing the *default value* for this column, which will be invoked upon insert @@ -891,9 +885,7 @@ class Column(SchemaItem, expression.ColumnClause): no_type = type_ is None - super(Column, self).__init__(name, None, type_, - comparator_factory= - kwargs.pop('comparator_factory', None)) + super(Column, self).__init__(name, None, type_) self.key = kwargs.pop('key', name) self.primary_key = kwargs.pop('primary_key', False) self.nullable = kwargs.pop('nullable', not self.primary_key) @@ -1082,7 +1074,6 @@ class Column(SchemaItem, expression.ColumnClause): name=self.name, type_=self.type, key = self.key, - comparator_factory = self.comparator_factory, primary_key = self.primary_key, nullable = self.nullable, unique = self.unique, @@ -1121,7 +1112,6 @@ class Column(SchemaItem, expression.ColumnClause): key = key if key else name if name else self.key, primary_key = self.primary_key, nullable = self.nullable, - comparator_factory = self.comparator_factory, quote=self.quote, _proxies=[self], *fk) except TypeError, e: # Py3K diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 84d7c1a299..b92ec45292 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -1918,7 +1918,7 @@ class _DefaultColumnComparator(ColumnOperators): def __operate(self, expr, op, obj, reverse=False): obj = self._check_literal(expr, op, obj) - comparator_factory = None + if reverse: left, right = obj, expr else: @@ -1927,25 +1927,13 @@ class _DefaultColumnComparator(ColumnOperators): if left.type is None: op, result_type = sqltypes.NULLTYPE._adapt_expression(op, right.type) - result_type = sqltypes.to_instance(result_type) - if right.type._compare_type_affinity(result_type): - comparator_factory = right.comparator_factory elif right.type is None: op, result_type = left.type._adapt_expression(op, sqltypes.NULLTYPE) - result_type = sqltypes.to_instance(result_type) - if left.type._compare_type_affinity(result_type): - comparator_factory = left.comparator_factory else: op, result_type = left.type._adapt_expression(op, right.type) - result_type = sqltypes.to_instance(result_type) - if left.type._compare_type_affinity(result_type): - comparator_factory = left.comparator_factory - elif right.type._compare_type_affinity(result_type): - comparator_factory = right.comparator_factory - return BinaryExpression(left, right, op, type_=result_type, - comparator_factory=comparator_factory) + return BinaryExpression(left, right, op, type_=result_type) def __scalar(self, expr, op, fn, **kw): return fn(expr) @@ -2159,23 +2147,20 @@ class ColumnElement(ClauseElement, ColumnOperators): __visit_name__ = 'column' primary_key = False foreign_keys = [] + type = None quote = None _label = None _key_label = None _alt_names = () - comparator = None - - class Comparator(operators.ColumnOperators): - def __init__(self, expr): - self.expr = expr - - def operate(self, op, *other, **kwargs): - return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs) - - def reverse_operate(self, op, other, **kwargs): - return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, other, - **kwargs) + @util.memoized_property + def comparator(self): + if self.type is None: + return None + elif self.type.comparator_factory is not None: + return self.type.comparator_factory(self) + else: + return None def __getattr__(self, key): if self.comparator is None: @@ -3558,7 +3543,7 @@ class BinaryExpression(ColumnElement): __visit_name__ = 'binary' def __init__(self, left, right, operator, type_=None, - negate=None, modifiers=None, comparator_factory=None): + negate=None, modifiers=None): # allow compatibility with libraries that # refer to BinaryExpression directly and pass strings if isinstance(operator, basestring): @@ -3569,10 +3554,6 @@ class BinaryExpression(ColumnElement): self.type = sqltypes.to_instance(type_) self.negate = negate - self.comparator_factory = comparator_factory - if comparator_factory is not None: - self.comparator = comparator_factory(self) - if modifiers is None: self.modifiers = {} else: @@ -4209,11 +4190,6 @@ class ColumnClause(Immutable, ColumnElement): :func:`literal_column()` function is usually used to create such a :class:`.ColumnClause`. - :param comparator_factory: a :class:`.operators.ColumnOperators` subclass - which will produce custom operator behavior. - - .. versionadded: 0.8 support for pluggable operators in - core column expressions. """ __visit_name__ = 'column' @@ -4222,15 +4198,11 @@ class ColumnClause(Immutable, ColumnElement): _memoized_property = util.group_expirable_memoized_property() - def __init__(self, text, selectable=None, type_=None, is_literal=False, - comparator_factory=None): + def __init__(self, text, selectable=None, type_=None, is_literal=False): self.key = self.name = text self.table = selectable self.type = sqltypes.to_instance(type_) self.is_literal = is_literal - self.comparator_factory = comparator_factory - if comparator_factory: - self.comparator = comparator_factory(self) def _compare_name_for_result(self, other): if self.table is not None and hasattr(other, 'proxy_set'): diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index a79bf03290..b6fdb3261c 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -25,6 +25,7 @@ import codecs from . import exc, schema, util, processors, events, event from .sql import operators +from .sql.expression import _DEFAULT_COMPARATOR from .util import pickle from .util.compat import decimal from .sql.visitors import Visitable @@ -41,6 +42,23 @@ class AbstractType(Visitable): class TypeEngine(AbstractType): """Base for built-in types.""" + class Comparator(operators.ColumnOperators): + def __init__(self, expr): + self.expr = expr + + def operate(self, op, *other, **kwargs): + return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, other, + **kwargs) + + comparator_factory = None + """A :class:`.TypeEngine.Comparator` class which will apply + to operations performed by owning :class:`.ColumnElement` objects. + + """ + def copy_value(self, value): return value @@ -451,6 +469,9 @@ class TypeDecorator(TypeEngine): "type being decorated") self.impl = to_instance(self.__class__.impl, *args, **kwargs) + @property + def comparator_factory(self): + return self.impl.comparator_factory def _gen_dialect_impl(self, dialect): """ @@ -700,11 +721,9 @@ class TypeDecorator(TypeEngine): return self.impl.compare_values(x, y) def _adapt_expression(self, op, othertype): - """ - #todo - """ - op, typ =self.impl._adapt_expression(op, othertype) - if typ is self.impl: + op, typ = self.impl._adapt_expression(op, othertype) + typ = to_instance(typ) + if typ._compare_type_affinity(self.impl): return op, self else: return op, typ @@ -844,7 +863,7 @@ class _DateAffinity(object): othertype = othertype._type_affinity return op, \ self._expression_adaptations.get(op, self._blank_dict).\ - get(othertype, self) + get(othertype, NULLTYPE) class String(Concatenable, TypeEngine): """The base for all string and character types. @@ -1136,26 +1155,26 @@ class Integer(_DateAffinity, TypeEngine): return { operators.add:{ Date:Date, - Integer:Integer, + Integer:self.__class__, Numeric:Numeric, }, operators.mul:{ Interval:Interval, - Integer:Integer, + Integer:self.__class__, Numeric:Numeric, }, # Py2K operators.div:{ - Integer:Integer, + Integer:self.__class__, Numeric:Numeric, }, # end Py2K operators.truediv:{ - Integer:Integer, + Integer:self.__class__, Numeric:Numeric, }, operators.sub:{ - Integer:Integer, + Integer:self.__class__, Numeric:Numeric, }, } @@ -1311,26 +1330,26 @@ class Numeric(_DateAffinity, TypeEngine): return { operators.mul:{ Interval:Interval, - Numeric:Numeric, - Integer:Numeric, + Numeric:self.__class__, + Integer:self.__class__, }, # Py2K operators.div:{ - Numeric:Numeric, - Integer:Numeric, + Numeric:self.__class__, + Integer:self.__class__, }, # end Py2K operators.truediv:{ - Numeric:Numeric, - Integer:Numeric, + Numeric:self.__class__, + Integer:self.__class__, }, operators.add:{ - Numeric:Numeric, - Integer:Numeric, + Numeric:self.__class__, + Integer:self.__class__, }, operators.sub:{ - Numeric:Numeric, - Integer:Numeric, + Numeric:self.__class__, + Integer:self.__class__, } } @@ -1380,21 +1399,21 @@ class Float(Numeric): return { operators.mul:{ Interval:Interval, - Numeric:Float, + Numeric:self.__class__, }, # Py2K operators.div:{ - Numeric:Float, + Numeric:self.__class__, }, # end Py2K operators.truediv:{ - Numeric:Float, + Numeric:self.__class__, }, operators.add:{ - Numeric:Float, + Numeric:self.__class__, }, operators.sub:{ - Numeric:Float, + Numeric:self.__class__, } } @@ -1434,10 +1453,10 @@ class DateTime(_DateAffinity, TypeEngine): def _expression_adaptations(self): return { operators.add:{ - Interval:DateTime, + Interval:self.__class__, }, operators.sub:{ - Interval:DateTime, + Interval:self.__class__, DateTime:Interval, }, } @@ -1459,13 +1478,13 @@ class Date(_DateAffinity,TypeEngine): def _expression_adaptations(self): return { operators.add:{ - Integer:Date, + Integer:self.__class__, Interval:DateTime, Time:DateTime, }, operators.sub:{ # date - integer = date - Integer:Date, + Integer:self.__class__, # date - date = integer. Date:Integer, @@ -1500,11 +1519,11 @@ class Time(_DateAffinity,TypeEngine): return { operators.add:{ Date:DateTime, - Interval:Time + Interval:self.__class__ }, operators.sub:{ Time:Interval, - Interval:Time, + Interval:self.__class__, }, } @@ -2050,22 +2069,22 @@ class Interval(_DateAffinity, TypeDecorator): return { operators.add:{ Date:DateTime, - Interval:Interval, + Interval:self.__class__, DateTime:DateTime, Time:Time, }, operators.sub:{ - Interval:Interval + Interval:self.__class__ }, operators.mul:{ - Numeric:Interval + Numeric:self.__class__ }, operators.truediv: { - Numeric:Interval + Numeric:self.__class__ }, # Py2K operators.div: { - Numeric:Interval + Numeric:self.__class__ } # end Py2K } diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 6e1966a587..02acda0f11 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -4,7 +4,7 @@ from sqlalchemy.sql.expression import BinaryExpression, \ ClauseList, Grouping, _DefaultColumnComparator from sqlalchemy.sql import operators from sqlalchemy.schema import Column, Table, MetaData -from sqlalchemy.types import Integer +from sqlalchemy.types import Integer, TypeEngine, TypeDecorator class DefaultColumnComparatorTest(fixtures.TestBase): @@ -54,15 +54,44 @@ class DefaultColumnComparatorTest(fixtures.TestBase): collate(left, right) ) -class CustomComparatorTest(fixtures.TestBase): - def _add_override_factory(self): - class MyComparator(Column.Comparator): - def __init__(self, expr): - self.expr = expr +class _CustomComparatorTests(object): + def test_override_builtin(self): + c1 = Column('foo', self._add_override_factory()) + self._assert_add_override(c1) + + def test_column_proxy(self): + t = Table('t', MetaData(), + Column('foo', self._add_override_factory()) + ) + proxied = t.select().c.foo + self._assert_add_override(proxied) - def __add__(self, other): - return self.expr.op("goofy")(other) - return MyComparator + def test_alias_proxy(self): + t = Table('t', MetaData(), + Column('foo', self._add_override_factory()) + ) + proxied = t.alias().c.foo + self._assert_add_override(proxied) + + def test_binary_propagate(self): + c1 = Column('foo', self._add_override_factory()) + self._assert_add_override(c1 - 6) + + def test_reverse_binary_propagate(self): + c1 = Column('foo', self._add_override_factory()) + self._assert_add_override(6 - c1) + + def test_binary_multi_propagate(self): + c1 = Column('foo', self._add_override_factory(True)) + self._assert_add_override((c1 - 6) + 5) + + def test_no_binary_multi_propagate_wo_adapt(self): + c1 = Column('foo', self._add_override_factory()) + self._assert_not_add_override((c1 - 6) + 5) + + def test_no_boolean_propagate(self): + c1 = Column('foo', self._add_override_factory()) + self._assert_not_add_override(c1 == 56) def _assert_add_override(self, expr): assert (expr + 5).compare( @@ -74,32 +103,102 @@ class CustomComparatorTest(fixtures.TestBase): expr.op("goofy")(5) ) - def test_override_builtin(self): - c1 = Column('foo', Integer, - comparator_factory=self._add_override_factory()) - self._assert_add_override(c1) +class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase): + def _add_override_factory(self, include_adapt=False): - def test_column_proxy(self): - t = Table('t', MetaData(), - Column('foo', Integer, - comparator_factory=self._add_override_factory())) - proxied = t.select().c.foo - self._assert_add_override(proxied) + class MyInteger(Integer): + class comparator_factory(TypeEngine.Comparator): + def __init__(self, expr): + self.expr = expr - def test_binary_propagate(self): - c1 = Column('foo', Integer, - comparator_factory=self._add_override_factory()) + def __add__(self, other): + return self.expr.op("goofy")(other) - self._assert_add_override(c1 - 6) + if include_adapt: + def _adapt_expression(self, op, othertype): + if op.__name__ == 'custom_op': + return op, self + else: + return super(MyInteger, self)._adapt_expression( + op, othertype) - def test_binary_multi_propagate(self): - c1 = Column('foo', Integer, - comparator_factory=self._add_override_factory()) - self._assert_add_override((c1 - 6) + 5) + return MyInteger - def test_no_boolean_propagate(self): - c1 = Column('foo', Integer, - comparator_factory=self._add_override_factory()) - self._assert_not_add_override(c1 == 56) +class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase): + def _add_override_factory(self, include_adapt=False): + + class MyInteger(TypeDecorator): + impl = Integer + + class comparator_factory(TypeEngine.Comparator): + def __init__(self, expr): + self.expr = expr + + def __add__(self, other): + return self.expr.op("goofy")(other) + + if include_adapt: + def _adapt_expression(self, op, othertype): + if op.__name__ == 'custom_op': + return op, self + else: + return super(MyInteger, self)._adapt_expression( + op, othertype) + + return MyInteger + + +class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBase): + def _add_override_factory(self, include_adapt=False): + class MyInteger(Integer): + class comparator_factory(TypeEngine.Comparator): + def __init__(self, expr): + self.expr = expr + + def __add__(self, other): + return self.expr.op("goofy")(other) + + if include_adapt: + def _adapt_expression(self, op, othertype): + if op.__name__ == 'custom_op': + return op, self + else: + return super(MyInteger, self)._adapt_expression( + op, othertype) + + class MyDecInteger(TypeDecorator): + impl = MyInteger + + return MyDecInteger + +class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase): + def _add_override_factory(self, include_adapt=False): + class MyInteger(Integer): + class comparator_factory(TypeEngine.Comparator): + def __init__(self, expr): + self.expr = expr + + def foob(self, other): + return self.expr.op("foob")(other) + + if include_adapt: + def _adapt_expression(self, op, othertype): + if op.__name__ == 'custom_op': + return op, self + else: + return super(MyInteger, self)._adapt_expression( + op, othertype) + + return MyInteger + + def _assert_add_override(self, expr): + assert (expr.foob(5)).compare( + expr.op("foob")(5) + ) + + def _assert_not_add_override(self, expr): + assert not hasattr(expr, "foob") + def test_no_binary_multi_propagate_wo_adapt(self): + pass \ No newline at end of file -- 2.47.3