]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Further extend test combinations
authorLele Gaifax <lele@metapensiero.it>
Thu, 10 Nov 2022 07:22:28 +0000 (08:22 +0100)
committerLele Gaifax <lele@metapensiero.it>
Thu, 10 Nov 2022 07:22:28 +0000 (08:22 +0100)
test/dialect/postgresql/test_types.py

index faa1af7b19c04d1922fc98e468119d998c003562..55afd2d18da3247d59d05b607cd7f26651c56627 100644 (file)
@@ -4021,277 +4021,319 @@ class _RangeComparisonFixtures(_RangeTests):
         r, expected = connection.execute(q).first()
         eq_(r.contains(v), expected)
 
-    def test_contains_range(
-        self,
-        connection,
-        bounds_obj_combinations,
-        contains_range_obj_combinations,
-    ):
-        r1repr = contains_range_obj_combinations._stringify()
-        r2repr = bounds_obj_combinations._stringify()
+    _common_ranges_to_test = (
+        lambda r, e: Range(empty=True),
+        lambda r, e: Range(None, None, bounds="()"),
+        lambda r, e: Range(r.lower, None, bounds="[)"),
+        lambda r, e: Range(None, r.upper, bounds="(]"),
+        lambda r, e: r,
+        lambda r, e: Range(r.lower, r.upper, bounds="[]"),
+        lambda r, e: Range(r.lower, r.upper, bounds="(]"),
+        lambda r, e: Range(r.lower, r.upper, bounds="()"),
+    )
+
+    @testing.combinations(
+        *_common_ranges_to_test,
+        argnames="r1t",
+    )
+    @testing.combinations(
+        *_common_ranges_to_test,
+        lambda r, e: Range(r.lower + e, r.upper + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.upper - e, bounds="(]"),
+        lambda r, e: Range(r.lower + e, r.upper - e, bounds="[]"),
+        lambda r, e: Range(r.lower + e, r.upper - e, bounds="(]"),
+        lambda r, e: Range(r.lower + e, r.upper, bounds="(]"),
+        lambda r, e: Range(r.lower + e, r.upper, bounds="[]"),
+        lambda r, e: Range(r.lower + e, r.upper + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"),
+        lambda r, e: Range(r.lower - 2 * e, r.lower - e, bounds="(]"),
+        lambda r, e: Range(r.lower - 4 * e, r.lower, bounds="[)"),
+        lambda r, e: Range(r.upper + 4 * e, r.upper + 6 * e, bounds="()"),
+        argnames="r2t",
+    )
+    def test_contains_range(self, connection, r1t, r2t):
+        r1 = r1t(self._data_obj(), self._epsilon)
+        r2 = r2t(self._data_obj(), self._epsilon)
 
         RANGE = self._col_type
         range_typ = self._col_str
 
         q = select(
-            cast(contains_range_obj_combinations, RANGE).label("r1"),
-            cast(bounds_obj_combinations, RANGE).label("r2"),
-            cast(contains_range_obj_combinations, RANGE).contains(
-                bounds_obj_combinations
-            ),
-            cast(contains_range_obj_combinations, RANGE).contained_by(
-                bounds_obj_combinations
-            ),
+            cast(r1, RANGE).contains(r2),
+            cast(r1, RANGE).contained_by(r2),
         )
+
         validate_q = select(
-            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
-            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
-            literal_column(
-                f"'{r1repr}'::{range_typ} @> '{r2repr}'::{range_typ}"
-            ),
-            literal_column(
-                f"'{r1repr}'::{range_typ} <@ '{r2repr}'::{range_typ}"
-            ),
+            literal_column(f"'{r1}'::{range_typ} @> '{r2}'::{range_typ}"),
+            literal_column(f"'{r1}'::{range_typ} <@ '{r2}'::{range_typ}"),
         )
-        orig_row = connection.execute(q).first()
+
+        row = connection.execute(q).first()
         validate_row = connection.execute(validate_q).first()
-        eq_(orig_row, validate_row)
+        eq_(row, validate_row)
 
-        r1, r2, contains, contained = orig_row
-        eq_(r1.contains(r2), contains, f"{r1}.contains({r2}) != {contains}")
+        pg_contains, pg_contained = row
+        py_contains = r1.contains(r2)
+        eq_(
+            py_contains,
+            pg_contains,
+            f"{r1}.contains({r2}): got {py_contains},"
+            f" expected {pg_contains}",
+        )
+        py_contained = r1.contained_by(r2)
         eq_(
-            r1.contained_by(r2),
-            contained,
-            f"{r1}.contained_by({r2}) != {contained}",
+            py_contained,
+            pg_contained,
+            f"{r1}.contained_by({r2}): got {py_contained},"
+            f" expected {pg_contained}",
+        )
+        eq_(
+            r2.contains(r1),
+            pg_contained,
+            f"{r2}.contains({r1}: got {r2.contains(r1)},"
+            f" expected {pg_contained})",
         )
-        eq_(r2.contains(r1), contained, f"{r2}.contains({r1} != {contained})")
 
-    def test_overlaps(
-        self,
-        connection,
-        bounds_obj_combinations,
-        contains_range_obj_combinations,
-    ):
-        r1repr = contains_range_obj_combinations._stringify()
-        r2repr = bounds_obj_combinations._stringify()
+    @testing.combinations(
+        *_common_ranges_to_test,
+        argnames="r1t",
+    )
+    @testing.combinations(
+        *_common_ranges_to_test,
+        lambda r, e: Range(r.lower - 2 * e, r.lower - e, bounds="(]"),
+        lambda r, e: Range(r.upper + e, r.upper + 2 * e, bounds="[)"),
+        argnames="r2t",
+    )
+    def test_overlaps(self, connection, r1t, r2t):
+        r1 = r1t(self._data_obj(), self._epsilon)
+        r2 = r2t(self._data_obj(), self._epsilon)
 
         RANGE = self._col_type
         range_typ = self._col_str
 
         q = select(
-            cast(contains_range_obj_combinations, RANGE).label("r1"),
-            cast(bounds_obj_combinations, RANGE).label("r2"),
-            cast(contains_range_obj_combinations, RANGE).overlaps(
-                bounds_obj_combinations
-            ),
+            cast(r1, RANGE).overlaps(r2),
         )
         validate_q = select(
-            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
-            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
-            literal_column(
-                f"'{r1repr}'::{range_typ} && '{r2repr}'::{range_typ}"
-            ),
+            literal_column(f"'{r1}'::{range_typ} && '{r2}'::{range_typ}"),
         )
-        orig_row = connection.execute(q).first()
+        row = connection.execute(q).first()
         validate_row = connection.execute(validate_q).first()
-        eq_(orig_row, validate_row)
+        eq_(row, validate_row)
 
-        r1, r2, overlaps = orig_row
-        eq_(r1.overlaps(r2), overlaps, f"{r1}.overlaps({r2}) != {overlaps}")
+        pg_res = row[0]
+        py_res = r1.overlaps(r2)
+        eq_(
+            py_res,
+            pg_res,
+            f"{r1}.overlaps({r2}): got {py_res}, expected {pg_res}",
+        )
 
-    def test_strictly_left_or_right_of(
-        self,
-        connection,
-        bounds_obj_combinations,
-        contains_range_obj_combinations,
-    ):
-        r1repr = contains_range_obj_combinations._stringify()
-        r2repr = bounds_obj_combinations._stringify()
+    @testing.combinations(
+        *_common_ranges_to_test,
+        argnames="r1t",
+    )
+    @testing.combinations(
+        *_common_ranges_to_test,
+        lambda r, e: Range(r.upper, r.upper + 2 * e, bounds="[]"),
+        lambda r, e: Range(r.upper, r.upper + 2 * e, bounds="(]"),
+        lambda r, e: Range(r.lower - 2 * e, r.lower, bounds="[]"),
+        lambda r, e: Range(r.lower - 2 * e, r.lower, bounds="[)"),
+        argnames="r2t",
+    )
+    def test_strictly_left_or_right_of(self, connection, r1t, r2t):
+        r1 = r1t(self._data_obj(), self._epsilon)
+        r2 = r2t(self._data_obj(), self._epsilon)
 
         RANGE = self._col_type
         range_typ = self._col_str
 
         q = select(
-            cast(contains_range_obj_combinations, RANGE).label("r1"),
-            cast(bounds_obj_combinations, RANGE).label("r2"),
-            cast(contains_range_obj_combinations, RANGE).strictly_left_of(
-                bounds_obj_combinations
-            ),
-            cast(contains_range_obj_combinations, RANGE).strictly_right_of(
-                bounds_obj_combinations
-            ),
+            cast(r1, RANGE).strictly_left_of(r2),
+            cast(r1, RANGE).strictly_right_of(r2),
         )
         validate_q = select(
-            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
-            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
-            literal_column(
-                f"'{r1repr}'::{range_typ} << '{r2repr}'::{range_typ}"
-            ),
-            literal_column(
-                f"'{r1repr}'::{range_typ} >> '{r2repr}'::{range_typ}"
-            ),
+            literal_column(f"'{r1}'::{range_typ} << '{r2}'::{range_typ}"),
+            literal_column(f"'{r1}'::{range_typ} >> '{r2}'::{range_typ}"),
         )
-        orig_row = connection.execute(q).first()
+
+        row = connection.execute(q).first()
         validate_row = connection.execute(validate_q).first()
-        eq_(orig_row, validate_row)
+        eq_(row, validate_row)
 
-        r1, r2, leftof, rightof = orig_row
+        pg_left, pg_right = row
+        py_left = r1.strictly_left_of(r2)
+        eq_(
+            py_left,
+            pg_left,
+            f"{r1}.strictly_left_of({r2}): got {py_left}, expected {pg_left}",
+        )
+        py_left = r1 << r2
         eq_(
-            r1.strictly_left_of(r2),
-            leftof,
-            f"{r1}.strictly_left_of({r2}) != {leftof}",
+            py_left,
+            pg_left,
+            f"{r1} << {r2}: got {py_left}, expected {pg_left}",
         )
-        eq_(r1 << r2, leftof, f"{r1} << {r2} != {leftof}")
+        py_right = r1.strictly_right_of(r2)
         eq_(
-            r1.strictly_right_of(r2),
-            rightof,
-            f"{r1}.strictly_right_of({r2}) != {rightof}",
+            py_right,
+            pg_right,
+            f"{r1}.strictly_right_of({r2}): got {py_left},"
+            f" expected {pg_right}",
+        )
+        py_right = r1 >> r2
+        eq_(
+            py_right,
+            pg_right,
+            f"{r1} >> {r2}: got {py_left}, expected {pg_right}",
         )
-        eq_(r1 >> r2, rightof, f"{r1} >> {r2} != {rightof}")
 
-    def test_not_extend_left_or_right_of(
-        self,
-        connection,
-        bounds_obj_combinations,
-        contains_range_obj_combinations,
-    ):
-        r1repr = contains_range_obj_combinations._stringify()
-        r2repr = bounds_obj_combinations._stringify()
+    @testing.combinations(
+        *_common_ranges_to_test,
+        argnames="r1t",
+    )
+    @testing.combinations(
+        *_common_ranges_to_test,
+        lambda r, e: Range(r.upper, r.upper + 2 * e, bounds="[]"),
+        lambda r, e: Range(r.upper, r.upper + 2 * e, bounds="(]"),
+        lambda r, e: Range(r.lower - 2 * e, r.lower, bounds="[]"),
+        lambda r, e: Range(r.lower - 2 * e, r.lower, bounds="[)"),
+        argnames="r2t",
+    )
+    def test_not_extend_left_or_right_of(self, connection, r1t, r2t):
+        r1 = r1t(self._data_obj(), self._epsilon)
+        r2 = r2t(self._data_obj(), self._epsilon)
 
         RANGE = self._col_type
         range_typ = self._col_str
 
         q = select(
-            cast(contains_range_obj_combinations, RANGE).label("r1"),
-            cast(bounds_obj_combinations, RANGE).label("r2"),
-            cast(contains_range_obj_combinations, RANGE).not_extend_left_of(
-                bounds_obj_combinations
-            ),
-            cast(contains_range_obj_combinations, RANGE).not_extend_right_of(
-                bounds_obj_combinations
-            ),
+            cast(r1, RANGE).not_extend_left_of(r2),
+            cast(r1, RANGE).not_extend_right_of(r2),
         )
         validate_q = select(
-            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
-            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
-            literal_column(
-                f"'{r1repr}'::{range_typ} &> '{r2repr}'::{range_typ}"
-            ),
-            literal_column(
-                f"'{r1repr}'::{range_typ} &< '{r2repr}'::{range_typ}"
-            ),
+            literal_column(f"'{r1}'::{range_typ} &> '{r2}'::{range_typ}"),
+            literal_column(f"'{r1}'::{range_typ} &< '{r2}'::{range_typ}"),
         )
-        orig_row = connection.execute(q).first()
+        row = connection.execute(q).first()
         validate_row = connection.execute(validate_q).first()
-        eq_(orig_row, validate_row)
+        eq_(row, validate_row)
 
-        r1, r2, leftof, rightof = orig_row
+        pg_left, pg_right = row
+        py_left = r1.not_extend_left_of(r2)
         eq_(
-            r1.not_extend_left_of(r2),
-            leftof,
-            f"{r1}.not_extend_left_of({r2}) != {leftof}",
+            py_left,
+            pg_left,
+            f"{r1}.not_extend_left_of({r2}): got {py_left},"
+            f" expected {pg_left}",
         )
+        py_right = r1.not_extend_right_of(r2)
         eq_(
-            r1.not_extend_right_of(r2),
-            rightof,
-            f"{r1}.not_extend_right_of({r2}) != {rightof}",
+            py_right,
+            pg_right,
+            f"{r1}.not_extend_right_of({r2}): got {py_right},"
+            f" expected {pg_right}",
         )
 
-    def test_adjacent(
-        self,
-        connection,
-        bounds_obj_combinations,
-        contains_range_obj_combinations,
-    ):
-        r1repr = contains_range_obj_combinations._stringify()
-        r2repr = bounds_obj_combinations._stringify()
+    @testing.combinations(
+        *_common_ranges_to_test,
+        lambda r, e: Range(r.lower - e, r.lower + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.lower - e, bounds="[]"),
+        argnames="r1t",
+    )
+    @testing.combinations(
+        *_common_ranges_to_test,
+        lambda r, e: Range(r.lower - e, r.lower + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.lower - e, bounds="[]"),
+        lambda r, e: Range(r.lower + e, r.upper - e, bounds="(]"),
+        lambda r, e: Range(r.lower + e, r.upper - e, bounds="[]"),
+        lambda r, e: Range(r.lower + e, r.upper, bounds="(]"),
+        lambda r, e: Range(r.lower + e, r.upper, bounds="[]"),
+        lambda r, e: Range(r.lower + e, r.upper + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.lower - e, bounds="[]"),
+        lambda r, e: Range(r.lower - 2 * e, r.lower - e, bounds="(]"),
+        lambda r, e: Range(r.lower - 4 * e, r.lower, bounds="[)"),
+        lambda r, e: Range(r.upper + 4 * e, r.upper + 6 * e, bounds="()"),
+        argnames="r2t",
+    )
+    def test_adjacent(self, connection, r1t, r2t):
+        r1 = r1t(self._data_obj(), self._epsilon)
+        r2 = r2t(self._data_obj(), self._epsilon)
 
         RANGE = self._col_type
         range_typ = self._col_str
 
         q = select(
-            cast(contains_range_obj_combinations, RANGE).label("r1"),
-            cast(bounds_obj_combinations, RANGE).label("r2"),
-            cast(contains_range_obj_combinations, RANGE).adjacent_to(
-                bounds_obj_combinations
-            ),
+            cast(r1, RANGE).adjacent_to(r2),
         )
         validate_q = select(
-            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
-            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
-            literal_column(
-                f"'{r1repr}'::{range_typ} -|- '{r2repr}'::{range_typ}"
-            ),
+            literal_column(f"'{r1}'::{range_typ} -|- '{r2}'::{range_typ}"),
         )
-        orig_row = connection.execute(q).first()
+
+        row = connection.execute(q).first()
         validate_row = connection.execute(validate_q).first()
-        eq_(orig_row, validate_row)
+        eq_(row, validate_row)
 
-        r1, r2, adjacent = orig_row
+        pg_res = row[0]
+        py_res = r1.adjacent_to(r2)
         eq_(
-            r1.adjacent_to(r2),
-            adjacent,
-            f"{r1}.adjacent_to({r2}) != {adjacent}",
+            py_res,
+            pg_res,
+            f"{r1}.adjacent_to({r2}): got {py_res}, expected {pg_res}",
         )
 
     @testing.combinations(
-        ("empty", "empty", False),
-        ("empty", "[{le},{re_}]", False),
-        ("[{le},{re_}]", "empty", False),
-        ("[{le},{re_}]", "[{le},{re_}]", False),
-        ("[{ll},{rh})", "[{le},{re_}]", False),
-        ("[{ll},{ll}]", "({le},{rh}]", True),
-        ("[{ll},{ll}]", "[{rh},{rh}]", True),
-        argnames="r1repr,r2repr,err",
+        *_common_ranges_to_test,
+        argnames="r1t",
     )
-    def test_union(self, connection, r1repr, r2repr, err):
-        data = self._value_values()
-
-        if r1repr != "empty":
-            r1repr = r1repr.format(**data)
-        if r2repr != "empty":
-            r2repr = r2repr.format(**data)
+    @testing.combinations(
+        *_common_ranges_to_test,
+        lambda r, e: Range(r.lower, r.lower + e, bounds="[]"),
+        lambda r, e: Range(r.upper + 4 * e, r.upper + 6 * e, bounds="()"),
+        argnames="r2t",
+    )
+    def test_union(self, connection, r1t, r2t):
+        r1 = r1t(self._data_obj(), self._epsilon)
+        r2 = r2t(self._data_obj(), self._epsilon)
 
         RANGE = self._col_type
         range_typ = self._col_str
 
         q = select(
-            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
-            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+            cast(r1, RANGE).union(r2),
         )
-
-        r1, r2 = connection.execute(q).first()
-
-        resq = select(
+        validate_q = select(
             literal_column(f"'{r1}'::{range_typ}+'{r2}'::{range_typ}", RANGE),
         )
 
-        if err:
+        try:
+            pg_res = connection.execute(q).scalar()
+        except DataError:
+            connection.rollback()
             with expect_raises(DataError):
-                connection.execute(resq).scalar()
+                connection.execute(validate_q).scalar()
             with expect_raises(ValueError):
                 r1.union(r2)
         else:
-            union = connection.execute(resq).scalar()
+            validate_union = connection.execute(validate_q).scalar()
+            eq_(pg_res, validate_union)
+            py_res = r1.union(r2)
             eq_(
-                r1.union(r2),
-                union,
-                f"{r1}.union({r2}) != {union}",
+                py_res,
+                pg_res,
+                f"{r1}.union({r2}): got {py_res}, expected {pg_res}",
             )
 
     @testing.combinations(
-        lambda r, e: Range(empty=True),
-        lambda r, e: r,
+        *_common_ranges_to_test,
         lambda r, e: Range(r.lower, r.lower, bounds="[]"),
-        lambda r, e: Range(r.lower, r.upper, bounds="[]"),
         lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"),
         lambda r, e: Range(r.lower - e, r.upper + e, bounds="[)"),
         lambda r, e: Range(r.lower - e, r.upper + e, bounds="[]"),
         argnames="r1t",
     )
     @testing.combinations(
-        lambda r, e: Range(empty=True),
-        lambda r, e: r,
+        *_common_ranges_to_test,
         lambda r, e: Range(r.lower, r.lower, bounds="[]"),
         lambda r, e: Range(r.lower, r.upper - e, bounds="(]"),
         lambda r, e: Range(r.lower, r.lower + e, bounds="[)"),
@@ -4309,38 +4351,37 @@ class _RangeComparisonFixtures(_RangeTests):
         range_typ = self._col_str
 
         q = select(
-            literal_column(f"'{r1}'::{range_typ}", RANGE).label("r1"),
-            literal_column(f"'{r2}'::{range_typ}", RANGE).label("r2"),
+            cast(r1, RANGE).difference(r2),
         )
-
-        r1, r2 = connection.execute(q).first()
-
-        resq = select(
+        validate_q = select(
             literal_column(f"'{r1}'::{range_typ}-'{r2}'::{range_typ}", RANGE),
         )
 
         try:
-            difference = connection.execute(resq).scalar()
+            pg_res = connection.execute(q).scalar()
         except DataError:
+            connection.rollback()
+            with expect_raises(DataError):
+                connection.execute(validate_q).scalar()
             with expect_raises(ValueError):
                 r1.difference(r2)
         else:
+            validate_difference = connection.execute(validate_q).scalar()
+            eq_(pg_res, validate_difference)
+            py_res = r1.difference(r2)
             eq_(
-                r1.difference(r2),
-                difference,
-                f"{r1}.difference({r2}): got {r1.difference(r2)},"
-                f" expected {difference}",
+                py_res,
+                pg_res,
+                f"{r1}.difference({r2}): got {py_res}, expected {pg_res}",
             )
 
     @testing.combinations(
-        lambda r, e: Range(empty=True),
-        lambda r, e: r,
+        *_common_ranges_to_test,
         lambda r, e: Range(r.lower, r.lower, bounds="[]"),
         argnames="r1t",
     )
     @testing.combinations(
-        lambda r, e: Range(empty=True),
-        lambda r, e: r,
+        *_common_ranges_to_test,
         lambda r, e: Range(r.lower, r.lower, bounds="[]"),
         lambda r, e: Range(r.lower, r.lower + e, bounds="[)"),
         lambda r, e: Range(r.lower - e, r.lower, bounds="(]"),
@@ -4357,11 +4398,7 @@ class _RangeComparisonFixtures(_RangeTests):
             literal_column(f"'{r1}'::{range_typ} = '{r2}'::{range_typ}")
         )
         equal = connection.execute(q).scalar()
-        eq_(
-            r1 == r2,
-            equal,
-            f"{r1} == {r2} != {equal}",
-        )
+        eq_(r1 == r2, equal, f"{r1} == {r2}: got {r1 == r2}, expected {equal}")
 
         q = select(
             literal_column(f"'{r1}'::{range_typ} <> '{r2}'::{range_typ}")
@@ -4370,7 +4407,7 @@ class _RangeComparisonFixtures(_RangeTests):
         eq_(
             r1 != r2,
             different,
-            f"{r1} != {r2} != {different}",
+            f"{r1} != {r2}: got {r1 != r2}, expected {different}",
         )