]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Complete multirange implementation and test
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 4 Oct 2021 01:50:15 +0000 (03:50 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 4 Oct 2021 12:45:56 +0000 (14:45 +0200)
- add register_multirange()
- check type of multirange items
- add extensive dump/load tests

psycopg/psycopg/types/multirange.py
tests/types/__init__.py [new file with mode: 0644]
tests/types/test_multirange.py
tests/types/test_range.py

index c3bae8635207c56ca58763759558c906ee975632..62bcd04b26f45401d9ac8e25c0d8c4d929ac5f62 100644 (file)
@@ -6,7 +6,7 @@ Support for multirange types adaptation.
 
 from decimal import Decimal
 from typing import Any, Generic, List, Iterable
-from typing import MutableSequence, Optional, Union, overload
+from typing import MutableSequence, Optional, Type, Union, overload
 from datetime import date, datetime
 
 from .. import errors as e
@@ -24,7 +24,14 @@ from .range import dump_range_text, dump_range_binary, fail_dump
 
 class Multirange(MutableSequence[Range[T]]):
     def __init__(self, items: Iterable[Range[T]] = ()):
-        self._ranges: List[Range[T]] = list(items)
+        self._ranges: List[Range[T]] = list(map(self._check_type, items))
+
+    def _check_type(self, item: Any) -> Range[Any]:
+        if not isinstance(item, Range):
+            raise TypeError(
+                f"Multirange is a sequence of Range, got {type(item).__name__}"
+            )
+        return item
 
     def __repr__(self) -> str:
         return f"{self.__class__.__name__}({self._ranges!r})"
@@ -64,13 +71,21 @@ class Multirange(MutableSequence[Range[T]]):
         index: Union[int, slice],
         value: Union[Range[T], Iterable[Range[T]]],
     ) -> None:
-        self._ranges[index] = value  # type: ignore
+        if isinstance(index, int):
+            self._check_type(value)
+            self._ranges[index] = self._check_type(value)
+        else:
+            if isinstance(value, Iterable):
+                value = map(self._check_type, value)
+            else:
+                value = [self._check_type(value)]
+            self._ranges[index] = value
 
     def __delitem__(self, index: Union[int, slice]) -> None:
         del self._ranges[index]
 
     def insert(self, index: int, value: Range[T]) -> None:
-        self._ranges.insert(index, value)
+        self._ranges.insert(index, self._check_type(value))
 
     def __eq__(self, other: Any) -> bool:
         if not isinstance(other, Multirange):
@@ -313,6 +328,54 @@ class MultirangeBinaryLoader(BaseMultirangeLoader[T]):
         return out
 
 
+def register_multirange(
+    info: MultirangeInfo, context: Optional[AdaptContext] = None
+) -> None:
+    """Register the adapters to load and dump a multirange type.
+
+    :param info: The object with the information about the range to register.
+    :param context: The context where to register the adapters. If `!None`,
+        register it globally.
+
+    Register loaders so that loading data of this type will result in a `Range`
+    with bounds parsed as the right subtype.
+
+    .. note::
+
+        Registering the adapters doesn't affect objects already created, even
+        if they are children of the registered context. For instance,
+        registering the adapter globally doesn't affect already existing
+        connections.
+    """
+    # A friendly error warning instead of an AttributeError in case fetch()
+    # failed and it wasn't noticed.
+    if not info:
+        raise TypeError(
+            "no info passed. Is the requested multirange available?"
+        )
+
+    # Register arrays and type info
+    info.register(context)
+
+    adapters = context.adapters if context else postgres.adapters
+
+    # generate and register a customized text loader
+    loader: Type[MultirangeLoader[Any]] = type(
+        f"{info.name.title()}Loader",
+        (MultirangeLoader,),
+        {"subtype_oid": info.subtype_oid},
+    )
+    adapters.register_loader(info.oid, loader)
+
+    # generate and register a customized binary loader
+    bloader: Type[MultirangeBinaryLoader[Any]] = type(
+        f"{info.name.title()}BinaryLoader",
+        (MultirangeBinaryLoader,),
+        {"subtype_oid": info.subtype_oid},
+    )
+    adapters.register_loader(info.oid, bloader)
+
+
 # Text dumpers for builtin multirange types wrappers
 # These are registered on specific subtypes so that the upgrade mechanism
 # doesn't kick in.
