]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
_adapt_expression() moves fully to _DefaultColumnComparator which resumes
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Aug 2012 20:11:42 +0000 (16:11 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Aug 2012 20:11:42 +0000 (16:11 -0400)
its original role as stateful, forms the basis of TypeEngine.Comparator.  lots
of code goes back mostly as it was just with cleaner typing behavior, such
as simple flow in _binary_operate now.

lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/types.py
test/sql/test_operators.py
test/sql/test_types.py

index a9ff988e81802d6a53aecf8cd456b3cd7975f0a2..36da14d3337c1f29e8c1d32dc89776cb311a30f2 100644 (file)
@@ -708,9 +708,10 @@ class PGCompiler(compiler.SQLCompiler):
             affinity = None
 
         casts = {
-                    sqltypes.Date:'date',
-                    sqltypes.DateTime:'timestamp',
-                    sqltypes.Interval:'interval', sqltypes.Time:'time'
+                    sqltypes.Date: 'date',
+                    sqltypes.DateTime: 'timestamp',
+                    sqltypes.Interval: 'interval',
+                    sqltypes.Time: 'time'
                 }
         cast = casts.get(affinity, None)
         if isinstance(extract.expr, sql.ColumnElement) and cast is not None:
index 844293c73673c7eca0040d2302cd9377f821fc15..63fa23c15d4a13bdb51f7da779108d91c9f6c8ec 100644 (file)
@@ -1875,7 +1875,7 @@ class Immutable(object):
         return self
 
 
-class _DefaultColumnComparator(object):
+class _DefaultColumnComparator(operators.ColumnOperators):
     """Defines comparison and math operations.
 
     See :class:`.ColumnOperators` and :class:`.Operators` for descriptions
@@ -1883,6 +1883,45 @@ class _DefaultColumnComparator(object):
 
     """
 
+    @util.memoized_property
+    def type(self):
+        return self.expr.type
+
+    def operate(self, op, *other, **kwargs):
+        o = self.operators[op.__name__]
+        return o[0](self, self.expr, op, *(other + o[1:]), **kwargs)
+
+    def reverse_operate(self, op, other, **kwargs):
+        o = self.operators[op.__name__]
+        return o[0](self, self.expr, op, other, reverse=True, *o[1:], **kwargs)
+
+    def _adapt_expression(self, op, other_comparator):
+        """evaluate the return type of <self> <op> <othertype>,
+        and apply any adaptations to the given operator.
+
+        This method determines the type of a resulting binary expression
+        given two source types and an operator.   For example, two
+        :class:`.Column` objects, both of the type :class:`.Integer`, will
+        produce a :class:`.BinaryExpression` that also has the type
+        :class:`.Integer` when compared via the addition (``+``) operator.
+        However, using the addition operator with an :class:`.Integer`
+        and a :class:`.Date` object will produce a :class:`.Date`, assuming
+        "days delta" behavior by the database (in reality, most databases
+        other than Postgresql don't accept this particular operation).
+
+        The method returns a tuple of the form <operator>, <type>.
+        The resulting operator and type will be those applied to the
+        resulting :class:`.BinaryExpression` as the final operator and the
+        right-hand side of the expression.
+
+        Note that only a subset of operators make usage of
+        :meth:`._adapt_expression`,
+        including math operators and user-defined operators, but not
+        boolean comparison or special SQL keywords like MATCH or BETWEEN.
+
+        """
+        return op, other_comparator.type
+
     def _boolean_compare(self, expr, op, obj, negate=None, reverse=False,
                         **kwargs
         ):
@@ -1912,7 +1951,7 @@ class _DefaultColumnComparator(object):
                             type_=sqltypes.BOOLEANTYPE,
                             negate=negate, modifiers=kwargs)
 
-    def _binary_operate(self, expr, op, obj, result_type, reverse=False):
+    def _binary_operate(self, expr, op, obj, reverse=False):
         obj = self._check_literal(expr, op, obj)
 
         if reverse:
@@ -1920,6 +1959,8 @@ class _DefaultColumnComparator(object):
         else:
             left, right = expr, obj
 
+        op, result_type = left.comparator._adapt_expression(op, right.comparator)
+
         return BinaryExpression(left, right, op, type_=result_type)
 
     def _scalar(self, expr, op, fn, **kw):
@@ -1986,7 +2027,8 @@ class _DefaultColumnComparator(object):
             expr,
             operators.like_op,
             literal_column("'%'", type_=sqltypes.String).__radd__(
-                                self._check_literal(expr, operators.like_op, other)
+                                self._check_literal(expr,
+                                        operators.like_op, other)
                             ),
             escape=escape)
 
