]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added range type fetching and registration
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 6 Dec 2020 02:50:50 +0000 (02:50 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 10 Dec 2020 01:43:50 +0000 (02:43 +0100)
Fixed quoting of [ ] chars for range types: unlike for composite they
need quoting.

psycopg3/psycopg3/types/range.py
tests/types/test_range.py

index 3f21a36831a89c41b98d862622e5c8ec5fb8d087..56e8901be79cbbcd94b1d5a0750c6cc148a92773 100644 (file)
@@ -4,14 +4,24 @@ Support for range types adaptation.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, cast, Dict, Generic, Optional, TypeVar, Type
+import re
+from typing import Any, Dict, Generic, List, Optional, TypeVar, Type, Union
+from typing import cast, TYPE_CHECKING
 from decimal import Decimal
 from datetime import date, datetime
 
-from ..oids import builtins
+from .. import sql
+from .. import errors as e
+from ..oids import builtins, TypeInfo
 from ..adapt import Format, Dumper, Loader
+from ..proto import AdaptContext
+
+from . import array
 from .composite import SequenceDumper, BaseCompositeLoader
 
+if TYPE_CHECKING:
+    from ..connection import Connection, AsyncConnection
+
 T = TypeVar("T")
 
 
@@ -222,6 +232,8 @@ class RangeDumper(SequenceDumper):
                 b",",
             )
 