diff --git a/tests/types/__init__.py b/tests/types/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
index 63231e0175da9ae3b8157fca6e1f9ab9def4f737..7721f63a81376704453536babd2d17b4b6d5999c 100644 (file)
@@ -1,18 +1,21 @@
 import pickle
+import datetime as dt
+from decimal import Decimal
 
 import pytest
 
+from psycopg import pq
+from psycopg import errors as e
+from psycopg.sql import Identifier
 from psycopg.adapt import PyFormat
 from psycopg.types.range import Range
-from psycopg.types.multirange import Multirange
+from psycopg.types import multirange
+from psycopg.types.multirange import Multirange, MultirangeInfo
+from psycopg.types.multirange import register_multirange
 
-pytestmark = pytest.mark.pg(">= 14")
-
-mr_names = """int4multirange int8multirange nummultirange
-    datemultirange tsmultirange tstzmultirange""".split()
+from .test_range import create_test_range
 
-mr_classes = """Int4Multirange Int8Multirange NumericMultirange
-    DateMultirange TimestampMultirange TimestamptzMultirange""".split()
+pytestmark = pytest.mark.pg(">= 14")
 
 
 class TestMultirangeObject:
@@ -32,11 +35,41 @@ class TestMultirangeObject:
         assert mr[2] == Range(50, 60)
         assert mr[-2] == Range(30, 40)
 
+    def test_bad_type(self):
+        with pytest.raises(TypeError):
+            Multirange(Range(10, 20))
+
+        with pytest.raises(TypeError):
+            Multirange([10])
+
+        mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+
+        with pytest.raises(TypeError):
+            mr[0] = "foo"
+
+        with pytest.raises(TypeError):
+            mr[0:1] = "foo"
+
+        with pytest.raises(TypeError):
+            mr[0:1] = ["foo"]
+
+        with pytest.raises(TypeError):
+            mr.insert(0, "foo")
+
     def test_setitem(self):
         mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
         mr[1] = Range(31, 41)
         assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)])
 
+    def test_setitem_slice(self):
+        mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+        mr[1:3] = [Range(31, 41), Range(51, 61)]
+        assert mr == Multirange([Range(10, 20), Range(31, 41), Range(51, 61)])
+
+        mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+        mr[1:3] = Range(31, 41)
+        assert mr == Multirange([Range(10, 20), Range(31, 41)])
+
     def test_delitem(self):
         mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
         del mr[1]
@@ -45,6 +78,11 @@ class TestMultirangeObject:
         del mr[-2]
         assert mr == Multirange([Range(50, 60)])
 
+    def test_insert(self):
+        mr = Multirange([Range(10, 20), Range(50, 60)])
+        mr.insert(1, Range(31, 41))
+        assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)])
+
     def test_relations(self):
         mr1 = Multirange([Range(10, 20), Range(30, 40)])
         mr2 = Multirange([Range(11, 20), Range(30, 40)])
@@ -74,9 +112,316 @@ class TestMultirangeObject:
         assert repr(mr) == expected
 
 