@@ -2068,21 +2110,16 @@ class _DefaultColumnComparator(object):
         "neg": (_neg_impl,),
     }
 
-    def operate(self, expr, op, *other, **kwargs):
-        o = self.operators[op.__name__]
-        return o[0](self, expr, op, *(other + o[1:]), **kwargs)
-
-    def reverse_operate(self, expr, op, other, **kwargs):
-        o = self.operators[op.__name__]
-        return o[0](self, expr, op, other, reverse=True, *o[1:], **kwargs)
 
     def _check_literal(self, expr, operator, other):
-        if isinstance(other, BindParameter) and \
-            isinstance(other.type, sqltypes.NullType):
-            # TODO: perhaps we should not mutate the incoming bindparam()
-            # here and instead make a copy of it.  this might
-            # be the only place that we're mutating an incoming construct.
-            other.type = expr.type
+        if isinstance(other, (ColumnElement, TextClause)):
+            if isinstance(other, BindParameter) and \
+                isinstance(other.type, sqltypes.NullType):
+                # TODO: perhaps we should not mutate the incoming
+                # bindparam() here and instead make a copy of it.
+                # this might be the only place that we're mutating
+                # an incoming construct.
+                other.type = expr.type
             return other
         elif hasattr(other, '__clause_element__'):
             other = other.__clause_element__()
@@ -2096,8 +2133,6 @@ class _DefaultColumnComparator(object):
         else:
             return other
 
-_DEFAULT_COMPARATOR = _DefaultColumnComparator()
-
 
 class ColumnElement(ClauseElement, ColumnOperators):
     """Represent an element that is usable within the "column clause" portion
@@ -2155,11 +2190,7 @@ class ColumnElement(ClauseElement, ColumnOperators):
     def comparator(self):
         return self.type.comparator_factory(self)
 
-    #def _assert_comparator(self):
-    #    assert self.comparator.expr is self
-
     def __getattr__(self, key):
-        #self._assert_comparator()
         try:
             return getattr(self.comparator, key)
         except AttributeError:
@@ -2171,11 +2202,9 @@ class ColumnElement(ClauseElement, ColumnOperators):
             )
 
     def operate(self, op, *other, **kwargs):
-        #self._assert_comparator()
         return op(self.comparator, *other, **kwargs)
 
     def reverse_operate(self, op, other, **kwargs):
-        #self._assert_comparator()
         return op(other, self.comparator, **kwargs)
 
     def _bind_param(self, operator, obj):
@@ -3090,6 +3119,10 @@ class TextClause(Executable, ClauseElement):
         else:
             return sqltypes.NULLTYPE
 
+    @property
+    def comparator(self):
+        return self.type.comparator_factory(self)
+
     def self_group(self, against=None):
         if against is operators.in_op:
             return Grouping(self)
index d4dbd648c2b6641bf4232ae64884cff9be4cff05..bbeebf5d3664ac43fde50c38bb346fb998f065a0 100644 (file)
@@ -11,21 +11,21 @@ types.
 For more information see the SQLAlchemy documentation on types.
 
 """
-__all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType',
-            'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR','TEXT', 'Text',
+__all__ = ['TypeEngine', 'TypeDecorator', 'AbstractType', 'UserDefinedType',
+            'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'NVARCHAR', 'TEXT', 'Text',
             'FLOAT', 'NUMERIC', 'REAL', 'DECIMAL', 'TIMESTAMP', 'DATETIME',
             'CLOB', 'BLOB', 'BINARY', 'VARBINARY', 'BOOLEAN', 'BIGINT', 'SMALLINT',
             'INTEGER', 'DATE', 'TIME', 'String', 'Integer', 'SmallInteger',
             'BigInteger', 'Numeric', 'Float', 'DateTime', 'Date', 'Time',
             'LargeBinary', 'Binary', 'Boolean', 'Unicode', 'Concatenable',
-            'UnicodeText','PickleType', 'Interval', 'Enum' ]
+            'UnicodeText', 'PickleType', 'Interval', 'Enum']
 
 import datetime as dt
 import codecs
 
 from . import exc, schema, util, processors, events, event
 from .sql import operators
-from .sql.expression import _DEFAULT_COMPARATOR
+from .sql.expression import _DefaultColumnComparator
 from .util import pickle
 from .util.compat import decimal
 from .sql.visitors import Visitable
@@ -42,7 +42,7 @@ class AbstractType(Visitable):
 class TypeEngine(AbstractType):
     """Base for built-in types."""
 
