From 90ec3922be4a9f57314cf8638c20ca14b06e8112 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Mon, 4 Oct 2021 03:50:15 +0200 Subject: [PATCH] Complete multirange implementation and test - add register_multirange() - check type of multirange items - add extensive dump/load tests --- psycopg/psycopg/types/multirange.py | 71 +++++- tests/types/__init__.py | 0 tests/types/test_multirange.py | 359 +++++++++++++++++++++++++++- tests/types/test_range.py | 24 +- 4 files changed, 434 insertions(+), 20 deletions(-) create mode 100644 tests/types/__init__.py diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py index c3bae8635..62bcd04b2 100644 --- a/psycopg/psycopg/types/multirange.py +++ b/psycopg/psycopg/types/multirange.py @@ -6,7 +6,7 @@ Support for multirange types adaptation. from decimal import Decimal from typing import Any, Generic, List, Iterable -from typing import MutableSequence, Optional, Union, overload +from typing import MutableSequence, Optional, Type, Union, overload from datetime import date, datetime from .. import errors as e @@ -24,7 +24,14 @@ from .range import dump_range_text, dump_range_binary, fail_dump class Multirange(MutableSequence[Range[T]]): def __init__(self, items: Iterable[Range[T]] = ()): - self._ranges: List[Range[T]] = list(items) + self._ranges: List[Range[T]] = list(map(self._check_type, items)) + + def _check_type(self, item: Any) -> Range[Any]: + if not isinstance(item, Range): + raise TypeError( + f"Multirange is a sequence of Range, got {type(item).__name__}" + ) + return item def __repr__(self) -> str: return f"{self.__class__.__name__}({self._ranges!r})" @@ -64,13 +71,21 @@ class Multirange(MutableSequence[Range[T]]): index: Union[int, slice], value: Union[Range[T], Iterable[Range[T]]], ) -> None: - self._ranges[index] = value # type: ignore + if isinstance(index, int): + self._check_type(value) + self._ranges[index] = self._check_type(value) + else: + if isinstance(value, Iterable): + value = map(self._check_type, value) + else: + value = [self._check_type(value)] + self._ranges[index] = value def __delitem__(self, index: Union[int, slice]) -> None: del self._ranges[index] def insert(self, index: int, value: Range[T]) -> None: - self._ranges.insert(index, value) + self._ranges.insert(index, self._check_type(value)) def __eq__(self, other: Any) -> bool: if not isinstance(other, Multirange): @@ -313,6 +328,54 @@ class MultirangeBinaryLoader(BaseMultirangeLoader[T]): return out +def register_multirange( + info: MultirangeInfo, context: Optional[AdaptContext] = None +) -> None: + """Register the adapters to load and dump a multirange type. + + :param info: The object with the information about the range to register. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + Register loaders so that loading data of this type will result in a `Range` + with bounds parsed as the right subtype. + + .. note:: + + Registering the adapters doesn't affect objects already created, even + if they are children of the registered context. For instance, + registering the adapter globally doesn't affect already existing + connections. + """ + # A friendly error warning instead of an AttributeError in case fetch() + # failed and it wasn't noticed. + if not info: + raise TypeError( + "no info passed. Is the requested multirange available?" + ) + + # Register arrays and type info + info.register(context) + + adapters = context.adapters if context else postgres.adapters + + # generate and register a customized text loader + loader: Type[MultirangeLoader[Any]] = type( + f"{info.name.title()}Loader", + (MultirangeLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, loader) + + # generate and register a customized binary loader + bloader: Type[MultirangeBinaryLoader[Any]] = type( + f"{info.name.title()}BinaryLoader", + (MultirangeBinaryLoader,), + {"subtype_oid": info.subtype_oid}, + ) + adapters.register_loader(info.oid, bloader) + + # Text dumpers for builtin multirange types wrappers # These are registered on specific subtypes so that the upgrade mechanism # doesn't kick in. diff --git a/tests/types/__init__.py b/tests/types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/types/test_multirange.py b/tests/types/test_multirange.py index 63231e017..7721f63a8 100644 --- a/tests/types/test_multirange.py +++ b/tests/types/test_multirange.py @@ -1,18 +1,21 @@ import pickle +import datetime as dt +from decimal import Decimal import pytest +from psycopg import pq +from psycopg import errors as e +from psycopg.sql import Identifier from psycopg.adapt import PyFormat from psycopg.types.range import Range -from psycopg.types.multirange import Multirange +from psycopg.types import multirange +from psycopg.types.multirange import Multirange, MultirangeInfo +from psycopg.types.multirange import register_multirange -pytestmark = pytest.mark.pg(">= 14") - -mr_names = """int4multirange int8multirange nummultirange - datemultirange tsmultirange tstzmultirange""".split() +from .test_range import create_test_range -mr_classes = """Int4Multirange Int8Multirange NumericMultirange - DateMultirange TimestampMultirange TimestamptzMultirange""".split() +pytestmark = pytest.mark.pg(">= 14") class TestMultirangeObject: @@ -32,11 +35,41 @@ class TestMultirangeObject: assert mr[2] == Range(50, 60) assert mr[-2] == Range(30, 40) + def test_bad_type(self): + with pytest.raises(TypeError): + Multirange(Range(10, 20)) + + with pytest.raises(TypeError): + Multirange([10]) + + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + + with pytest.raises(TypeError): + mr[0] = "foo" + + with pytest.raises(TypeError): + mr[0:1] = "foo" + + with pytest.raises(TypeError): + mr[0:1] = ["foo"] + + with pytest.raises(TypeError): + mr.insert(0, "foo") + def test_setitem(self): mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) mr[1] = Range(31, 41) assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)]) + def test_setitem_slice(self): + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + mr[1:3] = [Range(31, 41), Range(51, 61)] + assert mr == Multirange([Range(10, 20), Range(31, 41), Range(51, 61)]) + + mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) + mr[1:3] = Range(31, 41) + assert mr == Multirange([Range(10, 20), Range(31, 41)]) + def test_delitem(self): mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)]) del mr[1] @@ -45,6 +78,11 @@ class TestMultirangeObject: del mr[-2] assert mr == Multirange([Range(50, 60)]) + def test_insert(self): + mr = Multirange([Range(10, 20), Range(50, 60)]) + mr.insert(1, Range(31, 41)) + assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)]) + def test_relations(self): mr1 = Multirange([Range(10, 20), Range(30, 40)]) mr2 = Multirange([Range(11, 20), Range(30, 40)]) @@ -74,9 +112,316 @@ class TestMultirangeObject: assert repr(mr) == expected +tzinfo = dt.timezone(dt.timedelta(hours=2)) + +samples = [ + ("int4multirange", [Range(None, None, "()")]), + ("int4multirange", [Range(10, 20), Range(30, 40)]), + ("int8multirange", [Range(None, None, "()")]), + ("int8multirange", [Range(10, 20), Range(30, 40)]), + ( + "nummultirange", + [ + Range(None, Decimal(-100)), + Range(Decimal(100), Decimal("100.123")), + ], + ), + ( + "datemultirange", + [Range(dt.date(2000, 1, 1), dt.date(2020, 1, 1))], + ), + ( + "tsmultirange", + [ + Range( + dt.datetime(2000, 1, 1, 00, 00), + dt.datetime(2020, 1, 1, 23, 59, 59, 999999), + ) + ], + ), + ( + "tstzmultirange", + [ + Range( + dt.datetime(2000, 1, 1, 00, 00, tzinfo=tzinfo), + dt.datetime(2020, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo), + ), + Range( + dt.datetime(2030, 1, 1, 00, 00, tzinfo=tzinfo), + dt.datetime(2040, 1, 1, 23, 59, 59, 999999, tzinfo=tzinfo), + ), + ], + ), +] + +mr_names = """ + int4multirange int8multirange nummultirange + datemultirange tsmultirange tstzmultirange""".split() + +mr_classes = """ + Int4Multirange Int8Multirange NumericMultirange + DateMultirange TimestampMultirange TimestamptzMultirange""".split() + + @pytest.mark.parametrize("pgtype", mr_names) @pytest.mark.parametrize("fmt_in", PyFormat) def test_dump_builtin_empty(conn, pgtype, fmt_in): mr = Multirange() cur = conn.execute(f"select '{{}}'::{pgtype} = %{fmt_in}", (mr,)) assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("wrapper", mr_classes) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_empty_wrapper(conn, wrapper, fmt_in): + dumper = getattr(multirange, wrapper + "Dumper") + wrapper = getattr(multirange, wrapper) + mr = wrapper() + rec = conn.execute( + f""" + select '{{}}' = %(mr){fmt_in}, + %(mr){fmt_in}::text, + pg_typeof(%(mr){fmt_in})::oid + """, + {"mr": mr}, + ).fetchone() + assert rec[0] is True, rec[1] + assert rec[2] == dumper.oid + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize( + "fmt_in", + [ + PyFormat.AUTO, + PyFormat.TEXT, + # There are many ways to work around this (use text, use a cast on the + # placeholder, use specific Range subclasses). + pytest.param( + PyFormat.BINARY, + marks=pytest.mark.xfail( + reason="can't dump array of untypes binary multirange without cast" + ), + ), + ], +) +def test_dump_builtin_array(conn, pgtype, fmt_in): + mr1 = Multirange() + mr2 = Multirange([Range(bounds="()")]) + cur = conn.execute( + f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}] = %{fmt_in}", + ([mr1, mr2],), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_array_with_cast(conn, pgtype, fmt_in): + mr1 = Multirange() + mr2 = Multirange([Range(bounds="()")]) + cur = conn.execute( + f""" + select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}] = %{fmt_in}::{pgtype}[] + """, + ([mr1, mr2],), + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("wrapper", mr_classes) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in): + wrapper = getattr(multirange, wrapper) + mr1 = Multirange() + mr2 = Multirange([Range(bounds="()")]) + cur = conn.execute( + f"""select '{{"{{}}","{{(,)}}"}}' = %{fmt_in}""", ([mr1, mr2],) + ) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype, ranges", samples) +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_dump_builtin_multirange(conn, pgtype, ranges, fmt_in): + mr = Multirange(ranges) + rname = pgtype.replace("multi", "") + phs = ", ".join([f"%s::{rname}"] * len(ranges)) + cur = conn.execute(f"select {pgtype}({phs}) = %{fmt_in}", ranges + [mr]) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_empty(conn, pgtype, fmt_out): + mr = Multirange() + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute(f"select '{{}}'::{pgtype}").fetchone() + assert type(got) is Multirange + assert got == mr + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_array(conn, pgtype, fmt_out): + mr1 = Multirange() + mr2 = Multirange([Range(bounds="()")]) + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute( + f"select array['{{}}'::{pgtype}, '{{(,)}}'::{pgtype}]" + ).fetchone() + assert got == [mr1, mr2] + + +@pytest.mark.parametrize("pgtype, ranges", samples) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_builtin_range(conn, pgtype, ranges, fmt_out): + mr = Multirange(ranges) + rname = pgtype.replace("multi", "") + phs = ", ".join([f"%s::{rname}"] * len(ranges)) + cur = conn.cursor(binary=fmt_out) + cur.execute(f"select {pgtype}({phs})", ranges) + assert cur.fetchone()[0] == mr + + +@pytest.mark.parametrize( + "min, max, bounds", + [ + ("2000,1,1", "2001,1,1", "[)"), + ("2000,1,1", None, "[)"), + (None, "2001,1,1", "()"), + (None, None, "()"), + (None, None, "empty"), + ], +) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in(conn, min, max, bounds, format): + cur = conn.cursor() + cur.execute( + "create table copymr (id serial primary key, mr datemultirange)" + ) + + if bounds != "empty": + min = dt.date(*map(int, min.split(","))) if min else None + max = dt.date(*map(int, max.split(","))) if max else None + r = Range(min, max, bounds) + else: + r = Range(empty=True) + + mr = Multirange([r]) + try: + with cur.copy( + f"copy copymr (mr) from stdin (format {format.name})" + ) as copy: + copy.write_row([mr]) + except e.InternalError_: + if not min and not max and format == pq.Format.BINARY: + pytest.xfail( + "TODO: add annotation to dump multirange with no type info" + ) + else: + raise + + rec = cur.execute("select mr from copymr order by id").fetchone() + if not r.isempty: + assert rec[0] == mr + else: + assert rec[0] == Multirange() + + +@pytest.mark.parametrize("wrapper", mr_classes) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in_empty_wrappers(conn, wrapper, format): + cur = conn.cursor() + cur.execute( + "create table copymr (id serial primary key, mr datemultirange)" + ) + + mr = getattr(multirange, wrapper)() + + with cur.copy( + f"copy copymr (mr) from stdin (format {format.name})" + ) as copy: + copy.write_row([mr]) + + rec = cur.execute("select mr from copymr order by id").fetchone() + assert rec[0] == mr + + +@pytest.mark.parametrize("pgtype", mr_names) +@pytest.mark.parametrize("format", pq.Format) +def test_copy_in_empty_set_type(conn, pgtype, format): + cur = conn.cursor() + cur.execute(f"create table copymr (id serial primary key, mr {pgtype})") + + mr = Multirange() + + with cur.copy( + f"copy copymr (mr) from stdin (format {format.name})" + ) as copy: + copy.set_types([pgtype]) + copy.write_row([mr]) + + rec = cur.execute("select mr from copymr order by id").fetchone() + assert rec[0] == mr + + +@pytest.fixture(scope="session") +def testmr(svcconn): + create_test_range(svcconn) + + +fetch_cases = [ + ("testmultirange", "text"), + ("testschema.testmultirange", "float8"), + (Identifier("testmultirange"), "text"), + (Identifier("testschema", "testmultirange"), "float8"), +] + + +@pytest.mark.parametrize("name, subtype", fetch_cases) +def test_fetch_info(conn, testmr, name, subtype): + info = MultirangeInfo.fetch(conn, name) + assert info.name == "testmultirange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == conn.adapters.types[subtype].oid + + +def test_fetch_info_not_found(conn): + assert MultirangeInfo.fetch(conn, "nosuchrange") is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name, subtype", fetch_cases) +async def test_fetch_info_async(aconn, testmr, name, subtype): # noqa: F811 + info = await MultirangeInfo.fetch(aconn, name) + assert info.name == "testmultirange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == aconn.adapters.types[subtype].oid + + +@pytest.mark.asyncio +async def test_fetch_info_not_found_async(aconn): + assert await MultirangeInfo.fetch(aconn, "nosuchrange") is None + + +def test_dump_custom_empty(conn, testmr): + info = MultirangeInfo.fetch(conn, "testmultirange") + register_multirange(info, conn) + + r = Multirange() + cur = conn.execute("select '{}'::testmultirange = %s", (r,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_load_custom_empty(conn, testmr, fmt_out): + info = MultirangeInfo.fetch(conn, "testmultirange") + register_multirange(info, conn) + + cur = conn.cursor(binary=fmt_out) + (got,) = cur.execute("select '{}'::testmultirange").fetchone() + assert isinstance(got, Multirange) + assert not got diff --git a/tests/types/test_range.py b/tests/types/test_range.py index 000d6e679..8e31ded39 100644 --- a/tests/types/test_range.py +++ b/tests/types/test_range.py @@ -4,8 +4,8 @@ from decimal import Decimal import pytest -import psycopg.errors from psycopg import pq +from psycopg import errors as e from psycopg.sql import Identifier from psycopg.adapt import PyFormat from psycopg.types import range as range_module @@ -48,11 +48,13 @@ samples = [ ), ] -range_names = """int4range int8range numrange - daterange tsrange tstzrange""".split() +range_names = """ + int4range int8range numrange daterange tsrange tstzrange + """.split() -range_classes = """Int4Range Int8Range NumericRange - DateRange TimestampRange TimestamptzRange""".split() +range_classes = """ + Int4Range Int8Range NumericRange DateRange TimestampRange TimestamptzRange + """.split() @pytest.mark.parametrize("pgtype", range_names) @@ -199,7 +201,7 @@ def test_load_builtin_range(conn, pgtype, min, max, bounds, fmt_out): ], ) @pytest.mark.parametrize("format", pq.Format) -def test_copy_in_empty(conn, min, max, bounds, format): +def test_copy_in(conn, min, max, bounds, format): cur = conn.cursor() cur.execute("create table copyrange (id serial primary key, r daterange)") @@ -215,10 +217,10 @@ def test_copy_in_empty(conn, min, max, bounds, format): f"copy copyrange (r) from stdin (format {format.name})" ) as copy: copy.write_row([r]) - except psycopg.errors.ProtocolViolation: + except e.ProtocolViolation: if not min and not max and format == pq.Format.BINARY: pytest.xfail( - "TODO: add annotation to dump array with no type info" + "TODO: add annotation to dump ranges with no type info" ) else: raise @@ -267,7 +269,11 @@ def test_copy_in_empty_set_type(conn, bounds, pgtype, format): @pytest.fixture(scope="session") def testrange(svcconn): - svcconn.execute( + create_test_range(svcconn) + + +def create_test_range(conn): + conn.execute( """ create schema if not exists testschema; -- 2.47.2