+tzinfo = dt.timezone(dt.timedelta(hours=2))
+
+samples = [
+    ("int4multirange", [Range(None, None, "()")]),
+    ("int4multirange", [Range(10, 20), Range(30, 40)]),
+    ("int8multirange", [Range(None, None, "()")]),
+    ("int8multirange", [Range(10, 20), Range(30, 40)]),
+    (
+        "nummultirange",
+        [
+            Range(None, Decimal(-100)),
+            Range(Decimal(100), Decimal("100.123")),
+        ],
+    ),
+    (
+        "datemultirange",
+        [Range(dt.date(2000, 1, 1), dt.date(2020, 1, 1))],
+    ),
+    (
+        "tsmultirange",
+        [
+            Range(
+                dt.datetime(2000, 1, 1, 00, 00),
+                dt.datetime(2020, 1, 1, 23, 59, 59, 999999),
+            )
+        ],
+    ),
+    (
+        "tstzmultirange",
+        [
+            Range(
+                dt.datetime(2000, 1, 1, 00, 00, tzinfo=tzinfo),
+                dt.datetime(2020, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo),
+            ),
+            Range(
+                dt.datetime(2030, 1, 1, 00, 00, tzinfo=tzinfo),
+                dt.datetime(2040, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo),
+            ),
+        ],
+    ),
+]
+
+mr_names = """
+    int4multirange int8multirange nummultirange
+    datemultirange tsmultirange tstzmultirange""".split()
+
+mr_classes = """
+    Int4Multirange Int8Multirange NumericMultirange
+    DateMultirange TimestampMultirange TimestamptzMultirange""".split()
+
+
 @pytest.mark.parametrize("pgtype", mr_names)
 @pytest.mark.parametrize("fmt_in", PyFormat)
 def test_dump_builtin_empty(conn, pgtype, fmt_in):
     mr = Multirange()
     cur = conn.execute(f"select '{{}}'::{pgtype} = %{fmt_in}", (mr,))
     assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("wrapper", mr_classes)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in):
