]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure is_comparison passed for PG RANGE op() methods
authorJim Bosch <jbosch@astro.princeton.edu>
Sun, 26 Jul 2020 20:50:14 +0000 (16:50 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 26 Jul 2020 20:57:13 +0000 (16:57 -0400)
Fixed issue where the return type for the various RANGE comparison
operators would itself be the same RANGE type rather than BOOLEAN, which
would cause an undesirable result in the case that a
:class:`.TypeDecorator` that defined result-processing behavior were in
use.  Pull request courtesy Jim Bosch.

Fixes: #5476
Closes: #5477
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5477
Pull-request-sha: 925b117e0c91cdd67d9ddbd9d65f5ca3e88af91f

Change-Id: I52ab4d4362d379c8253990f9d328a40990a64520

doc/build/changelog/unreleased_13/5476.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_13/5476.rst b/doc/build/changelog/unreleased_13/5476.rst
new file mode 100644 (file)
index 0000000..abbf9b7
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 5476
+
+    Fixed issue where the return type for the various RANGE comparison
+    operators would itself be the same RANGE type rather than BOOLEAN, which
+    would cause an undesirable result in the case that a
+    :class:`.TypeDecorator` that defined result-processing behavior were in
+    use.  Pull request courtesy Jim Bosch.
+
+
index d4f75b4948cb70a3b3f7d8e19f28c4cbc8a5d93a..a31d958ed93b9dc7baa82947ef3f3ac25a929bf9 100644 (file)
@@ -36,32 +36,32 @@ class RangeOperators(object):
                     other
                 )
             else:
-                return self.expr.op("<>")(other)
+                return self.expr.op("<>", is_comparison=True)(other)
 
         def contains(self, other, **kw):
             """Boolean expression. Returns true if the right hand operand,
             which can be an element or a range, is contained within the
             column.
             """
-            return self.expr.op("@>")(other)
+            return self.expr.op("@>", is_comparison=True)(other)
 
         def contained_by(self, other):
             """Boolean expression. Returns true if the column is contained
             within the right hand operand.
             """
-            return self.expr.op("<@")(other)
+            return self.expr.op("<@", is_comparison=True)(other)
 
         def overlaps(self, other):
             """Boolean expression. Returns true if the column overlaps
             (has points in common with) the right hand operand.
             """
-            return self.expr.op("&&")(other)
+            return self.expr.op("&&", is_comparison=True)(other)
 
         def strictly_left_of(self, other):
             """Boolean expression. Returns true if the column is strictly
             left of the right hand operand.
             """
-            return self.expr.op("<<")(other)
+            return self.expr.op("<<", is_comparison=True)(other)
 
         __lshift__ = strictly_left_of
 
@@ -69,7 +69,7 @@ class RangeOperators(object):
             """Boolean expression. Returns true if the column is strictly
             right of the right hand operand.
             """
-            return self.expr.op(">>")(other)
+            return self.expr.op(">>", is_comparison=True)(other)
 
         __rshift__ = strictly_right_of
 
@@ -77,19 +77,19 @@ class RangeOperators(object):
             """Boolean expression. Returns true if the range in the column
             does not extend right of the range in the operand.
             """
-            return self.expr.op("&<")(other)
+            return self.expr.op("&<", is_comparison=True)(other)
 
         def not_extend_left_of(self, other):
             """Boolean expression. Returns true if the range in the column
             does not extend left of the range in the operand.
             """
-            return self.expr.op("&>")(other)
+            return self.expr.op("&>", is_comparison=True)(other)
 
         def adjacent_to(self, other):
             """Boolean expression. Returns true if the range in the column
             is adjacent to the range in the operand.
             """
