]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
move the whole thing to TypeEngine. the feature is pretty much for free like this.
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 13 Aug 2012 20:18:12 +0000 (16:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 13 Aug 2012 20:18:12 +0000 (16:18 -0400)
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/types.py
test/sql/test_operators.py

index 683a93d98b675ca0e97a50bc13e341ce52840b60..e79e12f7b8c957f1cad6723cc8ce19d4efb3ff25 100644 (file)
@@ -765,12 +765,6 @@ class Column(SchemaItem, expression.ColumnClause):
               setups, such as the one demonstrated in the ORM documentation
               at :ref:`post_update`.
 
-        :param comparator_factory: a :class:`.operators.ColumnOperators` subclass
-         which will produce custom operator behavior.
-
-         .. versionadded: 0.8 support for pluggable operators in
-            core column expressions.
-
         :param default: A scalar, Python callable, or
             :class:`~sqlalchemy.sql.expression.ClauseElement` representing the
             *default value* for this column, which will be invoked upon insert
@@ -891,9 +885,7 @@ class Column(SchemaItem, expression.ColumnClause):
 
         no_type = type_ is None
 
-        super(Column, self).__init__(name, None, type_,
-                                    comparator_factory=
-                                        kwargs.pop('comparator_factory', None))
+        super(Column, self).__init__(name, None, type_)
         self.key = kwargs.pop('key', name)
         self.primary_key = kwargs.pop('primary_key', False)
         self.nullable = kwargs.pop('nullable', not self.primary_key)
@@ -1082,7 +1074,6 @@ class Column(SchemaItem, expression.ColumnClause):
                 name=self.name,
                 type_=self.type,
                 key = self.key,
-                comparator_factory = self.comparator_factory,
                 primary_key = self.primary_key,
                 nullable = self.nullable,
                 unique = self.unique,
@@ -1121,7 +1112,6 @@ class Column(SchemaItem, expression.ColumnClause):
                 key = key if key else name if name else self.key,
                 primary_key = self.primary_key,
                 nullable = self.nullable,
-                comparator_factory = self.comparator_factory,
                 quote=self.quote, _proxies=[self], *fk)
         except TypeError, e:
             # Py3K
index 84d7c1a2995d7c05ae7adb9b99c6c68bcfc31763..b92ec4529206c051bf2d08d1c64a3c04bc3602c7 100644 (file)
@@ -1918,7 +1918,7 @@ class _DefaultColumnComparator(ColumnOperators):
 
     def __operate(self, expr, op, obj, reverse=False):
         obj = self._check_literal(expr, op, obj)
-        comparator_factory = None
+
         if reverse:
             left, right = obj, expr
         else:
@@ -1927,25 +1927,13 @@ class _DefaultColumnComparator(ColumnOperators):
         if left.type is None:
             op, result_type = sqltypes.NULLTYPE._adapt_expression(op,
                     right.type)
-            result_type = sqltypes.to_instance(result_type)
-            if right.type._compare_type_affinity(result_type):
-                comparator_factory = right.comparator_factory
         elif right.type is None:
             op, result_type = left.type._adapt_expression(op,
                     sqltypes.NULLTYPE)
-            result_type = sqltypes.to_instance(result_type)
-            if left.type._compare_type_affinity(result_type):
-                comparator_factory = left.comparator_factory
         else:
             op, result_type = left.type._adapt_expression(op, right.type)
-            result_type = sqltypes.to_instance(result_type)
-            if left.type._compare_type_affinity(result_type):
-                comparator_factory = left.comparator_factory
-            elif right.type._compare_type_affinity(result_type):
-                comparator_factory = right.comparator_factory
 
-        return BinaryExpression(left, right, op, type_=result_type,
-                    comparator_factory=comparator_factory)
+        return BinaryExpression(left, right, op, type_=result_type)
 
     def __scalar(self, expr, op, fn, **kw):
         return fn(expr)
