From f00b65bbdbc51fde9356cda08dc080203f9dce4c Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 6 Dec 2020 02:50:50 +0000 Subject: [PATCH] Added range type fetching and registration Fixed quoting of [ ] chars for range types: unlike for composite they need quoting. --- psycopg3/psycopg3/types/range.py | 105 +++++++++++++++++++++++++++- tests/types/test_range.py | 116 ++++++++++++++++++++++++++++--- 2 files changed, 210 insertions(+), 11 deletions(-) diff --git a/psycopg3/psycopg3/types/range.py b/psycopg3/psycopg3/types/range.py index 3f21a3683..56e8901be 100644 --- a/psycopg3/psycopg3/types/range.py +++ b/psycopg3/psycopg3/types/range.py @@ -4,14 +4,24 @@ Support for range types adaptation. # Copyright (C) 2020 The Psycopg Team -from typing import Any, cast, Dict, Generic, Optional, TypeVar, Type +import re +from typing import Any, Dict, Generic, List, Optional, TypeVar, Type, Union +from typing import cast, TYPE_CHECKING from decimal import Decimal from datetime import date, datetime -from ..oids import builtins +from .. import sql +from .. import errors as e +from ..oids import builtins, TypeInfo from ..adapt import Format, Dumper, Loader +from ..proto import AdaptContext + +from . import array from .composite import SequenceDumper, BaseCompositeLoader +if TYPE_CHECKING: + from ..connection import Connection, AsyncConnection + T = TypeVar("T") @@ -222,6 +232,8 @@ class RangeDumper(SequenceDumper): b",", ) + _re_needs_quotes = re.compile(br'[",\\\s()\[\]]') + class RangeLoader(BaseCompositeLoader, Generic[T]): """Generic loader for a range. @@ -342,3 +354,92 @@ class TimestampRangeLoader(RangeLoader[datetime]): class TimestampTZRangeLoader(RangeLoader[datetime]): subtype_oid = builtins["timestamptz"].oid cls = DateTimeTZRange + + +class RangeInfo(TypeInfo): + """Manage information about a range type. + + The class allows to: + + - read information about a range type using `fetch()` and `fetch_async()` + - configure a composite type adaptation using `register()` + """ + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + subtype_oid: int, + ): + super().__init__(name, oid, array_oid) + self.subtype_oid = subtype_oid + + @classmethod + def fetch( + cls, conn: "Connection", name: Union[str, sql.Identifier] + ) -> Optional["RangeInfo"]: + if isinstance(name, sql.Composable): + name = name.as_string(conn) + cur = conn.cursor(format=Format.BINARY) + cur.execute(cls._info_query, {"name": name}) + recs = cur.fetchall() + return cls._from_records(recs) + + @classmethod + async def fetch_async( + cls, conn: "AsyncConnection", name: Union[str, sql.Identifier] + ) -> Optional["RangeInfo"]: + if isinstance(name, sql.Composable): + name = name.as_string(conn) + cur = await conn.cursor(format=Format.BINARY) + await cur.execute(cls._info_query, {"name": name}) + recs = await cur.fetchall() + return cls._from_records(recs) + + def register( + self, + context: AdaptContext = None, + range_class: Optional[Type[Range[Any]]] = None, + ) -> None: + if not range_class: + range_class = type(self.name.title(), (Range,), {}) + + # generate and register a customized text dumper + dumper: Type[Dumper] = type( + f"{self.name.title()}Dumper", (RangeDumper,), {"oid": self.oid} + ) + dumper.register(range_class, context=context, format=Format.TEXT) + + # generate and register a customized text loader + loader: Type[Loader] = type( + f"{self.name.title()}Loader", + (RangeLoader,), + {"cls": range_class, "subtype_oid": self.subtype_oid}, + ) + loader.register(self.oid, context=context, format=Format.TEXT) + + if self.array_oid: + array.register( + self.array_oid, self.oid, context=context, name=self.name + ) + + @classmethod + def _from_records(cls, recs: List[Any]) -> Optional["RangeInfo"]: + if not recs: + return None + if len(recs) > 1: + raise e.ProgrammingError( + f"found {len(recs)} different ranges named {recs[0][0]}" + ) + + name, oid, array_oid, subtype = recs[0] + return cls(name, oid, array_oid, subtype) + + _info_query = """\ +select t.typname as name, t.oid as oid, t.typarray as array_oid, + r.rngsubtype as subtype_oid +from pg_type t +join pg_range r on t.oid = r.rngtypid +where t.oid = %(name)s::regtype +""" diff --git a/tests/types/test_range.py b/tests/types/test_range.py index 91eb669a6..93ad24be4 100644 --- a/tests/types/test_range.py +++ b/tests/types/test_range.py @@ -1,15 +1,17 @@ import pytest -from psycopg3.types import range +from psycopg3.sql import Identifier +from psycopg3.oids import builtins +from psycopg3.types import range as mrange type2cls = { - "int4range": range.Int4Range, - "int8range": range.Int8Range, - "numrange": range.DecimalRange, - "daterange": range.DateRange, - "tsrange": range.DateTimeRange, - "tstzrange": range.DateTimeTZRange, + "int4range": mrange.Int4Range, + "int8range": mrange.Int8Range, + "numrange": mrange.DecimalRange, + "daterange": mrange.DateRange, + "tsrange": mrange.DateTimeRange, + "tstzrange": mrange.DateTimeTZRange, } type2sub = { "int4range": "int4", @@ -35,7 +37,7 @@ samples = [ "pgtype", "int4range int8range numrange daterange tsrange tstzrange".split(), ) -def test_dump_builtin_range_empty(conn, pgtype): +def test_dump_builtin_empty(conn, pgtype): r = type2cls[pgtype](empty=True) cur = conn.cursor() cur.execute(f"select 'empty'::{pgtype} = %s", (r,)) @@ -58,7 +60,7 @@ def test_dump_builtin_range(conn, pgtype, min, max, bounds): "pgtype", "int4range int8range numrange daterange tsrange tstzrange".split(), ) -def test_load_builtin_range_empty(conn, pgtype): +def test_load_builtin_empty(conn, pgtype): r = type2cls[pgtype](empty=True) cur = conn.cursor() (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone() @@ -79,3 +81,99 @@ def test_load_builtin_range(conn, pgtype, min, max, bounds): bounds = "[)" if r.lower_inc else "()" r = type(r)(r.lower, r.upper + 1, bounds) assert cur.fetchone()[0] == r + + +@pytest.fixture(scope="session") +def testrange(svcconn): + cur = svcconn.cursor() + cur.execute( + """ + create schema if not exists testschema; + + drop type if exists testrange cascade; + drop type if exists testschema.testrange cascade; + + create type testrange as range (subtype = text, collation = "C"); + create type testschema.testrange as range (subtype = float8); + """ + ) + + +fetch_cases = [ + ("testrange", "text"), + ("testschema.testrange", "float8"), + (Identifier("testrange"), "text"), + (Identifier("testschema", "testrange"), "float8"), +] + + +@pytest.mark.parametrize("name, subtype", fetch_cases) +def test_fetch_info(conn, testrange, name, subtype): + info = mrange.RangeInfo.fetch(conn, name) + assert info.name == "testrange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == builtins[subtype].oid + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name, subtype", fetch_cases) +async def test_fetch_info_async(aconn, testrange, name, subtype): + info = await mrange.RangeInfo.fetch_async(aconn, name) + assert info.name == "testrange" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert info.subtype_oid == builtins[subtype].oid + + +def test_dump_custom_empty(conn, testrange): + class StrRange(mrange.Range): + pass + + info = mrange.RangeInfo.fetch(conn, "testrange") + info.register(conn, range_class=StrRange) + + r = StrRange(empty=True) + cur = conn.cursor() + cur.execute("select 'empty'::testrange = %s", (r,)) + assert cur.fetchone()[0] is True + + +def test_dump_quoting(conn, testrange): + class StrRange(mrange.Range): + pass + + info = mrange.RangeInfo.fetch(conn, "testrange") + info.register(conn, range_class=StrRange) + cur = conn.cursor() + for i in range(1, 254): + cur.execute( + "select ascii(lower(%(r)s)) = %(low)s and ascii(upper(%(r)s)) = %(up)s", + {"r": StrRange(chr(i), chr(i + 1)), "low": i, "up": i + 1}, + ) + assert cur.fetchone()[0] is True + + +def test_load_custom_empty(conn, testrange): + info = mrange.RangeInfo.fetch(conn, "testrange") + info.register(conn) + + cur = conn.cursor() + (got,) = cur.execute("select 'empty'::testrange").fetchone() + assert isinstance(got, mrange.Range) + assert got.isempty + + +def test_load_quoting(conn, testrange): + info = mrange.RangeInfo.fetch(conn, "testrange") + info.register(conn) + cur = conn.cursor() + for i in range(1, 254): + cur.execute( + "select testrange(chr(%(low)s::int), chr(%(up)s::int))", + {"low": i, "up": i + 1}, + ) + got = cur.fetchone()[0] + assert isinstance(got, mrange.Range) + assert ord(got.lower) == i + assert ord(got.upper) == i + 1 -- 2.47.2