# Copyright (C) 2020 The Psycopg Team
-from typing import Any, cast, Dict, Generic, Optional, TypeVar, Type
+import re
+from typing import Any, Dict, Generic, List, Optional, TypeVar, Type, Union
+from typing import cast, TYPE_CHECKING
from decimal import Decimal
from datetime import date, datetime
-from ..oids import builtins
+from .. import sql
+from .. import errors as e
+from ..oids import builtins, TypeInfo
from ..adapt import Format, Dumper, Loader
+from ..proto import AdaptContext
+
+from . import array
from .composite import SequenceDumper, BaseCompositeLoader
+if TYPE_CHECKING:
+ from ..connection import Connection, AsyncConnection
+
T = TypeVar("T")
b",",
)
+ _re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
+
class RangeLoader(BaseCompositeLoader, Generic[T]):
"""Generic loader for a range.
class TimestampTZRangeLoader(RangeLoader[datetime]):
subtype_oid = builtins["timestamptz"].oid
cls = DateTimeTZRange
+
+
+class RangeInfo(TypeInfo):
+ """Manage information about a range type.
+
+ The class allows to:
+
+ - read information about a range type using `fetch()` and `fetch_async()`
+ - configure a composite type adaptation using `register()`
+ """
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ subtype_oid: int,
+ ):
+ super().__init__(name, oid, array_oid)
+ self.subtype_oid = subtype_oid
+
+ @classmethod
+ def fetch(
+ cls, conn: "Connection", name: Union[str, sql.Identifier]
+ ) -> Optional["RangeInfo"]:
+ if isinstance(name, sql.Composable):
+ name = name.as_string(conn)
+ cur = conn.cursor(format=Format.BINARY)
+ cur.execute(cls._info_query, {"name": name})
+ recs = cur.fetchall()
+ return cls._from_records(recs)
+
+ @classmethod
+ async def fetch_async(
+ cls, conn: "AsyncConnection", name: Union[str, sql.Identifier]
+ ) -> Optional["RangeInfo"]:
+ if isinstance(name, sql.Composable):
+ name = name.as_string(conn)
+ cur = await conn.cursor(format=Format.BINARY)
+ await cur.execute(cls._info_query, {"name": name})
+ recs = await cur.fetchall()
+ return cls._from_records(recs)
+
+ def register(
+ self,
+ context: AdaptContext = None,
+ range_class: Optional[Type[Range[Any]]] = None,
+ ) -> None:
+ if not range_class:
+ range_class = type(self.name.title(), (Range,), {})
+
+ # generate and register a customized text dumper
+ dumper: Type[Dumper] = type(
+ f"{self.name.title()}Dumper", (RangeDumper,), {"oid": self.oid}
+ )
+ dumper.register(range_class, context=context, format=Format.TEXT)
+
+ # generate and register a customized text loader
+ loader: Type[Loader] = type(
+ f"{self.name.title()}Loader",
+ (RangeLoader,),
+ {"cls": range_class, "subtype_oid": self.subtype_oid},
+ )
+ loader.register(self.oid, context=context, format=Format.TEXT)
+
+ if self.array_oid:
+ array.register(
+ self.array_oid, self.oid, context=context, name=self.name
+ )
+
+ @classmethod
+ def _from_records(cls, recs: List[Any]) -> Optional["RangeInfo"]:
+ if not recs:
+ return None
+ if len(recs) > 1:
+ raise e.ProgrammingError(
+ f"found {len(recs)} different ranges named {recs[0][0]}"
+ )
+
+ name, oid, array_oid, subtype = recs[0]
+ return cls(name, oid, array_oid, subtype)
+
+ _info_query = """\
+select t.typname as name, t.oid as oid, t.typarray as array_oid,
+ r.rngsubtype as subtype_oid
+from pg_type t
+join pg_range r on t.oid = r.rngtypid
+where t.oid = %(name)s::regtype
+"""
import pytest
-from psycopg3.types import range
+from psycopg3.sql import Identifier
+from psycopg3.oids import builtins
+from psycopg3.types import range as mrange
type2cls = {
- "int4range": range.Int4Range,
- "int8range": range.Int8Range,
- "numrange": range.DecimalRange,
- "daterange": range.DateRange,
- "tsrange": range.DateTimeRange,
- "tstzrange": range.DateTimeTZRange,
+ "int4range": mrange.Int4Range,
+ "int8range": mrange.Int8Range,
+ "numrange": mrange.DecimalRange,
+ "daterange": mrange.DateRange,
+ "tsrange": mrange.DateTimeRange,
+ "tstzrange": mrange.DateTimeTZRange,
}
type2sub = {
"int4range": "int4",
"pgtype",
"int4range int8range numrange daterange tsrange tstzrange".split(),
)
-def test_dump_builtin_range_empty(conn, pgtype):
+def test_dump_builtin_empty(conn, pgtype):
r = type2cls[pgtype](empty=True)
cur = conn.cursor()
cur.execute(f"select 'empty'::{pgtype} = %s", (r,))
"pgtype",
"int4range int8range numrange daterange tsrange tstzrange".split(),
)
-def test_load_builtin_range_empty(conn, pgtype):
+def test_load_builtin_empty(conn, pgtype):
r = type2cls[pgtype](empty=True)
cur = conn.cursor()
(got,) = cur.execute(f"select 'empty'::{pgtype}").fetchone()
bounds = "[)" if r.lower_inc else "()"
r = type(r)(r.lower, r.upper + 1, bounds)
assert cur.fetchone()[0] == r
+
+
+@pytest.fixture(scope="session")
+def testrange(svcconn):
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ create schema if not exists testschema;
+
+ drop type if exists testrange cascade;
+ drop type if exists testschema.testrange cascade;
+
+ create type testrange as range (subtype = text, collation = "C");
+ create type testschema.testrange as range (subtype = float8);
+ """
+ )
+
+
+fetch_cases = [
+ ("testrange", "text"),
+ ("testschema.testrange", "float8"),
+ (Identifier("testrange"), "text"),
+ (Identifier("testschema", "testrange"), "float8"),
+]
+
+
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+def test_fetch_info(conn, testrange, name, subtype):
+ info = mrange.RangeInfo.fetch(conn, name)
+ assert info.name == "testrange"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert info.subtype_oid == builtins[subtype].oid
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("name, subtype", fetch_cases)
+async def test_fetch_info_async(aconn, testrange, name, subtype):
+ info = await mrange.RangeInfo.fetch_async(aconn, name)
+ assert info.name == "testrange"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert info.subtype_oid == builtins[subtype].oid
+
+
+def test_dump_custom_empty(conn, testrange):
+ class StrRange(mrange.Range):
+ pass
+
+ info = mrange.RangeInfo.fetch(conn, "testrange")
+ info.register(conn, range_class=StrRange)
+
+ r = StrRange(empty=True)
+ cur = conn.cursor()
+ cur.execute("select 'empty'::testrange = %s", (r,))
+ assert cur.fetchone()[0] is True
+
+
+def test_dump_quoting(conn, testrange):
+ class StrRange(mrange.Range):
+ pass
+
+ info = mrange.RangeInfo.fetch(conn, "testrange")
+ info.register(conn, range_class=StrRange)
+ cur = conn.cursor()
+ for i in range(1, 254):
+ cur.execute(
+ "select ascii(lower(%(r)s)) = %(low)s and ascii(upper(%(r)s)) = %(up)s",
+ {"r": StrRange(chr(i), chr(i + 1)), "low": i, "up": i + 1},
+ )
+ assert cur.fetchone()[0] is True
+
+
+def test_load_custom_empty(conn, testrange):
+ info = mrange.RangeInfo.fetch(conn, "testrange")
+ info.register(conn)
+
+ cur = conn.cursor()
+ (got,) = cur.execute("select 'empty'::testrange").fetchone()
+ assert isinstance(got, mrange.Range)
+ assert got.isempty
+
+
+def test_load_quoting(conn, testrange):
+ info = mrange.RangeInfo.fetch(conn, "testrange")
+ info.register(conn)
+ cur = conn.cursor()
+ for i in range(1, 254):
+ cur.execute(
+ "select testrange(chr(%(low)s::int), chr(%(up)s::int))",
+ {"low": i, "up": i + 1},
+ )
+ got = cur.fetchone()[0]
+ assert isinstance(got, mrange.Range)
+ assert ord(got.lower) == i
+ assert ord(got.upper) == i + 1