-    class Comparator(operators.ColumnOperators):
+    class Comparator(_DefaultColumnComparator):
         """Base class for custom comparison operations defined at the
         type level.  See :attr:`.TypeEngine.comparator_factory`.
 
@@ -54,24 +54,6 @@ class TypeEngine(AbstractType):
         def __reduce__(self):
             return _reconstitute_comparator, (self.expr, )
 
-        def operate(self, op, *other, **kwargs):
-            if len(other) == 1:
-                obj = other[0]
-                obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, obj)
-                op, adapt_type = self.expr.type._adapt_expression(op,
-                    obj.type)
-                kwargs['result_type'] = adapt_type
-
-            return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs)
-
-        def reverse_operate(self, op, other, **kwargs):
-
-            obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, other)
-            op, adapt_type = obj.type._adapt_expression(op, self.expr.type)
-            kwargs['result_type'] = adapt_type
-
-            return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, obj,
-                                                **kwargs)
 
     comparator_factory = Comparator
     """A :class:`.TypeEngine.Comparator` class which will apply
@@ -143,11 +125,6 @@ class TypeEngine(AbstractType):
         >>> (c1 == c2).type
         Boolean()
 
-    The propagation of :class:`.TypeEngine.Comparator` throughout an expression
-    will follow with how the :class:`.TypeEngine` itself is propagated.  To
-    customize the behavior of most operators in this regard, see the
-    :meth:`._adapt_expression` method.
-
     .. versionadded:: 0.8  The expression system was reworked to support
       user-defined comparator objects specified at the type level.
 
