]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve test_difference() cases and fix difference() accordingly
authorLele Gaifax <lele@metapensiero.it>
Wed, 9 Nov 2022 08:47:37 +0000 (09:47 +0100)
committerLele Gaifax <lele@metapensiero.it>
Wed, 9 Nov 2022 08:47:37 +0000 (09:47 +0100)
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

index 7327f6f1cf1f5a8f2391ae41e9409bbca7edc040..d12b4dffe62a92d239bf69ba8d67b6287c6f6f54 100644 (file)
@@ -525,7 +525,12 @@ class Range(Generic[_T]):
         # middle
         if sl_vs_ol <= 0 and su_vs_ol >= 0 and su_vs_ou <= 0:
             rupper_b = ")" if olower_b == "[" else "]"
-            if self._compare_edges(slower, slower_b, olower, rupper_b) == 0:
+            if (
+                slower_b != "["
+                and rupper_b != "]"
+                and self._compare_edges(slower, slower_b, olower, rupper_b)
+                == 0
+            ):
                 return Range(None, None, empty=True)
             else:
                 return Range(slower, olower, bounds=slower_b + rupper_b)
@@ -533,8 +538,13 @@ class Range(Generic[_T]):
         # If this range starts in the middle of the other and extends to its
         # right
         if sl_vs_ol >= 0 and su_vs_ou >= 0 and sl_vs_ou <= 0:
-            rlower_b = "(" if oupper_b == "]" else "("
-            if self._compare_edges(oupper, rlower_b, supper, supper_b) == 0:
+            rlower_b = "(" if oupper_b == "]" else "["
+            if (
+                rlower_b != "["
+                and supper_b != "]"
+                and self._compare_edges(oupper, rlower_b, supper, supper_b)
+                == 0
+            ):
                 return Range(None, None, empty=True)
             else:
                 return Range(oupper, supper, bounds=rlower_b + supper_b)
index 128041e4adabb0163be51547ea1421a7a7ef58a5..faa1af7b19c04d1922fc98e468119d998c003562 100644 (file)
@@ -4280,29 +4280,37 @@ class _RangeComparisonFixtures(_RangeTests):
             )
 
     @testing.combinations(
-        ("empty", "empty", False),
-        ("empty", "[{le},{re_}]", False),
-        ("[{le},{re_}]", "empty", False),
-        ("[{ll},{ih}]", "({le},{ih}]", False),
-        ("[{ll},{rh})", "[{le},{re_}]", False),
-        ("[{le},{re_}]", "({le},{re_})", True),
-        ("[{ll},{rh}]", "[{le},{re_}]", True),
-        argnames="r1repr,r2repr,err",
+        lambda r, e: Range(empty=True),
+        lambda r, e: r,
+        lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+        lambda r, e: Range(r.lower, r.upper, bounds="[]"),
+        lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"),
+        lambda r, e: Range(r.lower - e, r.upper + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.upper + e, bounds="[]"),
+        argnames="r1t",
     )
-    def test_difference(self, connection, r1repr, r2repr, err):
-        data = self._value_values()
-
-        if r1repr != "empty":
-            r1repr = r1repr.format(**data)
-        if r2repr != "empty":
-            r2repr = r2repr.format(**data)
+    @testing.combinations(
+        lambda r, e: Range(empty=True),
+        lambda r, e: r,
+        lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+        lambda r, e: Range(r.lower, r.upper - e, bounds="(]"),
+        lambda r, e: Range(r.lower, r.lower + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.lower, bounds="(]"),
+        lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"),
+        lambda r, e: Range(r.lower, r.upper, bounds="[]"),
+        lambda r, e: Range(r.lower, r.upper, bounds="()"),
+        argnames="r2t",
+    )
+    def test_difference(self, connection, r1t, r2t):
+        r1 = r1t(self._data_obj(), self._epsilon)
+        r2 = r2t(self._data_obj(), self._epsilon)
 
         RANGE = self._col_type
         range_typ = self._col_str
 
         q = select(
-            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
-            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+            literal_column(f"'{r1}'::{range_typ}", RANGE).label("r1"),
+            literal_column(f"'{r2}'::{range_typ}", RANGE).label("r2"),
         )
 
         r1, r2 = connection.execute(q).first()
@@ -4311,17 +4319,17 @@ class _RangeComparisonFixtures(_RangeTests):
             literal_column(f"'{r1}'::{range_typ}-'{r2}'::{range_typ}", RANGE),
         )
 
-        if err:
-            with expect_raises(DataError):
-                connection.execute(resq).scalar()
+        try:
+            difference = connection.execute(resq).scalar()
+        except DataError:
             with expect_raises(ValueError):
                 r1.difference(r2)
         else:
-            difference = connection.execute(resq).scalar()
             eq_(
                 r1.difference(r2),
                 difference,
-                f"{r1}.difference({r2}) != {difference}",
+                f"{r1}.difference({r2}): got {r1.difference(r2)},"
+                f" expected {difference}",
             )
 
     @testing.combinations(