@@ -2159,23 +2147,20 @@ class ColumnElement(ClauseElement, ColumnOperators):
     __visit_name__ = 'column'
     primary_key = False
     foreign_keys = []
+    type = None
     quote = None
     _label = None
     _key_label = None
     _alt_names = ()
 
-    comparator = None
-
-    class Comparator(operators.ColumnOperators):
-        def __init__(self, expr):
-            self.expr = expr
-
-        def operate(self, op, *other, **kwargs):
-            return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs)
-
-        def reverse_operate(self, op, other, **kwargs):
-            return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, other,
-                                                **kwargs)
+    @util.memoized_property
+    def comparator(self):
+        if self.type is None:
+            return None
+        elif self.type.comparator_factory is not None:
+            return self.type.comparator_factory(self)
+        else:
+            return None
 
     def __getattr__(self, key):
         if self.comparator is None:
@@ -3558,7 +3543,7 @@ class BinaryExpression(ColumnElement):
     __visit_name__ = 'binary'
 
     def __init__(self, left, right, operator, type_=None,
-                    negate=None, modifiers=None, comparator_factory=None):
+                    negate=None, modifiers=None):
         # allow compatibility with libraries that
         # refer to BinaryExpression directly and pass strings
         if isinstance(operator, basestring):
@@ -3569,10 +3554,6 @@ class BinaryExpression(ColumnElement):
         self.type = sqltypes.to_instance(type_)
         self.negate = negate
 
-        self.comparator_factory = comparator_factory
-        if comparator_factory is not None:
-            self.comparator = comparator_factory(self)
-
         if modifiers is None:
             self.modifiers = {}
         else:
@@ -4209,11 +4190,6 @@ class ColumnClause(Immutable, ColumnElement):
       :func:`literal_column()` function is usually used to create such a
       :class:`.ColumnClause`.
 
-    :param comparator_factory: a :class:`.operators.ColumnOperators` subclass
-     which will produce custom operator behavior.
-
-     .. versionadded: 0.8 support for pluggable operators in
-        core column expressions.
 
     """
     __visit_name__ = 'column'
@@ -4222,15 +4198,11 @@ class ColumnClause(Immutable, ColumnElement):
 
     _memoized_property = util.group_expirable_memoized_property()
 
-    def __init__(self, text, selectable=None, type_=None, is_literal=False,
-                            comparator_factory=None):
+    def __init__(self, text, selectable=None, type_=None, is_literal=False):
         self.key = self.name = text
         self.table = selectable
         self.type = sqltypes.to_instance(type_)
         self.is_literal = is_literal
-        self.comparator_factory = comparator_factory
-        if comparator_factory:
-            self.comparator = comparator_factory(self)
 
     def _compare_name_for_result(self, other):
         if self.table is not None and hasattr(other, 'proxy_set'):
index a79bf03290acdf5cf86967bd554afae95e806b9b..b6fdb3261cba81e43eb40b82144630523e9637e5 100644 (file)
@@ -25,6 +25,7 @@ import codecs
 
 from . import exc, schema, util, processors, events, event
 from .sql import operators
+from .sql.expression import _DEFAULT_COMPARATOR
 from .util import pickle
 from .util.compat import decimal
 from .sql.visitors import Visitable
@@ -41,6 +42,23 @@ class AbstractType(Visitable):
 class TypeEngine(AbstractType):
     """Base for built-in types."""
 
+    class Comparator(operators.ColumnOperators):
+        def __init__(self, expr):
+            self.expr = expr
+
+        def operate(self, op, *other, **kwargs):
+            return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs)
+
+        def reverse_operate(self, op, other, **kwargs):
+            return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, other,
+                                                **kwargs)
+
+    comparator_factory = None
+    """A :class:`.TypeEngine.Comparator` class which will apply
+    to operations performed by owning :class:`.ColumnElement` objects.
+
+    """
+
     def copy_value(self, value):
         return value
 
@@ -451,6 +469,9 @@ class TypeDecorator(TypeEngine):
                                  "type being decorated")
         self.impl = to_instance(self.__class__.impl, *args, **kwargs)
 
+    @property
+    def comparator_factory(self):
+        return self.impl.comparator_factory
 
     def _gen_dialect_impl(self, dialect):
         """
