]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
-we move all the invocation of "_adapt_expression" into TypeEngine.Comparator. ...
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Aug 2012 18:07:33 +0000 (14:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Aug 2012 18:07:33 +0000 (14:07 -0400)
the split of operator stuff is getting awkward and we might want to move _DefaultComparator.

lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/types.py
test/sql/test_operators.py

index a0715a975df69ff35cdeeaf150e706851e709276..844293c73673c7eca0040d2302cd9377f821fc15 100644 (file)
@@ -1883,7 +1883,7 @@ class _DefaultColumnComparator(object):
 
     """
 
-    def __compare(self, expr, op, obj, negate=None, reverse=False,
+    def _boolean_compare(self, expr, op, obj, negate=None, reverse=False,
                         **kwargs
         ):
         if obj is None or isinstance(obj, Null):
@@ -1912,7 +1912,7 @@ class _DefaultColumnComparator(object):
                             type_=sqltypes.BOOLEANTYPE,
                             negate=negate, modifiers=kwargs)
 
-    def __operate(self, expr, op, obj, reverse=False):
+    def _binary_operate(self, expr, op, obj, result_type, reverse=False):
         obj = self._check_literal(expr, op, obj)
 
         if reverse:
@@ -1920,25 +1920,16 @@ class _DefaultColumnComparator(object):
         else:
             left, right = expr, obj
 
-        if left.type is None:
-            op, result_type = sqltypes.NULLTYPE._adapt_expression(op,
-                    right.type)
-        elif right.type is None:
-            op, result_type = left.type._adapt_expression(op,
-                    sqltypes.NULLTYPE)
-        else:
-            op, result_type = left.type._adapt_expression(op, right.type)
-
         return BinaryExpression(left, right, op, type_=result_type)
 
-    def __scalar(self, expr, op, fn, **kw):
+    def _scalar(self, expr, op, fn, **kw):
         return fn(expr)
 
-    def _in_impl(self, expr, op, seq_or_selectable, negate_op):
+    def _in_impl(self, expr, op, seq_or_selectable, negate_op, **kw):
         seq_or_selectable = _clause_element_as_expr(seq_or_selectable)
 
         if isinstance(seq_or_selectable, ScalarSelect):
-            return self.__compare(expr, op, seq_or_selectable,
+            return self._boolean_compare(expr, op, seq_or_selectable,
                                   negate=negate_op)
         elif isinstance(seq_or_selectable, SelectBase):
 
@@ -1947,11 +1938,11 @@ class _DefaultColumnComparator(object):
             # as_scalar() to produce a multi- column selectable that
             # does not export itself as a FROM clause
 
-            return self.__compare(expr, op, seq_or_selectable.as_scalar(),
-                                  negate=negate_op)
+            return self._boolean_compare(expr, op, seq_or_selectable.as_scalar(),
+                                  negate=negate_op, **kw)
         elif isinstance(seq_or_selectable, (Selectable, TextClause)):
-            return self.__compare(expr, op, seq_or_selectable,
-                                  negate=negate_op)
+            return self._boolean_compare(expr, op, seq_or_selectable,
+                                  negate=negate_op, **kw)
 
 
         # Handle non selectable arguments as sequences
@@ -1981,17 +1972,17 @@ class _DefaultColumnComparator(object):
                       'strategies for improved performance.' % expr)
             return expr != expr
 
-        return self.__compare(expr, op,
+        return self._boolean_compare(expr, op,
                               ClauseList(*args).self_group(against=op),
                               negate=negate_op)
     def _neg_impl(self, expr, op, **kw):
         """See :meth:`.ColumnOperators.__neg__`."""
         return UnaryExpression(expr, operator=operators.neg)
 
-    def _startswith_impl(self, expr, op, other, escape=None):
+    def _startswith_impl(self, expr, op, other, escape=None, **kw):
         """See :meth:`.ColumnOperators.startswith`."""
         # use __radd__ to force string concat behavior
-        return self.__compare(
+        return self._boolean_compare(
             expr,
             operators.like_op,
             literal_column("'%'", type_=sqltypes.String).__radd__(
@@ -1999,18 +1990,18 @@ class _DefaultColumnComparator(object):
                             ),
             escape=escape)
 
-    def _endswith_impl(self, expr, op, other, escape=None):
+    def _endswith_impl(self, expr, op, other, escape=None, **kw):
         """See :meth:`.ColumnOperators.endswith`."""
-        return self.__compare(
+        return self._boolean_compare(
             expr,
             operators.like_op,
             literal_column("'%'", type_=sqltypes.String) +
                 self._check_literal(expr, operators.like_op, other),
             escape=escape)
 
-    def _contains_impl(self, expr, op, other, escape=None):
+    def _contains_impl(self, expr, op, other, escape=None, **kw):
         """See :meth:`.ColumnOperators.contains`."""
-        return self.__compare(
+        return self._boolean_compare(
             expr,
             operators.like_op,
             literal_column("'%'", type_=sqltypes.String) +
@@ -2018,13 +2009,13 @@ class _DefaultColumnComparator(object):
                 literal_column("'%'", type_=sqltypes.String),
             escape=escape)
 
-    def _match_impl(self, expr, op, other):
+    def _match_impl(self, expr, op, other, **kw):
         """See :meth:`.ColumnOperators.match`."""
-        return self.__compare(expr, operators.match_op,
+        return self._boolean_compare(expr, operators.match_op,
                               self._check_literal(expr, operators.match_op,
                               other))
 
-    def _distinct_impl(self, expr, op):
+    def _distinct_impl(self, expr, op, **kw):
         """See :meth:`.ColumnOperators.distinct`."""
         return UnaryExpression(expr, operator=operators.distinct_op,
                                 type_=expr.type)
@@ -2040,32 +2031,32 @@ class _DefaultColumnComparator(object):
                     group=False),
                 operators.between_op)
 
-    def _collate_impl(self, expr, op, other):
+    def _collate_impl(self, expr, op, other, **kw):
         return collate(expr, other)
 
-    # a mapping of operators with the method they use, along with their negated
-    # operator for comparison operators
+    # a mapping of operators with the method they use, along with
+    # their negated operator for comparison operators
     operators = {
-        "add": (__operate,),
-        "mul": (__operate,),
-        "sub": (__operate,),
-        "div": (__operate,),
-        "mod": (__operate,),
-        "truediv": (__operate,),
-        "custom_op": (__operate,),
-        "concat_op": (__operate,),
-        "lt": (__compare, operators.ge),
-        "le": (__compare, operators.gt),
-        "ne": (__compare, operators.eq),
-        "gt": (__compare, operators.le),
-        "ge": (__compare, operators.lt),
-        "eq": (__compare, operators.ne),
-        "like_op": (__compare, operators.notlike_op),
-        "ilike_op": (__compare, operators.notilike_op),
-        "desc_op": (__scalar, desc),
-        "asc_op": (__scalar, asc),
-        "nullsfirst_op": (__scalar, nullsfirst),
-        "nullslast_op": (__scalar, nullslast),
+        "add": (_binary_operate,),
+        "mul": (_binary_operate,),
+        "sub": (_binary_operate,),
+        "div": (_binary_operate,),
+        "mod": (_binary_operate,),
+        "truediv": (_binary_operate,),
+        "custom_op": (_binary_operate,),
+        "concat_op": (_binary_operate,),
+        "lt": (_boolean_compare, operators.ge),
+        "le": (_boolean_compare, operators.gt),
+        "ne": (_boolean_compare, operators.eq),
+        "gt": (_boolean_compare, operators.le),
+        "ge": (_boolean_compare, operators.lt),
+        "eq": (_boolean_compare, operators.ne),
+        "like_op": (_boolean_compare, operators.notlike_op),
+        "ilike_op": (_boolean_compare, operators.notilike_op),
+        "desc_op": (_scalar, desc),
+        "asc_op": (_scalar, asc),
+        "nullsfirst_op": (_scalar, nullsfirst),
+        "nullslast_op": (_scalar, nullslast),
         "in_op": (_in_impl, operators.notin_op),
         "collate": (_collate_impl,),
         "match_op": (_match_impl,),
@@ -2085,7 +2076,6 @@ class _DefaultColumnComparator(object):
         o = self.operators[op.__name__]
         return o[0](self, expr, op, other, reverse=True, *o[1:], **kwargs)
 
-
     def _check_literal(self, expr, operator, other):
         if isinstance(other, BindParameter) and \
             isinstance(other.type, sqltypes.NullType):
@@ -2096,15 +2086,13 @@ class _DefaultColumnComparator(object):
             return other
         elif hasattr(other, '__clause_element__'):
             other = other.__clause_element__()
-            if isinstance(other, (SelectBase, Alias)):
-                other = other.as_scalar()
-            return other
         elif isinstance(other, sqltypes.TypeEngine.Comparator):
-            return other.expr
-        elif not isinstance(other, ClauseElement):
-            return expr._bind_param(operator, other)
-        elif isinstance(other, (SelectBase, Alias)):
+            other = other.expr
+
+        if isinstance(other, (SelectBase, Alias)):
             return other.as_scalar()
+        elif not isinstance(other, (ColumnElement, TextClause)):
+            return expr._bind_param(operator, other)
         else:
             return other
 
index 40cf2f3314fc44930983b2f2d01bacd942aa7cce..d4dbd648c2b6641bf4232ae64884cff9be4cff05 100644 (file)
@@ -55,10 +55,22 @@ class TypeEngine(AbstractType):
             return _reconstitute_comparator, (self.expr, )
 
         def operate(self, op, *other, **kwargs):
+            if len(other) == 1:
+                obj = other[0]
+                obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, obj)
+                op, adapt_type = self.expr.type._adapt_expression(op,
+                    obj.type)
+                kwargs['result_type'] = adapt_type
+
             return _DEFAULT_COMPARATOR.operate(self.expr, op, *other, **kwargs)
 
         def reverse_operate(self, op, other, **kwargs):
-            return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, other,
+
+            obj = _DEFAULT_COMPARATOR._check_literal(self.expr, op, other)
+            op, adapt_type = obj.type._adapt_expression(op, self.expr.type)
+            kwargs['result_type'] = adapt_type
+
+            return _DEFAULT_COMPARATOR.reverse_operate(self.expr, op, obj,
                                                 **kwargs)
 
     comparator_factory = Comparator
index b369dfc1144692d67aa221cc9fe50553ce16bb6c..c38f95a015e67de4b8bb3fd8d675233651d319ff 100644 (file)
@@ -23,7 +23,7 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
         left = column('left')
         right = column('right')
 
-        assert cc.operate(left, operator, right).compare(
+        assert cc.operate(left, operator, right, result_type=Integer).compare(
             BinaryExpression(left, right, operator)
         )