From: Lele Gaifax Date: Wed, 9 Nov 2022 08:47:37 +0000 (+0100) Subject: Improve test_difference() cases and fix difference() accordingly X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3ea6d2fbd8e94d11acce591fa8e380ff7d0d17db;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve test_difference() cases and fix difference() accordingly --- diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 7327f6f1cf..d12b4dffe6 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -525,7 +525,12 @@ class Range(Generic[_T]): # middle if sl_vs_ol <= 0 and su_vs_ol >= 0 and su_vs_ou <= 0: rupper_b = ")" if olower_b == "[" else "]" - if self._compare_edges(slower, slower_b, olower, rupper_b) == 0: + if ( + slower_b != "[" + and rupper_b != "]" + and self._compare_edges(slower, slower_b, olower, rupper_b) + == 0 + ): return Range(None, None, empty=True) else: return Range(slower, olower, bounds=slower_b + rupper_b) @@ -533,8 +538,13 @@ class Range(Generic[_T]): # If this range starts in the middle of the other and extends to its # right if sl_vs_ol >= 0 and su_vs_ou >= 0 and sl_vs_ou <= 0: - rlower_b = "(" if oupper_b == "]" else "(" - if self._compare_edges(oupper, rlower_b, supper, supper_b) == 0: + rlower_b = "(" if oupper_b == "]" else "[" + if ( + rlower_b != "[" + and supper_b != "]" + and self._compare_edges(oupper, rlower_b, supper, supper_b) + == 0 + ): return Range(None, None, empty=True) else: return Range(oupper, supper, bounds=rlower_b + supper_b) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 128041e4ad..faa1af7b19 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -4280,29 +4280,37 @@ class _RangeComparisonFixtures(_RangeTests): ) @testing.combinations( - ("empty", "empty", False), - ("empty", "[{le},{re_}]", False), - ("[{le},{re_}]", "empty", False), - ("[{ll},{ih}]", "({le},{ih}]", False), - ("[{ll},{rh})", "[{le},{re_}]", False), - ("[{le},{re_}]", "({le},{re_})", True), - ("[{ll},{rh}]", "[{le},{re_}]", True), - argnames="r1repr,r2repr,err", + lambda r, e: Range(empty=True), + lambda r, e: r, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower, r.upper, bounds="[]"), + lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"), + lambda r, e: Range(r.lower - e, r.upper + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.upper + e, bounds="[]"), + argnames="r1t", ) - def test_difference(self, connection, r1repr, r2repr, err): - data = self._value_values() - - if r1repr != "empty": - r1repr = r1repr.format(**data) - if r2repr != "empty": - r2repr = r2repr.format(**data) + @testing.combinations( + lambda r, e: Range(empty=True), + lambda r, e: r, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower, r.upper - e, bounds="(]"), + lambda r, e: Range(r.lower, r.lower + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.lower, bounds="(]"), + lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"), + lambda r, e: Range(r.lower, r.upper, bounds="[]"), + lambda r, e: Range(r.lower, r.upper, bounds="()"), + argnames="r2t", + ) + def test_difference(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) RANGE = self._col_type range_typ = self._col_str q = select( - literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"), - literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"), + literal_column(f"'{r1}'::{range_typ}", RANGE).label("r1"), + literal_column(f"'{r2}'::{range_typ}", RANGE).label("r2"), ) r1, r2 = connection.execute(q).first() @@ -4311,17 +4319,17 @@ class _RangeComparisonFixtures(_RangeTests): literal_column(f"'{r1}'::{range_typ}-'{r2}'::{range_typ}", RANGE), ) - if err: - with expect_raises(DataError): - connection.execute(resq).scalar() + try: + difference = connection.execute(resq).scalar() + except DataError: with expect_raises(ValueError): r1.difference(r2) else: - difference = connection.execute(resq).scalar() eq_( r1.difference(r2), difference, - f"{r1}.difference({r2}) != {difference}", + f"{r1}.difference({r2}): got {r1.difference(r2)}," + f" expected {difference}", ) @testing.combinations(