@@ -700,11 +721,9 @@ class TypeDecorator(TypeEngine):
         return self.impl.compare_values(x, y)
 
     def _adapt_expression(self, op, othertype):
-        """
-        #todo
-        """
-        op, typ =self.impl._adapt_expression(op, othertype)
-        if typ is self.impl:
+        op, typ = self.impl._adapt_expression(op, othertype)
+        typ = to_instance(typ)
+        if typ._compare_type_affinity(self.impl):
             return op, self
         else:
             return op, typ
@@ -844,7 +863,7 @@ class _DateAffinity(object):
         othertype = othertype._type_affinity
         return op, \
                 self._expression_adaptations.get(op, self._blank_dict).\
-                get(othertype, self)
+                get(othertype, NULLTYPE)
 
 class String(Concatenable, TypeEngine):
     """The base for all string and character types.
@@ -1136,26 +1155,26 @@ class Integer(_DateAffinity, TypeEngine):
         return {
             operators.add:{
                 Date:Date,
-                Integer:Integer,
+                Integer:self.__class__,
                 Numeric:Numeric,
             },
             operators.mul:{
                 Interval:Interval,
-                Integer:Integer,
+                Integer:self.__class__,
                 Numeric:Numeric,
             },
             # Py2K
             operators.div:{
-                Integer:Integer,
+                Integer:self.__class__,
                 Numeric:Numeric,
             },
             # end Py2K
             operators.truediv:{
-                Integer:Integer,
+                Integer:self.__class__,
                 Numeric:Numeric,
             },
             operators.sub:{
-                Integer:Integer,
+                Integer:self.__class__,
                 Numeric:Numeric,
             },
         }
@@ -1311,26 +1330,26 @@ class Numeric(_DateAffinity, TypeEngine):
         return {
             operators.mul:{
                 Interval:Interval,
-                Numeric:Numeric,
-                Integer:Numeric,
+                Numeric:self.__class__,
+                Integer:self.__class__,
             },
             # Py2K
             operators.div:{
-                Numeric:Numeric,
-                Integer:Numeric,
+                Numeric:self.__class__,
+                Integer:self.__class__,
             },
             # end Py2K
             operators.truediv:{
-                Numeric:Numeric,
-                Integer:Numeric,
+                Numeric:self.__class__,
+                Integer:self.__class__,
             },
             operators.add:{
-                Numeric:Numeric,
-                Integer:Numeric,
+                Numeric:self.__class__,
+                Integer:self.__class__,
             },
             operators.sub:{
-                Numeric:Numeric,
-                Integer:Numeric,
+                Numeric:self.__class__,
+                Integer:self.__class__,
             }
         }
 
@@ -1380,21 +1399,21 @@ class Float(Numeric):
         return {
             operators.mul:{
                 Interval:Interval,
-                Numeric:Float,
+                Numeric:self.__class__,
             },
             # Py2K
             operators.div:{
-                Numeric:Float,
+                Numeric:self.__class__,
             },
             # end Py2K
             operators.truediv:{
-                Numeric:Float,
+                Numeric:self.__class__,
             },
             operators.add:{
-                Numeric:Float,
+                Numeric:self.__class__,
             },
             operators.sub:{
-                Numeric:Float,
+                Numeric:self.__class__,
             }
         }
 
@@ -1434,10 +1453,10 @@ class DateTime(_DateAffinity, TypeEngine):
     def _expression_adaptations(self):
         return {
             operators.add:{
-                Interval:DateTime,
+                Interval:self.__class__,
             },
             operators.sub:{
-                Interval:DateTime,
+                Interval:self.__class__,
                 DateTime:Interval,
             },
         }
@@ -1459,13 +1478,13 @@ class Date(_DateAffinity,TypeEngine):
     def _expression_adaptations(self):
         return {
             operators.add:{
-                Integer:Date,
+                Integer:self.__class__,
                 Interval:DateTime,
                 Time:DateTime,
             },
             operators.sub:{
                 # date - integer = date
-                Integer:Date,
+                Integer:self.__class__,
 
                 # date - date = integer.
                 Date:Integer,
@@ -1500,11 +1519,11 @@ class Time(_DateAffinity,TypeEngine):
         return {
             operators.add:{
                 Date:DateTime,
-                Interval:Time
+                Interval:self.__class__
             },
             operators.sub:{
                 Time:Interval,
-                Interval:Time,
+                Interval:self.__class__,
             },
         }
 
@@ -2050,22 +2069,22 @@ class Interval(_DateAffinity, TypeDecorator):
         return {
             operators.add:{
                 Date:DateTime,
-                Interval:Interval,
+                Interval:self.__class__,
                 DateTime:DateTime,
                 Time:Time,
             },
             operators.sub:{
-                Interval:Interval
+                Interval:self.__class__
             },
             operators.mul:{
-                Numeric:Interval
+                Numeric:self.__class__
             },
             operators.truediv: {
-                Numeric:Interval
+                Numeric:self.__class__
             },
             # Py2K
             operators.div: {
-                Numeric:Interval
+                Numeric:self.__class__
             }
             # end Py2K
         }
index 6e1966a58776d27a0e531e3495d008be71f5bfa3..02acda0f11fe83c1a21a88d9dc7f4643072eb79d 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy.sql.expression import BinaryExpression, \
                 ClauseList, Grouping, _DefaultColumnComparator
 from sqlalchemy.sql import operators
 from sqlalchemy.schema import Column, Table, MetaData
-from sqlalchemy.types import Integer
+from sqlalchemy.types import Integer, TypeEngine, TypeDecorator
 
 class DefaultColumnComparatorTest(fixtures.TestBase):
 
@@ -54,15 +54,44 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
             collate(left, right)
         )
 
-class CustomComparatorTest(fixtures.TestBase):
-    def _add_override_factory(self):
-        class MyComparator(Column.Comparator):
-            def __init__(self, expr):
-                self.expr = expr
+class _CustomComparatorTests(object):
+    def test_override_builtin(self):
+        c1 = Column('foo', self._add_override_factory())
+        self._assert_add_override(c1)
+
+    def test_column_proxy(self):
+        t = Table('t', MetaData(),
+                Column('foo', self._add_override_factory())
+            )
+        proxied = t.select().c.foo
+        self._assert_add_override(proxied)
 
-            def __add__(self, other):
-                return self.expr.op("goofy")(other)
-        return MyComparator
+    def test_alias_proxy(self):
+        t = Table('t', MetaData(),
+                Column('foo', self._add_override_factory())
+            )
+        proxied = t.alias().c.foo
+        self._assert_add_override(proxied)
+
+    def test_binary_propagate(self):
+        c1 = Column('foo', self._add_override_factory())
+        self._assert_add_override(c1 - 6)
+
+    def test_reverse_binary_propagate(self):
+        c1 = Column('foo', self._add_override_factory())
+        self._assert_add_override(6 - c1)
+
+    def test_binary_multi_propagate(self):
+        c1 = Column('foo', self._add_override_factory(True))
+        self._assert_add_override((c1 - 6) + 5)
+
+    def test_no_binary_multi_propagate_wo_adapt(self):
+        c1 = Column('foo', self._add_override_factory())
+        self._assert_not_add_override((c1 - 6) + 5)
+
+    def test_no_boolean_propagate(self):
+        c1 = Column('foo', self._add_override_factory())
+        self._assert_not_add_override(c1 == 56)
 
     def _assert_add_override(self, expr):
         assert (expr + 5).compare(
@@ -74,32 +103,102 @@ class CustomComparatorTest(fixtures.TestBase):
             expr.op("goofy")(5)
         )
 
-    def test_override_builtin(self):
-        c1 = Column('foo', Integer,
-                comparator_factory=self._add_override_factory())
-        self._assert_add_override(c1)
+class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase):
+    def _add_override_factory(self, include_adapt=False):
 
-    def test_column_proxy(self):
-        t = Table('t', MetaData(),
-                Column('foo', Integer,
-                    comparator_factory=self._add_override_factory()))
-        proxied = t.select().c.foo
-        self._assert_add_override(proxied)
+        class MyInteger(Integer):
+            class comparator_factory(TypeEngine.Comparator):
+                def __init__(self, expr):
+                    self.expr = expr
 
-    def test_binary_propagate(self):
-        c1 = Column('foo', Integer,
-                comparator_factory=self._add_override_factory())
+                def __add__(self, other):
+                    return self.expr.op("goofy")(other)
 
-        self._assert_add_override(c1 - 6)
+            if include_adapt:
+                def _adapt_expression(self, op, othertype):
+                    if op.__name__ == 'custom_op':
+                        return op, self
+                    else:
+                        return super(MyInteger, self)._adapt_expression(
+                                                            op, othertype)
 
-    def test_binary_multi_propagate(self):
-        c1 = Column('foo', Integer,
-                comparator_factory=self._add_override_factory())
-        self._assert_add_override((c1 - 6) + 5)
+        return MyInteger
 
-    def test_no_boolean_propagate(self):
-        c1 = Column('foo', Integer,
-                comparator_factory=self._add_override_factory())
 
-        self._assert_not_add_override(c1 == 56)
+class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase):
+    def _add_override_factory(self, include_adapt=False):
+
+        class MyInteger(TypeDecorator):
+            impl = Integer
+
+            class comparator_factory(TypeEngine.Comparator):
+                def __init__(self, expr):
+                    self.expr = expr
+
+                def __add__(self, other):
+                    return self.expr.op("goofy")(other)
+
+            if include_adapt:
+                def _adapt_expression(self, op, othertype):
+                    if op.__name__ == 'custom_op':
+                        return op, self
+                    else:
+                        return super(MyInteger, self)._adapt_expression(
+                                                            op, othertype)
+
+        return MyInteger
+
+
+class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBase):
+    def _add_override_factory(self, include_adapt=False):
+        class MyInteger(Integer):
+            class comparator_factory(TypeEngine.Comparator):
+                def __init__(self, expr):
+                    self.expr = expr
+
+                def __add__(self, other):
+                    return self.expr.op("goofy")(other)
+
+            if include_adapt:
+                def _adapt_expression(self, op, othertype):
+                    if op.__name__ == 'custom_op':
+                        return op, self
+                    else:
+                        return super(MyInteger, self)._adapt_expression(
+                                                            op, othertype)
+
+        class MyDecInteger(TypeDecorator):
+            impl = MyInteger
+
+        return MyDecInteger
+
+class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase):
+    def _add_override_factory(self, include_adapt=False):
+        class MyInteger(Integer):
+            class comparator_factory(TypeEngine.Comparator):
+                def __init__(self, expr):
+                    self.expr = expr
+
+                def foob(self, other):
+                    return self.expr.op("foob")(other)
+
+            if include_adapt:
+                def _adapt_expression(self, op, othertype):
+                    if op.__name__ == 'custom_op':
+                        return op, self
+                    else:
+                        return super(MyInteger, self)._adapt_expression(
+                                                            op, othertype)
+
+        return MyInteger
+
+    def _assert_add_override(self, expr):
+        assert (expr.foob(5)).compare(
+            expr.op("foob")(5)
+        )
+
+    def _assert_not_add_override(self, expr):
+        assert not hasattr(expr, "foob")
 
+    def test_no_binary_multi_propagate_wo_adapt(self):
+        pass
\ No newline at end of file