+    _re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
+
 
 class RangeLoader(BaseCompositeLoader, Generic[T]):
     """Generic loader for a range.
@@ -342,3 +354,92 @@ class TimestampRangeLoader(RangeLoader[datetime]):
 class TimestampTZRangeLoader(RangeLoader[datetime]):
     subtype_oid = builtins["timestamptz"].oid
     cls = DateTimeTZRange
+
+
+class RangeInfo(TypeInfo):
+    """Manage information about a range type.
+
+    The class allows to:
+
+    - read information about a range type using `fetch()` and `fetch_async()`
+    - configure a composite type adaptation using `register()`
+    """
+
+    def __init__(
+        self,
+        name: str,
+        oid: int,
+        array_oid: int,
+        subtype_oid: int,
+    ):
+        super().__init__(name, oid, array_oid)
+        self.subtype_oid = subtype_oid
+
+    @classmethod
+    def fetch(
+        cls, conn: "Connection", name: Union[str, sql.Identifier]
+    ) -> Optional["RangeInfo"]:
+        if isinstance(name, sql.Composable):
+            name = name.as_string(conn)
+        cur = conn.cursor(format=Format.BINARY)
+        cur.execute(cls._info_query, {"name": name})
+        recs = cur.fetchall()
+        return cls._from_records(recs)
+
+    @classmethod
+    async def fetch_async(
+        cls, conn: "AsyncConnection", name: Union[str, sql.Identifier]
+    ) -> Optional["RangeInfo"]:
+        if isinstance(name, sql.Composable):
+            name = name.as_string(conn)
+        cur = await conn.cursor(format=Format.BINARY)
+        await cur.execute(cls._info_query, {"name": name})
+        recs = await cur.fetchall()
+        return cls._from_records(recs)
+
+    def register(
+        self,
+        context: AdaptContext = None,
+        range_class: Optional[Type[Range[Any]]] = None,
+    ) -> None:
+        if not range_class:
+            range_class = type(self.name.title(), (Range,), {})
+
+        # generate and register a customized text dumper
+        dumper: Type[Dumper] = type(
+            f"{self.name.title()}Dumper", (RangeDumper,), {"oid": self.oid}
+        )
+        dumper.register(range_class, context=context, format=Format.TEXT)
+
+        # generate and register a customized text loader
+        loader: Type[Loader] = type(
+            f"{self.name.title()}Loader",
+            (RangeLoader,),
+            {"cls": range_class, "subtype_oid": self.subtype_oid},
+        )
+        loader.register(self.oid, context=context, format=Format.TEXT)
+
+        if self.array_oid:
+            array.register(
+                self.array_oid, self.oid, context=context, name=self.name
+            )
+
+    @classmethod
+    def _from_records(cls, recs: List[Any]) -> Optional["RangeInfo"]:
+        if not recs:
+            return None
+        if len(recs) > 1:
+            raise e.ProgrammingError(
+                f"found {len(recs)} different ranges named {recs[0][0]}"
+            )
+
+        name, oid, array_oid, subtype = recs[0]
+        return cls(name, oid, array_oid, subtype)
+
+    _info_query = """\
+select t.typname as name, t.oid as oid, t.typarray as array_oid,
+    r.rngsubtype as subtype_oid
+from pg_type t
+join pg_range r on t.oid = r.rngtypid
+where t.oid = %(name)s::regtype
+"""
index 91eb669a6efc48ea8c7cd1b56ea39adc17e52442..93ad24be428e35edc6144ebd099d0c564522e7ed 100644 (file)
@@ -1,15 +1,17 @@
 import pytest
 
-from psycopg3.types import range
+from psycopg3.sql import Identifier
+from psycopg3.oids import builtins
+from psycopg3.types import range as mrange
 
 
 type2cls = {
-    "int4range": range.Int4Range,
-    "int8range": range.Int8Range,
-    "numrange": range.DecimalRange,
-    "daterange": range.DateRange,
-    "tsrange": range.DateTimeRange,
-    "tstzrange": range.DateTimeTZRange,
+    "int4range": mrange.Int4Range,
+    "int8range": mrange.Int8Range,
+    "numrange": mrange.DecimalRange,
+    "daterange": mrange.DateRange,
+    "tsrange": mrange.DateTimeRange,
+    "tstzrange": mrange.DateTimeTZRange,
 }
 type2sub = {
     "int4range": "int4",
@@ -35,7 +37,7 @@ samples = [
     "pgtype",
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
-def test_dump_builtin_range_empty(conn, pgtype):
+def test_dump_builtin_empty(conn, pgtype):
     r = type2cls[pgtype](empty=True)
     cur = conn.cursor()
     cur.execute(f"select 'empty'::{pgtype} = %s", (r,))
@@ -58,7 +60,7 @@ def test_dump_builtin_range(conn, pgtype, min, max, bounds):
     "pgtype",
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
-def test_load_builtin_range_empty(conn, pgtype):
+def test_load_builtin_empty(conn, pgtype):
     r = type2cls[pgtype](empty=True)
     cur = conn.cursor()
     (got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone()
@@ -79,3 +81,99 @@ def test_load_builtin_range(conn, pgtype, min, max, bounds):
         bounds = "[)" if r.lower_inc else "()"
         r = type(r)(r.lower, r.upper + 1, bounds)
     assert cur.fetchone()[0] == r
+
+
+@pytest.fixture(scope="session")
+def testrange(svcconn):
+    cur = svcconn.cursor()
+    cur.execute(
+        """
+        create schema if not exists testschema;
+
+        drop type if exists testrange cascade;
+        drop type if exists testschema.testrange cascade;
+
+        create type testrange as range (subtype = text, collation = "C");
+        create type testschema.testrange as range (subtype = float8);
+        """
+    )
+
+
+fetch_cases = [
+    ("testrange", "text"),
+    ("testschema.testrange", "float8"),
+    (Identifier("testrange"), "text"),
+    (Identifier("testschema", "testrange"), "float8"),
+]
+
+
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+def test_fetch_info(conn, testrange, name, subtype):
+    info = mrange.RangeInfo.fetch(conn, name)
+    assert info.name == "testrange"
+    assert info.oid > 0
+    assert info.oid != info.array_oid > 0
+    assert info.subtype_oid == builtins[subtype].oid
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+async def test_fetch_info_async(aconn, testrange, name, subtype):
+    info = await mrange.RangeInfo.fetch_async(aconn, name)
+    assert info.name == "testrange"
+    assert info.oid > 0
+    assert info.oid != info.array_oid > 0
+    assert info.subtype_oid == builtins[subtype].oid
+
+
+def test_dump_custom_empty(conn, testrange):
+    class StrRange(mrange.Range):
+        pass
+
+    info = mrange.RangeInfo.fetch(conn, "testrange")
+    info.register(conn, range_class=StrRange)
+
+    r = StrRange(empty=True)
+    cur = conn.cursor()
+    cur.execute("select 'empty'::testrange = %s", (r,))
+    assert cur.fetchone()[0] is True
+
+
+def test_dump_quoting(conn, testrange):
+    class StrRange(mrange.Range):
+        pass
+
+    info = mrange.RangeInfo.fetch(conn, "testrange")
+    info.register(conn, range_class=StrRange)
+    cur = conn.cursor()
+    for i in range(1, 254):
+        cur.execute(
+            "select ascii(lower(%(r)s)) = %(low)s and ascii(upper(%(r)s)) = %(up)s",
+            {"r": StrRange(chr(i), chr(i + 1)), "low": i, "up": i + 1},
+        )
+        assert cur.fetchone()[0] is True
+
+
+def test_load_custom_empty(conn, testrange):
+    info = mrange.RangeInfo.fetch(conn, "testrange")
+    info.register(conn)
+
+    cur = conn.cursor()
+    (got,) = cur.execute("select 'empty'::testrange").fetchone()
+    assert isinstance(got, mrange.Range)
+    assert got.isempty
+
+
+def test_load_quoting(conn, testrange):
+    info = mrange.RangeInfo.fetch(conn, "testrange")
+    info.register(conn)
+    cur = conn.cursor()
+    for i in range(1, 254):
+        cur.execute(
+            "select testrange(chr(%(low)s::int), chr(%(up)s::int))",
+            {"low": i, "up": i + 1},
+        )
+        got = cur.fetchone()[0]
+        assert isinstance(got, mrange.Range)
+        assert ord(got.lower) == i
+        assert ord(got.upper) == i + 1