From: Jim Bosch Date: Sun, 26 Jul 2020 20:50:14 +0000 (-0400) Subject: Ensure is_comparison passed for PG RANGE op() methods X-Git-Tag: rel_1_4_0b1~211^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=07e57a0330fb7b1bbe0c59f442111a34e4b7c960;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Ensure is_comparison passed for PG RANGE op() methods Fixed issue where the return type for the various RANGE comparison operators would itself be the same RANGE type rather than BOOLEAN, which would cause an undesirable result in the case that a :class:`.TypeDecorator` that defined result-processing behavior were in use. Pull request courtesy Jim Bosch. Fixes: #5476 Closes: #5477 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5477 Pull-request-sha: 925b117e0c91cdd67d9ddbd9d65f5ca3e88af91f Change-Id: I52ab4d4362d379c8253990f9d328a40990a64520 --- diff --git a/doc/build/changelog/unreleased_13/5476.rst b/doc/build/changelog/unreleased_13/5476.rst new file mode 100644 index 0000000000..abbf9b7bde --- /dev/null +++ b/doc/build/changelog/unreleased_13/5476.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, postgresql + :tickets: 5476 + + Fixed issue where the return type for the various RANGE comparison + operators would itself be the same RANGE type rather than BOOLEAN, which + would cause an undesirable result in the case that a + :class:`.TypeDecorator` that defined result-processing behavior were in + use. Pull request courtesy Jim Bosch. + + diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index d4f75b4948..a31d958ed9 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -36,32 +36,32 @@ class RangeOperators(object): other ) else: - return self.expr.op("<>")(other) + return self.expr.op("<>", is_comparison=True)(other) def contains(self, other, **kw): """Boolean expression. Returns true if the right hand operand, which can be an element or a range, is contained within the column. """ - return self.expr.op("@>")(other) + return self.expr.op("@>", is_comparison=True)(other) def contained_by(self, other): """Boolean expression. Returns true if the column is contained within the right hand operand. """ - return self.expr.op("<@")(other) + return self.expr.op("<@", is_comparison=True)(other) def overlaps(self, other): """Boolean expression. Returns true if the column overlaps (has points in common with) the right hand operand. """ - return self.expr.op("&&")(other) + return self.expr.op("&&", is_comparison=True)(other) def strictly_left_of(self, other): """Boolean expression. Returns true if the column is strictly left of the right hand operand. """ - return self.expr.op("<<")(other) + return self.expr.op("<<", is_comparison=True)(other) __lshift__ = strictly_left_of @@ -69,7 +69,7 @@ class RangeOperators(object): """Boolean expression. Returns true if the column is strictly right of the right hand operand. """ - return self.expr.op(">>")(other) + return self.expr.op(">>", is_comparison=True)(other) __rshift__ = strictly_right_of @@ -77,19 +77,19 @@ class RangeOperators(object): """Boolean expression. Returns true if the range in the column does not extend right of the range in the operand. """ - return self.expr.op("&<")(other) + return self.expr.op("&<", is_comparison=True)(other) def not_extend_left_of(self, other): """Boolean expression. Returns true if the range in the column does not extend left of the range in the operand. """ - return self.expr.op("&>")(other) + return self.expr.op("&>", is_comparison=True)(other) def adjacent_to(self, other): """Boolean expression. Returns true if the range in the column is adjacent to the range in the operand. """ - return self.expr.op("-|-")(other) + return self.expr.op("-|-", is_comparison=True)(other) def __add__(self, other): """Range expression. Returns the union of the two ranges. diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 9331f99105..d229892918 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -2628,112 +2628,149 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): ) cls.col = table.c.range - def _test_clause(self, colclause, expected): + def _test_clause(self, colclause, expected, type_): self.assert_compile(colclause, expected) + is_(colclause.type._type_affinity, type_._type_affinity) def test_where_equal(self): self._test_clause( - self.col == self._data_str, "data_table.range = %(range_1)s" + self.col == self._data_str, + "data_table.range = %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_where_not_equal(self): self._test_clause( - self.col != self._data_str, "data_table.range <> %(range_1)s" + self.col != self._data_str, + "data_table.range <> %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_where_is_null(self): - self._test_clause(self.col == None, "data_table.range IS NULL") + self._test_clause( + self.col == None, "data_table.range IS NULL", sqltypes.BOOLEANTYPE + ) def test_where_is_not_null(self): - self._test_clause(self.col != None, "data_table.range IS NOT NULL") + self._test_clause( + self.col != None, + "data_table.range IS NOT NULL", + sqltypes.BOOLEANTYPE, + ) def test_where_less_than(self): self._test_clause( - self.col < self._data_str, "data_table.range < %(range_1)s" + self.col < self._data_str, + "data_table.range < %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_where_greater_than(self): self._test_clause( - self.col > self._data_str, "data_table.range > %(range_1)s" + self.col > self._data_str, + "data_table.range > %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_where_less_than_or_equal(self): self._test_clause( - self.col <= self._data_str, "data_table.range <= %(range_1)s" + self.col <= self._data_str, + "data_table.range <= %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_where_greater_than_or_equal(self): self._test_clause( - self.col >= self._data_str, "data_table.range >= %(range_1)s" + self.col >= self._data_str, + "data_table.range >= %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_contains(self): self._test_clause( self.col.contains(self._data_str), "data_table.range @> %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_contained_by(self): self._test_clause( self.col.contained_by(self._data_str), "data_table.range <@ %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_overlaps(self): self._test_clause( self.col.overlaps(self._data_str), "data_table.range && %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_strictly_left_of(self): self._test_clause( - self.col << self._data_str, "data_table.range << %(range_1)s" + self.col << self._data_str, + "data_table.range << %(range_1)s", + sqltypes.BOOLEANTYPE, ) self._test_clause( self.col.strictly_left_of(self._data_str), "data_table.range << %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_strictly_right_of(self): self._test_clause( - self.col >> self._data_str, "data_table.range >> %(range_1)s" + self.col >> self._data_str, + "data_table.range >> %(range_1)s", + sqltypes.BOOLEANTYPE, ) self._test_clause( self.col.strictly_right_of(self._data_str), "data_table.range >> %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_not_extend_right_of(self): self._test_clause( self.col.not_extend_right_of(self._data_str), "data_table.range &< %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_not_extend_left_of(self): self._test_clause( self.col.not_extend_left_of(self._data_str), "data_table.range &> %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_adjacent_to(self): self._test_clause( self.col.adjacent_to(self._data_str), "data_table.range -|- %(range_1)s", + sqltypes.BOOLEANTYPE, ) def test_union(self): self._test_clause( - self.col + self.col, "data_table.range + data_table.range" + self.col + self.col, + "data_table.range + data_table.range", + self.col.type, ) def test_intersection(self): self._test_clause( - self.col * self.col, "data_table.range * data_table.range" + self.col * self.col, + "data_table.range * data_table.range", + self.col.type, ) def test_different(self): self._test_clause( - self.col - self.col, "data_table.range - data_table.range" + self.col - self.col, + "data_table.range - data_table.range", + self.col.type, )