]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added several range tests and fixed a pasto error
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 10 Dec 2020 02:51:36 +0000 (03:51 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 10 Dec 2020 02:51:36 +0000 (03:51 +0100)
psycopg3/psycopg3/types/range.py
tests/types/test_range.py

index 56e8901be79cbbcd94b1d5a0750c6cc148a92773..3432a8d82e3748c2830a18598d959a3fe01f4a04 100644 (file)
@@ -340,7 +340,7 @@ class NumericRangeLoader(RangeLoader[Decimal]):
 
 @Loader.text(builtins["daterange"].oid)
 class DateRangeLoader(RangeLoader[date]):
-    subtype_oid = builtins["numeric"].oid
+    subtype_oid = builtins["date"].oid
     cls = DateRange
 
 
index 93ad24be428e35edc6144ebd099d0c564522e7ed..003532797634bbf23835954f9a7778f7d8bc3ffb 100644 (file)
@@ -1,8 +1,13 @@
+import pickle
+import datetime as dt
+from decimal import Decimal
+
 import pytest
 
 from psycopg3.sql import Identifier
 from psycopg3.oids import builtins
 from psycopg3.types import range as mrange
+from psycopg3.types.range import Range
 
 
 type2cls = {
@@ -22,6 +27,8 @@ type2sub = {
     "tstzrange": "timestamptz",
 }
 
+tzinfo = dt.timezone(dt.timedelta(hours=2))
+
 samples = [
     ("int4range", None, None, "()"),
     ("int4range", 10, 20, "[]"),
@@ -29,7 +36,20 @@ samples = [
     ("int8range", None, None, "()"),
     ("int8range", 10, 20, "[)"),
     ("int8range", -(2 ** 63), (2 ** 63) - 1, "[)"),
-    # TODO: complete samples
+    ("numrange", Decimal(-100), Decimal("100.123"), "(]"),
+    ("daterange", dt.date(2000, 1, 1), dt.date(2020, 1, 1), "[)"),
+    (
+        "tsrange",
+        dt.datetime(2000, 1, 1, 00, 00),
+        dt.datetime(2020, 1, 1, 23, 59, 59, 999999),
+        "[]",
+    ),
+    (
+        "tstzrange",
+        dt.datetime(2000, 1, 1, 00, 00, tzinfo=tzinfo),
+        dt.datetime(2020, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo),
+        "()",
+    ),
 ]
 
 
@@ -39,8 +59,21 @@ samples = [
 )
 def test_dump_builtin_empty(conn, pgtype):
     r = type2cls[pgtype](empty=True)
-    cur = conn.cursor()
-    cur.execute(f"select 'empty'::{pgtype} = %s", (r,))
+    cur = conn.execute(f"select 'empty'::{pgtype} = %s", (r,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "pgtype",
+    "int4range int8range numrange daterange tsrange tstzrange".split(),
+)
+def test_dump_builtin_array(conn, pgtype):
+    r1 = type2cls[pgtype](empty=True)
+    r2 = type2cls[pgtype](bounds="()")
+    cur = conn.execute(
+        f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %s",
+        ([r1, r2],),
+    )
     assert cur.fetchone()[0] is True
 
 
@@ -48,8 +81,7 @@ def test_dump_builtin_empty(conn, pgtype):
 def test_dump_builtin_range(conn, pgtype, min, max, bounds):
     r = type2cls[pgtype](min, max, bounds)
     sub = type2sub[pgtype]
-    cur = conn.cursor()
-    cur.execute(
+    cur = conn.execute(
         f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %s",
         (min, max, bounds, r),
     )
@@ -62,18 +94,46 @@ def test_dump_builtin_range(conn, pgtype, min, max, bounds):
 )
 def test_load_builtin_empty(conn, pgtype):
     r = type2cls[pgtype](empty=True)
-    cur = conn.cursor()
-    (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone()
+    (got,) = conn.execute(f"select 'empty'::{pgtype}").fetchone()
     assert type(got) is type2cls[pgtype]
     assert got == r
+    assert not got
+    assert got.isempty
+
+
+@pytest.mark.parametrize(
+    "pgtype",
+    "int4range int8range numrange daterange tsrange tstzrange".split(),
+)
+def test_load_builtin_inf(conn, pgtype):
+    r = type2cls[pgtype](bounds="()")
+    (got,) = conn.execute(f"select '(,)'::{pgtype}").fetchone()
+    assert type(got) is type2cls[pgtype]
+    assert got == r
+    assert got
+    assert not got.isempty
+    assert got.lower_inf
+    assert got.upper_inf
+
+
+@pytest.mark.parametrize(
+    "pgtype",
+    "int4range int8range numrange daterange tsrange tstzrange".split(),
+)
+def test_load_builtin_array(conn, pgtype):
+    r1 = type2cls[pgtype](empty=True)
+    r2 = type2cls[pgtype](bounds="()")
+    (got,) = conn.execute(
+        f"select array['empty'::{pgtype}, '(,)'::{pgtype}]"
+    ).fetchone()
+    assert got == [r1, r2]
 
 
 @pytest.mark.parametrize("pgtype, min, max, bounds", samples)
 def test_load_builtin_range(conn, pgtype, min, max, bounds):
     r = type2cls[pgtype](min, max, bounds)
     sub = type2sub[pgtype]
-    cur = conn.cursor()
-    cur.execute(
+    cur = conn.execute(
         f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds)
     )
     # normalise discrete ranges
@@ -85,8 +145,7 @@ def test_load_builtin_range(conn, pgtype, min, max, bounds):
 
 @pytest.fixture(scope="session")
 def testrange(svcconn):
-    cur = svcconn.cursor()
-    cur.execute(
+    svcconn.execute(
         """
         create schema if not exists testschema;
 
@@ -116,6 +175,11 @@ def test_fetch_info(conn, testrange, name, subtype):
     assert info.subtype_oid == builtins[subtype].oid
 
 
+def test_fetch_info_not_found(conn):
+    with pytest.raises(conn.ProgrammingError):
+        mrange.RangeInfo.fetch(conn, "nosuchrange")
+
+
 @pytest.mark.asyncio
 @pytest.mark.parametrize("name, subtype", fetch_cases)
 async def test_fetch_info_async(aconn, testrange, name, subtype):
@@ -126,6 +190,12 @@ async def test_fetch_info_async(aconn, testrange, name, subtype):
     assert info.subtype_oid == builtins[subtype].oid
 
 
+@pytest.mark.asyncio
+async def test_fetch_info_not_found_async(aconn):
+    with pytest.raises(aconn.ProgrammingError):
+        await mrange.RangeInfo.fetch_async(aconn, "nosuchrange")
+
+
 def test_dump_custom_empty(conn, testrange):
     class StrRange(mrange.Range):
         pass
@@ -134,8 +204,7 @@ def test_dump_custom_empty(conn, testrange):
     info.register(conn, range_class=StrRange)
 
     r = StrRange(empty=True)
-    cur = conn.cursor()
-    cur.execute("select 'empty'::testrange = %s", (r,))
+    cur = conn.execute("select 'empty'::testrange = %s", (r,))
     assert cur.fetchone()[0] is True
 
 
@@ -158,8 +227,7 @@ def test_load_custom_empty(conn, testrange):
     info = mrange.RangeInfo.fetch(conn, "testrange")
     info.register(conn)
 
-    cur = conn.cursor()
-    (got,) = cur.execute("select 'empty'::testrange").fetchone()
+    (got,) = conn.execute("select 'empty'::testrange").fetchone()
     assert isinstance(got, mrange.Range)
     assert got.isempty
 
@@ -177,3 +245,276 @@ def test_load_quoting(conn, testrange):
         assert isinstance(got, mrange.Range)
         assert ord(got.lower) == i
         assert ord(got.upper) == i + 1
+
+
+class TestRangeObject:
+    def test_noparam(self):
+        r = Range()
+
+        assert not r.isempty
+        assert r.lower is None
+        assert r.upper is None
+        assert r.lower_inf
+        assert r.upper_inf
+        assert not r.lower_inc
+        assert not r.upper_inc
+
+    def test_empty(self):
+        r = Range(empty=True)
+
+        assert r.isempty
+        assert r.lower is None
+        assert r.upper is None
+        assert not r.lower_inf
+        assert not r.upper_inf
+        assert not r.lower_inc
+        assert not r.upper_inc
+
+    def test_nobounds(self):
+        r = Range(10, 20)
+        assert r.lower == 10
+        assert r.upper == 20
+        assert not r.isempty
+        assert not r.lower_inf
+        assert not r.upper_inf
+        assert r.lower_inc
+        assert not r.upper_inc
+
+    def test_bounds(self):
+        for bounds, lower_inc, upper_inc in [
+            ("[)", True, False),
+            ("(]", False, True),
+            ("()", False, False),
+            ("[]", True, True),
+        ]:
+            r = Range(10, 20, bounds)
+            assert r.lower == 10
+            assert r.upper == 20
+            assert not r.isempty
+            assert not r.lower_inf
+            assert not r.upper_inf
+            assert r.lower_inc == lower_inc
+            assert r.upper_inc == upper_inc
+
+    def test_keywords(self):
+        r = Range(upper=20)
+        r.lower is None
+        r.upper == 20
+        assert not r.isempty
+        assert r.lower_inf
+        assert not r.upper_inf
+        assert not r.lower_inc
+        assert not r.upper_inc
+
+        r = Range(lower=10, bounds="(]")
+        r.lower == 10
+        r.upper is None
+        assert not r.isempty
+        assert not r.lower_inf
+        assert r.upper_inf
+        assert not r.lower_inc
+        assert not r.upper_inc
+
+    def test_bad_bounds(self):
+        with pytest.raises(ValueError):
+            Range(bounds="(")
+        with pytest.raises(ValueError):
+            Range(bounds="[}")
+
+    def test_in(self):
+        r = Range(empty=True)
+        assert 10 not in r
+
+        r = Range()
+        assert 10 in r
+
+        r = Range(lower=10, bounds="[)")
+        assert 9 not in r
+        assert 10 in r
+        assert 11 in r
+
+        r = Range(lower=10, bounds="()")
+        assert 9 not in r
+        assert 10 not in r
+        assert 11 in r
+
+        r = Range(upper=20, bounds="()")
+        assert 19 in r
+        assert 20 not in r
+        assert 21 not in r
+
+        r = Range(upper=20, bounds="(]")
+        assert 19 in r
+        assert 20 in r
+        assert 21 not in r
+
+        r = Range(10, 20)
+        assert 9 not in r
+        assert 10 in r
+        assert 11 in r
+        assert 19 in r
+        assert 20 not in r
+        assert 21 not in r
+
+        r = Range(10, 20, "(]")
+        assert 9 not in r
+        assert 10 not in r
+        assert 11 in r
+        assert 19 in r
+        assert 20 in r
+        assert 21 not in r
+
+        r = Range(20, 10)
+        assert 9 not in r
+        assert 10 not in r
+        assert 11 not in r
+        assert 19 not in r
+        assert 20 not in r
+        assert 21 not in r
+
+    def test_nonzero(self):
+        assert Range()
+        assert Range(10, 20)
+        assert not Range(empty=True)
+
+    def test_eq_hash(self):
+        def assert_equal(r1, r2):
+            assert r1 == r2
+            assert hash(r1) == hash(r2)
+
+        assert_equal(Range(empty=True), Range(empty=True))
+        assert_equal(Range(), Range())
+        assert_equal(Range(10, None), Range(10, None))
+        assert_equal(Range(10, 20), Range(10, 20))
+        assert_equal(Range(10, 20), Range(10, 20, "[)"))
+        assert_equal(Range(10, 20, "[]"), Range(10, 20, "[]"))
+
+        def assert_not_equal(r1, r2):
+            assert r1 != r2
+            assert hash(r1) != hash(r2)
+
+        assert_not_equal(Range(10, 20), Range(10, 21))
+        assert_not_equal(Range(10, 20), Range(11, 20))
+        assert_not_equal(Range(10, 20, "[)"), Range(10, 20, "[]"))
+
+    def test_eq_wrong_type(self):
+        assert Range(10, 20) != ()
+
+    def test_eq_subclass(self):
+        class IntRange(mrange.DecimalRange):
+            pass
+
+        class PositiveIntRange(IntRange):
+            pass
+
+        assert Range(10, 20) == IntRange(10, 20)
+        assert PositiveIntRange(10, 20) == IntRange(10, 20)
+
+    # as the postgres docs describe for the server-side stuff,
+    # ordering is rather arbitrary, but will remain stable
+    # and consistent.
+
+    def test_lt_ordering(self):
+        assert Range(empty=True) < Range(0, 4)
+        assert not Range(1, 2) < Range(0, 4)
+        assert Range(0, 4) < Range(1, 2)
+        assert not Range(1, 2) < Range()
+        assert Range() < Range(1, 2)
+        assert not Range(1) < Range(upper=1)
+        assert not Range() < Range()
+        assert not Range(empty=True) < Range(empty=True)
+        assert not Range(1, 2) < Range(1, 2)
+        with pytest.raises(TypeError):
+            assert 1 < Range(1, 2)
+        with pytest.raises(TypeError):
+            assert not Range(1, 2) < 1
+
+    def test_gt_ordering(self):
+        assert not Range(empty=True) > Range(0, 4)
+        assert Range(1, 2) > Range(0, 4)
+        assert not Range(0, 4) > Range(1, 2)
+        assert Range(1, 2) > Range()
+        assert not Range() > Range(1, 2)
+        assert Range(1) > Range(upper=1)
+        assert not Range() > Range()
+        assert not Range(empty=True) > Range(empty=True)
+        assert not Range(1, 2) > Range(1, 2)
+        with pytest.raises(TypeError):
+            assert not 1 > Range(1, 2)
+        with pytest.raises(TypeError):
+            assert Range(1, 2) > 1
+
+    def test_le_ordering(self):
+        assert Range(empty=True) <= Range(0, 4)
+        assert not Range(1, 2) <= Range(0, 4)
+        assert Range(0, 4) <= Range(1, 2)
+        assert not Range(1, 2) <= Range()
+        assert Range() <= Range(1, 2)
+        assert not Range(1) <= Range(upper=1)
+        assert Range() <= Range()
+        assert Range(empty=True) <= Range(empty=True)
+        assert Range(1, 2) <= Range(1, 2)
+        with pytest.raises(TypeError):
+            assert 1 <= Range(1, 2)
+        with pytest.raises(TypeError):
+            assert not Range(1, 2) <= 1
+
+    def test_ge_ordering(self):
+        assert not Range(empty=True) >= Range(0, 4)
+        assert Range(1, 2) >= Range(0, 4)
+        assert not Range(0, 4) >= Range(1, 2)
+        assert Range(1, 2) >= Range()
+        assert not Range() >= Range(1, 2)
+        assert Range(1) >= Range(upper=1)
+        assert Range() >= Range()
+        assert Range(empty=True) >= Range(empty=True)
+        assert Range(1, 2) >= Range(1, 2)
+        with pytest.raises(TypeError):
+            assert not 1 >= Range(1, 2)
+        with pytest.raises(TypeError):
+            (Range(1, 2) >= 1)
+
+    def test_pickling(self):
+        r = Range(0, 4)
+        assert pickle.loads(pickle.dumps(r)) == r
+
+    def test_str(self):
+        """
+        Range types should have a short and readable ``str`` implementation.
+
+        Using ``repr`` for all string conversions can be very unreadable for
+        longer types like ``DateTimeTZRange``.
+        """
+
+        # Using the "u" prefix to make sure we have the proper return types in
+        # Python2
+        expected = [
+            "(0, 4)",
+            "[0, 4]",
+            "(0, 4]",
+            "[0, 4)",
+            "empty",
+        ]
+        results = []
+
+        for bounds in ("()", "[]", "(]", "[)"):
+            r = Range(0, 4, bounds=bounds)
+            results.append(str(r))
+
+        r = Range(empty=True)
+        results.append(str(r))
+        assert results == expected
+
+    def test_str_datetime(self):
+        """
+        Date-Time ranges should return a human-readable string as well on
+        string conversion.
+        """
+        tz = dt.timezone(dt.timedelta(hours=-5))
+        r = mrange.DateTimeTZRange(
+            dt.datetime(2010, 1, 1, tzinfo=tz),
+            dt.datetime(2011, 1, 1, tzinfo=tz),
+        )
+        expected = "[2010-01-01 00:00:00-05:00, 2011-01-01 00:00:00-05:00)"
+        result = str(r)
+        assert result == expected