+    dumper = getattr(multirange, wrapper + "Dumper")
+    wrapper = getattr(multirange, wrapper)
+    mr = wrapper()
+    rec = conn.execute(
+        f"""
+        select '{{}}' = %(mr){fmt_in},
+            %(mr){fmt_in}::text,
+            pg_typeof(%(mr){fmt_in})::oid
+        """,
+        {"mr": mr},
+    ).fetchone()
+    assert rec[0] is True, rec[1]
+    assert rec[2] == dumper.oid
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize(
+    "fmt_in",
+    [
+        PyFormat.AUTO,
+        PyFormat.TEXT,
+        # There are many ways to work around this (use text, use a cast on the
+        # placeholder, use specific Range subclasses).
+        pytest.param(
+            PyFormat.BINARY,
+            marks=pytest.mark.xfail(
+                reason="can't dump array of untypes binary multirange without cast"
+            ),
+        ),
+    ],
+)
+def test_dump_builtin_array(conn, pgtype, fmt_in):
+    mr1 = Multirange()
+    mr2 = Multirange([Range(bounds="()")])
+    cur = conn.execute(
+        f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}] = %{fmt_in}",
+        ([mr1, mr2],),
+    )
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in):
+    mr1 = Multirange()
+    mr2 = Multirange([Range(bounds="()")])
+    cur = conn.execute(
+        f"""
+        select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}] = %{fmt_in}::{pgtype}[]
+        """,
+        ([mr1, mr2],),
+    )
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("wrapper", mr_classes)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in):
+    wrapper = getattr(multirange, wrapper)
+    mr1 = Multirange()
+    mr2 = Multirange([Range(bounds="()")])
+    cur = conn.execute(
+        f"""select '{{"{{}}","{{(,)}}"}}' = %{fmt_in}""", ([mr1, mr2],)
+    )
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype, ranges", samples)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_dump_builtin_multirange(conn, pgtype, ranges, fmt_in):
+    mr = Multirange(ranges)
+    rname = pgtype.replace("multi", "")
+    phs = ", ".join([f"%s::{rname}"] * len(ranges))
+    cur = conn.execute(f"select {pgtype}({phs}) = %{fmt_in}", ranges + [mr])
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_empty(conn, pgtype, fmt_out):
+    mr = Multirange()
+    cur = conn.cursor(binary=fmt_out)
+    (got,) = cur.execute(f"select '{{}}'::{pgtype}").fetchone()
+    assert type(got) is Multirange
+    assert got == mr
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_array(conn, pgtype, fmt_out):
+    mr1 = Multirange()
+    mr2 = Multirange([Range(bounds="()")])
+    cur = conn.cursor(binary=fmt_out)
+    (got,) = cur.execute(
+        f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}]"
+    ).fetchone()
+    assert got == [mr1, mr2]
+
+
+@pytest.mark.parametrize("pgtype, ranges", samples)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_builtin_range(conn, pgtype, ranges, fmt_out):
+    mr = Multirange(ranges)
+    rname = pgtype.replace("multi", "")
+    phs = ", ".join([f"%s::{rname}"] * len(ranges))
+    cur = conn.cursor(binary=fmt_out)
+    cur.execute(f"select {pgtype}({phs})", ranges)
+    assert cur.fetchone()[0] == mr
+
+
+@pytest.mark.parametrize(
+    "min, max, bounds",
+    [
+        ("2000,1,1", "2001,1,1", "[)"),
+        ("2000,1,1", None, "[)"),
+        (None, "2001,1,1", "()"),
+        (None, None, "()"),
+        (None, None, "empty"),
+    ],
+)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in(conn, min, max, bounds, format):
+    cur = conn.cursor()
+    cur.execute(
+        "create table copymr (id serial primary key, mr datemultirange)"
+    )
+
+    if bounds != "empty":
+        min = dt.date(*map(int, min.split(","))) if min else None
+        max = dt.date(*map(int, max.split(","))) if max else None
+        r = Range(min, max, bounds)
+    else:
+        r = Range(empty=True)
+
+    mr = Multirange([r])
+    try:
+        with cur.copy(
+            f"copy copymr (mr) from stdin (format {format.name})"
+        ) as copy:
+            copy.write_row([mr])
+    except e.InternalError_:
+        if not min and not max and format == pq.Format.BINARY:
+            pytest.xfail(
+                "TODO: add annotation to dump multirange with no type info"
+            )
+        else:
+            raise
+
+    rec = cur.execute("select mr from copymr order by id").fetchone()
+    if not r.isempty:
+        assert rec[0] == mr
+    else:
+        assert rec[0] == Multirange()
+
+
+@pytest.mark.parametrize("wrapper", mr_classes)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in_empty_wrappers(conn, wrapper, format):
+    cur = conn.cursor()
+    cur.execute(
+        "create table copymr (id serial primary key, mr datemultirange)"
+    )
+
+    mr = getattr(multirange, wrapper)()
+
+    with cur.copy(
+        f"copy copymr (mr) from stdin (format {format.name})"
+    ) as copy:
+        copy.write_row([mr])
+
+    rec = cur.execute("select mr from copymr order by id").fetchone()
+    assert rec[0] == mr
+
+
+@pytest.mark.parametrize("pgtype", mr_names)
+@pytest.mark.parametrize("format", pq.Format)
+def test_copy_in_empty_set_type(conn, pgtype, format):
+    cur = conn.cursor()
+    cur.execute(f"create table copymr (id serial primary key, mr {pgtype})")
+
+    mr = Multirange()
+
+    with cur.copy(
+        f"copy copymr (mr) from stdin (format {format.name})"
+    ) as copy:
+        copy.set_types([pgtype])
+        copy.write_row([mr])
+
+    rec = cur.execute("select mr from copymr order by id").fetchone()
+    assert rec[0] == mr
+
+
+@pytest.fixture(scope="session")
+def testmr(svcconn):
+    create_test_range(svcconn)
+
+
+fetch_cases = [
+    ("testmultirange", "text"),
+    ("testschema.testmultirange", "float8"),
+    (Identifier("testmultirange"), "text"),
+    (Identifier("testschema", "testmultirange"), "float8"),
+]
+
+
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+def test_fetch_info(conn, testmr, name, subtype):
+    info = MultirangeInfo.fetch(conn, name)
+    assert info.name == "testmultirange"
+    assert info.oid > 0
+    assert info.oid != info.array_oid > 0
+    assert info.subtype_oid == conn.adapters.types[subtype].oid
+
+
+def test_fetch_info_not_found(conn):
+    assert MultirangeInfo.fetch(conn, "nosuchrange") is None
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+async def test_fetch_info_async(aconn, testmr, name, subtype):  # noqa: F811
+    info = await MultirangeInfo.fetch(aconn, name)
+    assert info.name == "testmultirange"
+    assert info.oid > 0
+    assert info.oid != info.array_oid > 0
+    assert info.subtype_oid == aconn.adapters.types[subtype].oid
+
+
+@pytest.mark.asyncio
+async def test_fetch_info_not_found_async(aconn):
+    assert await MultirangeInfo.fetch(aconn, "nosuchrange") is None
+
+
+def test_dump_custom_empty(conn, testmr):
+    info = MultirangeInfo.fetch(conn, "testmultirange")
+    register_multirange(info, conn)
+
+    r = Multirange()
+    cur = conn.execute("select '{}'::testmultirange = %s", (r,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_custom_empty(conn, testmr, fmt_out):
+    info = MultirangeInfo.fetch(conn, "testmultirange")
+    register_multirange(info, conn)
+
+    cur = conn.cursor(binary=fmt_out)
+    (got,) = cur.execute("select '{}'::testmultirange").fetchone()
+    assert isinstance(got, Multirange)
+    assert not got
index 000d6e6793cd238fd530e0bf0381e621f1a51f75..8e31ded3998ed5640ecfb6b9955472874eabe7fc 100644 (file)
@@ -4,8 +4,8 @@ from decimal import Decimal
 
 import pytest
 
-import psycopg.errors
 from psycopg import pq
+from psycopg import errors as e
 from psycopg.sql import Identifier
 from psycopg.adapt import PyFormat
 from psycopg.types import range as range_module
@@ -48,11 +48,13 @@ samples = [
     ),
 ]
 
-range_names = """int4range int8range numrange
-    daterange tsrange tstzrange""".split()
+range_names = """
+    int4range int8range numrange daterange tsrange tstzrange
+    """.split()
 
-range_classes = """Int4Range Int8Range NumericRange
-    DateRange TimestampRange TimestamptzRange""".split()
+range_classes = """
+    Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange
+    """.split()
 
 
 @pytest.mark.parametrize("pgtype", range_names)
@@ -199,7 +201,7 @@ def test_load_builtin_range(conn, pgtype, min, max, bounds, fmt_out):
     ],
 )
 @pytest.mark.parametrize("format", pq.Format)
-def test_copy_in_empty(conn, min, max, bounds, format):
+def test_copy_in(conn, min, max, bounds, format):
     cur = conn.cursor()
     cur.execute("create table copyrange (id serial primary key, r daterange)")
 
@@ -215,10 +217,10 @@ def test_copy_in_empty(conn, min, max, bounds, format):
             f"copy copyrange (r) from stdin (format {format.name})"
         ) as copy:
             copy.write_row([r])
-    except psycopg.errors.ProtocolViolation:
+    except e.ProtocolViolation:
         if not min and not max and format == pq.Format.BINARY:
             pytest.xfail(
-                "TODO: add annotation to dump array with no type info"
+                "TODO: add annotation to dump ranges with no type info"
             )
         else:
             raise
@@ -267,7 +269,11 @@ def test_copy_in_empty_set_type(conn, bounds, pgtype, format):
 
 @pytest.fixture(scope="session")
 def testrange(svcconn):
-    svcconn.execute(
+    create_test_range(svcconn)
+
+
+def create_test_range(conn):
+    conn.execute(
         """
         create schema if not exists testschema;