@@ -247,34 +224,7 @@ class TypeEngine(AbstractType):
         .. versionadded:: 0.7.2
 
         """
-        return Variant(self, {dialect_name:type_})
-
-    def _adapt_expression(self, op, othertype):
-        """evaluate the return type of <self> <op> <othertype>,
-        and apply any adaptations to the given operator.
-
-        This method determines the type of a resulting binary expression
-        given two source types and an operator.   For example, two
-        :class:`.Column` objects, both of the type :class:`.Integer`, will
-        produce a :class:`.BinaryExpression` that also has the type
-        :class:`.Integer` when compared via the addition (``+``) operator.
-        However, using the addition operator with an :class:`.Integer`
-        and a :class:`.Date` object will produce a :class:`.Date`, assuming
-        "days delta" behavior by the database (in reality, most databases
-        other than Postgresql don't accept this particular operation).
-
-        The method returns a tuple of the form <operator>, <type>.
-        The resulting operator and type will be those applied to the
-        resulting :class:`.BinaryExpression` as the final operator and the
-        right-hand side of the expression.
-
-        Note that only a subset of operators make usage of
-        :meth:`._adapt_expression`,
-        including math operators and user-defined operators, but not
-        boolean comparison or special SQL keywords like MATCH or BETWEEN.
-
-        """
-        return op, self
+        return Variant(self, {dialect_name: type_})
 
     @util.memoized_property
     def _type_affinity(self):
@@ -334,7 +284,7 @@ class TypeEngine(AbstractType):
                 impl = self.adapt(type(self))
             # this can't be self, else we create a cycle
             assert impl is not self
-            dialect._type_memos[self] = d = {'impl':impl}
+            dialect._type_memos[self] = d = {'impl': impl}
             return d
 
     def _gen_dialect_impl(self, dialect):
@@ -461,22 +411,21 @@ class UserDefinedType(TypeEngine):
     """
     __visit_name__ = "user_defined"
 
-    def _adapt_expression(self, op, othertype):
-        """evaluate the return type of <self> <op> <othertype>,
-        and apply any adaptations to the given operator.
-
-        """
-        return self.adapt_operator(op), self
-
-    def adapt_operator(self, op):
-        """A hook which allows the given operator to be adapted
-        to something new.
+    class Comparator(TypeEngine.Comparator):
+        def _adapt_expression(self, op, other_comparator):
+            if hasattr(self.type, 'adapt_operator'):
+                util.warn_deprecated(
+                    "UserDefinedType.adapt_operator is deprecated.  Create "
+                     "a UserDefinedType.Comparator subclass instead which "
+                     "generates the desired expression constructs, given a "
+                     "particular operator."
+                    )
+                return self.type.adapt_operator(op), self.type
+            else:
+                return op, self.type
 
-        See also UserDefinedType._adapt_expression(), an as-yet-
-        semi-public method with greater capability in this regard.
+    comparator_factory = Comparator
 
-        """
-        return op
 
 class TypeDecorator(TypeEngine):
     """Allows the creation of types which add additional functionality
@@ -837,13 +786,6 @@ class TypeDecorator(TypeEngine):
         """
         return self.impl.compare_values(x, y)
 
-    def _adapt_expression(self, op, othertype):
-        op, typ = self.impl._adapt_expression(op, othertype)
-        typ = to_instance(typ)
-        if typ._compare_type_affinity(self.impl):
-            return op, self
-        else:
-            return op, typ
 
 class Variant(TypeDecorator):
     """A wrapping type that selects among a variety of
@@ -926,8 +868,6 @@ def adapt_type(typeobj, colspecs):
     return typeobj.adapt(impltype)
 
 
-
-
 class NullType(TypeEngine):
     """An unknown type.
 
@@ -943,11 +883,14 @@ class NullType(TypeEngine):
     """
     __visit_name__ = 'null'
 
-    def _adapt_expression(self, op, othertype):
-        if isinstance(othertype, NullType) or not operators.is_commutative(op):
-            return op, self
-        else:
-            return othertype._adapt_expression(op, self)
+    class Comparator(TypeEngine.Comparator):
+        def _adapt_expression(self, op, other_comparator):
+            if isinstance(other_comparator, NullType.Comparator) or \
+                not operators.is_commutative(op):
+                return op, self.expr.type
+            else:
+                return other_comparator._adapt_expression(op, self)
+    comparator_factory = Comparator
 
 NullTypeEngine = NullType
 
@@ -955,12 +898,16 @@ class Concatenable(object):
     """A mixin that marks a type as supporting 'concatenation',
     typically strings."""
 
-    def _adapt_expression(self, op, othertype):
-        if op is operators.add and issubclass(othertype._type_affinity,
-                (Concatenable, NullType)):
-            return operators.concat_op, self
-        else:
-            return op, self
+    class Comparator(TypeEngine.Comparator):
+        def _adapt_expression(self, op, other_comparator):
+            if op is operators.add and isinstance(other_comparator,
+                    (Concatenable.Comparator, NullType.Comparator)):
+                return operators.concat_op, self.expr.type
+            else:
+                return op, self.expr.type
+
+    comparator_factory = Comparator
+
 
 class _DateAffinity(object):
     """Mixin date/time specific expression adaptations.
@@ -975,12 +922,14 @@ class _DateAffinity(object):
     def _expression_adaptations(self):
         raise NotImplementedError()
 
-    _blank_dict = util.immutabledict()
-    def _adapt_expression(self, op, othertype):
-        othertype = othertype._type_affinity
-        return op, \
-                self._expression_adaptations.get(op, self._blank_dict).\
-                get(othertype, NULLTYPE)
+    class Comparator(TypeEngine.Comparator):
+        _blank_dict = util.immutabledict()
+        def _adapt_expression(self, op, other_comparator):
+            othertype = other_comparator.type._type_affinity
+            return op, \
+                    self.type._expression_adaptations.get(op, self._blank_dict).\
+                    get(othertype, NULLTYPE)
+    comparator_factory = Comparator
 
 class String(Concatenable, TypeEngine):
     """The base for all string and character types.
index c38f95a015e67de4b8bb3fd8d675233651d319ff..05de8c9ef45282505b2598ad857ea579cbe62769 100644 (file)
@@ -12,18 +12,16 @@ from sqlalchemy.types import Integer, TypeEngine, TypeDecorator
 class DefaultColumnComparatorTest(fixtures.TestBase):
 
     def _do_scalar_test(self, operator, compare_to):
-        cc = _DefaultColumnComparator()
         left = column('left')
-        assert cc.operate(left, operator).compare(
+        assert left.comparator.operate(operator).compare(
             compare_to(left)
         )
 
     def _do_operate_test(self, operator):
-        cc = _DefaultColumnComparator()
         left = column('left')
         right = column('right')
 
-        assert cc.operate(left, operator, right, result_type=Integer).compare(
+        assert left.comparator.operate(operator, right).compare(
             BinaryExpression(left, right, operator)
         )
 
@@ -37,9 +35,8 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
         self._do_operate_test(operators.add)
 
     def test_in(self):
-        cc = _DefaultColumnComparator()
         left = column('left')
-        assert cc.operate(left, operators.in_op, [1, 2, 3]).compare(
+        assert left.comparator.operate(operators.in_op, [1, 2, 3]).compare(
                 BinaryExpression(
                     left,
                     Grouping(ClauseList(
@@ -50,10 +47,9 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
             )
 
     def test_collate(self):
-        cc = _DefaultColumnComparator()
         left = column('left')
         right = "some collation"
-        cc.operate(left, operators.collate, right).compare(
+        left.comparator.operate(operators.collate, right).compare(
             collate(left, right)
         )
 
@@ -144,12 +140,8 @@ class _CustomComparatorTests(object):
         self._assert_add_override(6 - c1)
 
     def test_binary_multi_propagate(self):
-        c1 = Column('foo', self._add_override_factory(True))
-        self._assert_add_override((c1 - 6) + 5)
-
-    def test_no_binary_multi_propagate_wo_adapt(self):
         c1 = Column('foo', self._add_override_factory())
-        self._assert_not_add_override((c1 - 6) + 5)
+        self._assert_add_override((c1 - 6) + 5)
 
     def test_no_boolean_propagate(self):
         c1 = Column('foo', self._add_override_factory())
@@ -166,7 +158,7 @@ class _CustomComparatorTests(object):
         )
 
 class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase):
-    def _add_override_factory(self, include_adapt=False):
+    def _add_override_factory(self):
 
         class MyInteger(Integer):
             class comparator_factory(TypeEngine.Comparator):
@@ -176,19 +168,12 @@ class CustomComparatorTest(_CustomComparatorTests, fixtures.TestBase):
                 def __add__(self, other):
                     return self.expr.op("goofy")(other)
 
-            if include_adapt:
-                def _adapt_expression(self, op, othertype):
-                    if op.__name__ == 'custom_op':
-                        return op, self
-                    else:
-                        return super(MyInteger, self)._adapt_expression(
-                                                            op, othertype)
 
         return MyInteger
 
 
 class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase):
-    def _add_override_factory(self, include_adapt=False):
+    def _add_override_factory(self):
 
         class MyInteger(TypeDecorator):
             impl = Integer
@@ -200,19 +185,12 @@ class TypeDecoratorComparatorTest(_CustomComparatorTests, fixtures.TestBase):
                 def __add__(self, other):
                     return self.expr.op("goofy")(other)
 
-            if include_adapt:
-                def _adapt_expression(self, op, othertype):
-                    if op.__name__ == 'custom_op':
-                        return op, self
-                    else:
-                        return super(MyInteger, self)._adapt_expression(
-                                                            op, othertype)
 
         return MyInteger
 
 
 class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBase):
-    def _add_override_factory(self, include_adapt=False):
+    def _add_override_factory(self):
         class MyInteger(Integer):
             class comparator_factory(TypeEngine.Comparator):
                 def __init__(self, expr):
@@ -221,13 +199,6 @@ class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBas
                 def __add__(self, other):
                     return self.expr.op("goofy")(other)
 
-            if include_adapt:
-                def _adapt_expression(self, op, othertype):
-                    if op.__name__ == 'custom_op':
-                        return op, self
-                    else:
-                        return super(MyInteger, self)._adapt_expression(
-                                                            op, othertype)
 
         class MyDecInteger(TypeDecorator):
             impl = MyInteger
@@ -235,7 +206,7 @@ class CustomEmbeddedinTypeDecoratorTest(_CustomComparatorTests, fixtures.TestBas
         return MyDecInteger
 
 class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase):
-    def _add_override_factory(self, include_adapt=False):
+    def _add_override_factory(self):
         class MyInteger(Integer):
             class comparator_factory(TypeEngine.Comparator):
                 def __init__(self, expr):
@@ -243,15 +214,6 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase):
 
                 def foob(self, other):
                     return self.expr.op("foob")(other)
-
-            if include_adapt:
-                def _adapt_expression(self, op, othertype):
-                    if op.__name__ == 'custom_op':
-                        return op, self
-                    else:
-                        return super(MyInteger, self)._adapt_expression(
-                                                            op, othertype)
-
         return MyInteger
 
     def _assert_add_override(self, expr):
@@ -262,5 +224,3 @@ class NewOperatorTest(_CustomComparatorTests, fixtures.TestBase):
     def _assert_not_add_override(self, expr):
         assert not hasattr(expr, "foob")
 
-    def test_no_binary_multi_propagate_wo_adapt(self):
-        pass
\ No newline at end of file
index 91bf17175f0219fc656c8fce9ce84ce570dae0e8..279ae36a0a700570c3cb4f6c40fe3786da5aab7e 100644 (file)
@@ -1222,6 +1222,7 @@ class ExpressionTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiled
         eq_(expr.right.type.__class__, CHAR)
 
 
+    @testing.uses_deprecated
     @testing.fails_on('firebird', 'Data type unknown on the parameter')
     @testing.fails_on('mssql', 'int is unsigned ?  not clear')
     def test_operator_adapt(self):