From: Mike Bayer Date: Thu, 27 Apr 2017 14:26:10 +0000 (-0400) Subject: Enforce boolean result type for all eq_, is_, isnot, comparison X-Git-Tag: rel_1_2_0b1~87 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=433d2ee9f14a028399e848f3552a1a71f223c976;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Enforce boolean result type for all eq_, is_, isnot, comparison Repaired issue where the type of an expression that used :meth:`.ColumnOperators.is_` or similar would not be a "boolean" type, instead the type would be "nulltype", as well as when using custom comparison operators against an untyped expression. This typing can impact how the expression behaves in larger contexts as well as in result-row-handling. Change-Id: Ib810ff686de500d8db26ae35a51005fab29603b6 Fixes: #3873 --- diff --git a/doc/build/changelog/changelog_12.rst b/doc/build/changelog/changelog_12.rst index b87682b6dc..2b67414948 100644 --- a/doc/build/changelog/changelog_12.rst +++ b/doc/build/changelog/changelog_12.rst @@ -13,6 +13,17 @@ .. changelog:: :version: 1.2.0b1 + .. change:: 3873 + :tags: bug, sql + :tickets: 3873 + + Repaired issue where the type of an expression that used + :meth:`.ColumnOperators.is_` or similar would not be a "boolean" type, + instead the type would be "nulltype", as well as when using custom + comparison operators against an untyped expression. This typing can + impact how the expression behaves in larger contexts as well as + in result-row-handling. + .. change:: 3969 :tags: bug, sql :tickets: 3969 diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 4ba53ef758..4485c661b9 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -50,11 +50,15 @@ def _boolean_compare(expr, op, obj, negate=None, reverse=False, if op in (operators.eq, operators.is_): return BinaryExpression(expr, _const_expr(obj), operators.is_, - negate=operators.isnot) + negate=operators.isnot, + type_=result_type + ) elif op in (operators.ne, operators.isnot): return BinaryExpression(expr, _const_expr(obj), operators.isnot, - negate=operators.is_) + negate=operators.is_, + type_=result_type + ) else: raise exc.ArgumentError( "Only '=', '!=', 'is_()', 'isnot()', " diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 49642acdd5..58f32b3e6f 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -1021,7 +1021,8 @@ def json_path_getitem_op(a, b): _commutative = {eq, ne, add, mul} -_comparison = {eq, ne, lt, gt, ge, le, between_op, like_op} +_comparison = {eq, ne, lt, gt, ge, le, between_op, like_op, is_, + isnot, is_distinct_from, isnot_distinct_from} def is_comparison(op): diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index b8117e3ca1..b8c8c81168 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2555,7 +2555,9 @@ class NullType(TypeEngine): class Comparator(TypeEngine.Comparator): def _adapt_expression(self, op, other_comparator): - if isinstance(other_comparator, NullType.Comparator) or \ + if operators.is_comparison(op): + return op, BOOLEANTYPE + elif isinstance(other_comparator, NullType.Comparator) or \ not operators.is_commutative(op): return op, self.expr.type else: diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index c0637d2257..3dd9af5e23 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -3,6 +3,7 @@ from sqlalchemy import testing from sqlalchemy.testing import assert_raises_message from sqlalchemy.sql import column, desc, asc, literal, collate, null, \ true, false, any_, all_ +from sqlalchemy.sql import sqltypes from sqlalchemy.sql.expression import BinaryExpression, \ ClauseList, Grouping, \ UnaryExpression, select, union, func, tuple_ @@ -62,6 +63,12 @@ class DefaultColumnComparatorTest(fixtures.TestBase): self._loop_test(operator, right) + if operators.is_comparison(operator): + is_( + left.comparator.operate(operator, right).type, + sqltypes.BOOLEANTYPE + ) + def _loop_test(self, operator, *arg): loop = LoopOperate() is_( @@ -2617,6 +2624,9 @@ class CustomOpTest(fixtures.TestBase): assert operators.is_comparison(op1) assert not operators.is_comparison(op2) + expr = c.op('$', is_comparison=True)(None) + is_(expr.type, sqltypes.BOOLEANTYPE) + class TupleTypingTest(fixtures.TestBase):