-            return self.expr.op("-|-")(other)
+            return self.expr.op("-|-", is_comparison=True)(other)
 
         def __add__(self, other):
             """Range expression. Returns the union of the two ranges.
index 9331f991059d8af6c401b64ae06ed55679949a76..d229892918dec90513bd68028dfb986d8b44711b 100644 (file)
@@ -2628,112 +2628,149 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
         )
         cls.col = table.c.range
 
-    def _test_clause(self, colclause, expected):
+    def _test_clause(self, colclause, expected, type_):
         self.assert_compile(colclause, expected)
+        is_(colclause.type._type_affinity, type_._type_affinity)
 
     def test_where_equal(self):
         self._test_clause(
-            self.col == self._data_str, "data_table.range = %(range_1)s"
+            self.col == self._data_str,
+            "data_table.range = %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_where_not_equal(self):
         self._test_clause(
-            self.col != self._data_str, "data_table.range <> %(range_1)s"
+            self.col != self._data_str,
+            "data_table.range <> %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_where_is_null(self):
-        self._test_clause(self.col == None, "data_table.range IS NULL")
+        self._test_clause(
+            self.col == None, "data_table.range IS NULL", sqltypes.BOOLEANTYPE
+        )
 
     def test_where_is_not_null(self):
-        self._test_clause(self.col != None, "data_table.range IS NOT NULL")
+        self._test_clause(
+            self.col != None,
+            "data_table.range IS NOT NULL",
+            sqltypes.BOOLEANTYPE,
+        )
 
     def test_where_less_than(self):
         self._test_clause(
-            self.col < self._data_str, "data_table.range < %(range_1)s"
+            self.col < self._data_str,
+            "data_table.range < %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_where_greater_than(self):
         self._test_clause(
-            self.col > self._data_str, "data_table.range > %(range_1)s"
+            self.col > self._data_str,
+            "data_table.range > %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_where_less_than_or_equal(self):
         self._test_clause(
-            self.col <= self._data_str, "data_table.range <= %(range_1)s"
+            self.col <= self._data_str,
+            "data_table.range <= %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_where_greater_than_or_equal(self):
         self._test_clause(
-            self.col >= self._data_str, "data_table.range >= %(range_1)s"
+            self.col >= self._data_str,
+            "data_table.range >= %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_contains(self):
         self._test_clause(
             self.col.contains(self._data_str),
             "data_table.range @> %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_contained_by(self):
         self._test_clause(
             self.col.contained_by(self._data_str),
             "data_table.range <@ %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_overlaps(self):
         self._test_clause(
             self.col.overlaps(self._data_str),
             "data_table.range && %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_strictly_left_of(self):
         self._test_clause(
-            self.col << self._data_str, "data_table.range << %(range_1)s"
+            self.col << self._data_str,
+            "data_table.range << %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
         self._test_clause(
             self.col.strictly_left_of(self._data_str),
             "data_table.range << %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_strictly_right_of(self):
         self._test_clause(
-            self.col >> self._data_str, "data_table.range >> %(range_1)s"
+            self.col >> self._data_str,
+            "data_table.range >> %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
         self._test_clause(
             self.col.strictly_right_of(self._data_str),
             "data_table.range >> %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_not_extend_right_of(self):
         self._test_clause(
             self.col.not_extend_right_of(self._data_str),
             "data_table.range &< %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_not_extend_left_of(self):
         self._test_clause(
             self.col.not_extend_left_of(self._data_str),
             "data_table.range &> %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_adjacent_to(self):
         self._test_clause(
             self.col.adjacent_to(self._data_str),
             "data_table.range -|- %(range_1)s",
+            sqltypes.BOOLEANTYPE,
         )
 
     def test_union(self):
         self._test_clause(
-            self.col + self.col, "data_table.range + data_table.range"
+            self.col + self.col,
+            "data_table.range + data_table.range",
+            self.col.type,
         )
 
     def test_intersection(self):
         self._test_clause(
-            self.col * self.col, "data_table.range * data_table.range"
+            self.col * self.col,
+            "data_table.range * data_table.range",
+            self.col.type,
         )
 
     def test_different(self):
         self._test_clause(
-            self.col - self.col, "data_table.range - data_table.range"
+            self.col - self.col,
+            "data_table.range - data_table.range",
+            self.col.type,
         )