]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Enforce boolean result type for all eq_, is_, isnot, comparison
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 27 Apr 2017 14:26:10 +0000 (10:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 27 Apr 2017 20:08:02 +0000 (16:08 -0400)
Repaired issue where the type of an expression that used
:meth:`.ColumnOperators.is_` or similar would not be a "boolean" type,
instead the type would be "nulltype", as well as when using custom
comparison operators against an untyped expression.   This typing can
impact how the expression behaves in larger contexts as well as
in result-row-handling.

Change-Id: Ib810ff686de500d8db26ae35a51005fab29603b6
Fixes: #3873
doc/build/changelog/changelog_12.rst
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/sql/sqltypes.py
test/sql/test_operators.py

index b87682b6dca0f14d6b3552622cf24ac3f213e710..2b67414948941bced537e51efa9c438a2c4bbdb8 100644 (file)
 .. changelog::
     :version: 1.2.0b1
 
+    .. change:: 3873
+        :tags: bug, sql
+        :tickets: 3873
+
+        Repaired issue where the type of an expression that used
+        :meth:`.ColumnOperators.is_` or similar would not be a "boolean" type,
+        instead the type would be "nulltype", as well as when using custom
+        comparison operators against an untyped expression.   This typing can
+        impact how the expression behaves in larger contexts as well as
+        in result-row-handling.
+
     .. change:: 3969
         :tags: bug, sql
         :tickets: 3969
index 4ba53ef758086cbd3029af833145f8e4ca867f1c..4485c661b9ddf02c2ec1b322441654cfeceda5f7 100644 (file)
@@ -50,11 +50,15 @@ def _boolean_compare(expr, op, obj, negate=None, reverse=False,
             if op in (operators.eq, operators.is_):
                 return BinaryExpression(expr, _const_expr(obj),
                                         operators.is_,
-                                        negate=operators.isnot)
+                                        negate=operators.isnot,
+                                        type_=result_type
+                                        )
             elif op in (operators.ne, operators.isnot):
                 return BinaryExpression(expr, _const_expr(obj),
                                         operators.isnot,
-                                        negate=operators.is_)
+                                        negate=operators.is_,
+                                        type_=result_type
+                                        )
             else:
                 raise exc.ArgumentError(
                     "Only '=', '!=', 'is_()', 'isnot()', "
index 49642acdd5b6499c2706afd3e3fbb3d071e29c99..58f32b3e6f50d7dc4910ec15636d7b42f5cbb421 100644 (file)
@@ -1021,7 +1021,8 @@ def json_path_getitem_op(a, b):
 
 _commutative = {eq, ne, add, mul}
 
-_comparison = {eq, ne, lt, gt, ge, le, between_op, like_op}
+_comparison = {eq, ne, lt, gt, ge, le, between_op, like_op, is_,
+               isnot, is_distinct_from, isnot_distinct_from}
 
 
 def is_comparison(op):
index b8117e3ca1ef93fdc85aa6ba730f25e44eae05d7..b8c8c811687b744ff06e87c08f5c639f56c59602 100644 (file)
@@ -2555,7 +2555,9 @@ class NullType(TypeEngine):
     class Comparator(TypeEngine.Comparator):
 
         def _adapt_expression(self, op, other_comparator):
-            if isinstance(other_comparator, NullType.Comparator) or \
+            if operators.is_comparison(op):
+                return op, BOOLEANTYPE
+            elif isinstance(other_comparator, NullType.Comparator) or \
                     not operators.is_commutative(op):
                 return op, self.expr.type
             else:
index c0637d22571343c1ad49de6c27d1a44c03aeffb8..3dd9af5e23f39a6a54330d1f761d18ce370459df 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import testing
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.sql import column, desc, asc, literal, collate, null, \
     true, false, any_, all_
+from sqlalchemy.sql import sqltypes
 from sqlalchemy.sql.expression import BinaryExpression, \
     ClauseList, Grouping, \
     UnaryExpression, select, union, func, tuple_
@@ -62,6 +63,12 @@ class DefaultColumnComparatorTest(fixtures.TestBase):
 
         self._loop_test(operator, right)
 
+        if operators.is_comparison(operator):
+            is_(
+                left.comparator.operate(operator, right).type,
+                sqltypes.BOOLEANTYPE
+            )
+
     def _loop_test(self, operator, *arg):
         loop = LoopOperate()
         is_(
@@ -2617,6 +2624,9 @@ class CustomOpTest(fixtures.TestBase):
         assert operators.is_comparison(op1)
         assert not operators.is_comparison(op2)
 
+        expr = c.op('$', is_comparison=True)(None)
+        is_(expr.type, sqltypes.BOOLEANTYPE)
+
 
 class TupleTypingTest(fixtures.TestBase):