]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added typecasting of composite types
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Apr 2020 09:00:16 +0000 (21:00 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Apr 2020 09:12:38 +0000 (21:12 +1200)
psycopg3/types/__init__.py
psycopg3/types/array.py
psycopg3/types/composite.py
psycopg3/types/oids.py
tests/fix_db.py
tests/types/test_composite.py

index 7b2556dfdf15472ae8984a50d31bf2226b9c5844..f578c651ca9f39e26cf85d321d15c4aa6a22b496 100644 (file)
@@ -8,9 +8,10 @@ psycopg3 types package
 from .oids import builtins
 
 # Register default adapters
-from . import composite, numeric, text  # noqa
+from . import array, composite, numeric, text  # noqa
 
 # Register associations with array oids
-from . import array  # noqa
+array.register_all_arrays()
+
 
 __all__ = ["builtins"]
index 5da427192de709e802b3f1e3e6ee4d8e9e124e78..13318cd1760f2ccab4ea731314386029dad1f389 100644 (file)
@@ -292,10 +292,16 @@ def register_array(
         TypeCaster.register(array_oid, t, context=context, format=format)
 
 
-# Register associations between array and base oids
-for t in builtins:
-    if t.array_oid and (
-        (t.oid, Format.TEXT) in TypeCaster.globals
-        or (t.oid, Format.BINARY) in TypeCaster.globals
-    ):
-        register_array(t.array_oid, t.oid, name=t.name)
+def register_all_arrays() -> None:
+    """
+    Associate the array oid of all the types in TypeCaster.globals.
+
+    This function is designed to be called once at import time, after having
+    registered all the base casters.
+    """
+    for t in builtins:
+        if t.array_oid and (
+            (t.oid, Format.TEXT) in TypeCaster.globals
+            or (t.oid, Format.BINARY) in TypeCaster.globals
+        ):
+            register_array(t.array_oid, t.oid, name=t.name)
index f663d4e57a6542e6c842dda5bbe4ae571c958d9b..a40498cba05855e79b7eef4fc97fd522660afaa2 100644 (file)
@@ -4,39 +4,125 @@ Support for composite types adaptation.
 
 import re
 import struct
-from typing import Any, Generator, Optional, Tuple
+from collections import namedtuple
+from typing import Any, Callable, Generator, List, Sequence, Tuple, Union
+from typing import Optional, TYPE_CHECKING
 
 from ..adapt import Format, TypeCaster, Transformer, AdaptContext
-from .oids import builtins
+from .oids import builtins, TypeInfo
+from .array import register_array
+
+if TYPE_CHECKING:
+    from ..connection import Connection
 
 
 TEXT_OID = builtins["text"].oid
 
 
-_re_tokenize = re.compile(
-    br"""(?x)
-      \(? ([,)])                        # an empty token, representing NULL
-    | \(? " ((?: [^"] | "")*) " [,)]    # or a quoted string
-    | \(? ([^",)]+) [,)]                # or an unquoted string
-    """
-)
+class FieldInfo:
+    def __init__(self, name: str, type_oid: int):
+        self.name = name
+        self.type_oid = type_oid
+
+
+class CompositeTypeInfo(TypeInfo):
+    def __init__(
+        self,
+        name: str,
+        oid: int,
+        array_oid: int,
+        fields: Sequence[Union[FieldInfo, Tuple[str, int]]],
+    ):
+        super().__init__(name, oid, array_oid)
+        self.fields: List[FieldInfo] = []
+        for f in fields:
+            if isinstance(f, FieldInfo):
+                self.fields.append(f)
+            elif isinstance(f, tuple):
+                self.fields.append(FieldInfo(f[0], f[1]))
+            else:
+                raise TypeError(f"bad field info: {f}")
+
+
+def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]:
+    cur = conn.cursor(binary=True)
+    cur.execute(_type_info_query, (name,))
+    rec = cur.fetchone()
+    if rec is not None:
+        return CompositeTypeInfo(*rec)
+    else:
+        return None
+
+
+def register(
+    info: CompositeTypeInfo,
+    context: AdaptContext = None,
+    factory: Optional[Callable[..., Any]] = None,
+) -> None:
+    if factory is None:
+        factory = namedtuple(  # type: ignore
+            info.name, [f.name for f in info.fields]
+        )
 
-_re_undouble = re.compile(br'(["\\])\1')
+    # generate and register a customized text typecaster
+    caster = type(
+        f"{info.name.title()}Caster",
+        (CompositeCaster,),
+        {
+            "factory": factory,
+            "fields_types": tuple(f.type_oid for f in info.fields),
+        },
+    )
+    TypeCaster.register(info.oid, caster, context=context, format=Format.TEXT)
+
+    # generate and register a customized binary typecaster
+    caster = type(
+        f"{info.name.title()}BinaryCaster",
+        (CompositeBinaryCaster,),
+        {"factory": factory},
+    )
+    TypeCaster.register(
+        info.oid, caster, context=context, format=Format.BINARY
+    )
+
+    if info.array_oid:
+        register_array(info.array_oid, info.oid)
+
+
+_type_info_query = """\
+select
+    name, oid, array_oid,
+    array_agg(row(field_name, field_type)) as fields
+from (
+    select
+        typname as name,
+        t.oid as oid,
+        t.typarray as array_oid,
+        a.attname as field_name,
+        a.atttypid as field_type
+    from pg_type t
+    left join pg_attribute a on a.attrelid = t.typrelid
+    where t.typname = %s
+    and a.attnum > 0
+    order by a.attnum
+) x
+group by name, oid, array_oid
+"""
 
 
 class BaseCompositeCaster(TypeCaster):
     def __init__(self, oid: int, context: AdaptContext = None):
         super().__init__(oid, context)
-        self.tx = Transformer(context)
+        self._tx = Transformer(context)
 
 
 @TypeCaster.text(builtins["record"].oid)
 class RecordCaster(BaseCompositeCaster):
     def cast(self, data: bytes) -> Tuple[Any, ...]:
-        cast = self.tx.get_cast_function(TEXT_OID, format=Format.TEXT)
+        cast = self._tx.get_cast_function(TEXT_OID, format=Format.TEXT)
         return tuple(
-            cast(item) if item is not None else None
-            for item in self._parse_record(data)
+            cast(token) if token is not None else None
+            for token in self._parse_record(data)
         )
 
     def _parse_record(
@@ -45,21 +131,31 @@ class RecordCaster(BaseCompositeCaster):
         if data == b"()":
             return
 
-        for m in _re_tokenize.finditer(data):
+        for m in self._re_tokenize.finditer(data):
             if m.group(1) is not None:
                 yield None
             elif m.group(2) is not None:
-                yield _re_undouble.sub(br"\1", m.group(2))
+                yield self._re_undouble.sub(br"\1", m.group(2))
             else:
                 yield m.group(3)
 
+    _re_tokenize = re.compile(
+        br"""(?x)
+          \(? ([,)])                        # an empty token, representing NULL
+        | \(? " ((?: [^"] | "")*) " [,)]    # or a quoted string
+        | \(? ([^",)]+) [,)]                # or an unquoted string
+        """
+    )
+
+    _re_undouble = re.compile(br'(["\\])\1')
+
 
 _struct_len = struct.Struct("!i")
 _struct_oidlen = struct.Struct("!Ii")
 
 
 @TypeCaster.binary(builtins["record"].oid)
-class BinaryRecordCaster(BaseCompositeCaster):
+class RecordBinaryCaster(BaseCompositeCaster):
     _types_set = False
 
     def cast(self, data: bytes) -> Tuple[Any, ...]:
@@ -68,7 +164,7 @@ class BinaryRecordCaster(BaseCompositeCaster):
             self._types_set = True
 
         return tuple(
-            self.tx.cast_sequence(
+            self._tx.cast_sequence(
                 data[offset : offset + length] if length != -1 else None
                 for _, offset, length in self._walk_record(data)
             )
@@ -88,6 +184,32 @@ class BinaryRecordCaster(BaseCompositeCaster):
             i += (8 + length) if length > 0 else 8
 
     def _config_types(self, data: bytes) -> None:
-        self.tx.set_row_types(
+        self._tx.set_row_types(
             (oid, Format.BINARY) for oid, _, _ in self._walk_record(data)
         )
+
+
+class CompositeCaster(RecordCaster):
+    factory: Callable[..., Any]
+    fields_types: Tuple[int, ...]
+    _types_set = False
+
+    def cast(self, data: bytes) -> Any:
+        if not self._types_set:
+            self._config_types(data)
+            self._types_set = True
+
+        return type(self).factory(
+            *self._tx.cast_sequence(self._parse_record(data))
+        )
+
+    def _config_types(self, data: bytes) -> None:
+        self._tx.set_row_types((oid, Format.TEXT) for oid in self.fields_types)
+
+
+class CompositeBinaryCaster(RecordBinaryCaster):
+    factory: Callable[..., Any]
+
+    def cast(self, data: bytes) -> Any:
+        r = super().cast(data)
+        return type(self).factory(*r)
index 1b211fd652b6ca3bfa38c769bcdeec24d746d59d..39a9d4206c8d16386354b6aa83660d927d58e987 100644 (file)
@@ -8,17 +8,30 @@ to a Postgres server.
 # Copyright (C) 2020 The Psycopg Team
 
 import re
-from typing import Dict, Generator, Optional, NamedTuple, Union
+from typing import Dict, Generator, Optional, Union
 
 INVALID_OID = 0
 
 
-class TypeInfo(NamedTuple):
-    name: str
-    oid: int
-    array_oid: int
-    alt_name: str
-    delimiter: str
+class TypeInfo:
+    def __init__(self, name: str, oid: int, array_oid: int):
+        self.name = name
+        self.oid = oid
+        self.array_oid = array_oid
+
+
+class BuiltinTypeInfo(TypeInfo):
+    def __init__(
+        self,
+        name: str,
+        oid: int,
+        array_oid: int,
+        alt_name: str,
+        delimiter: str,
+    ):
+        super().__init__(name, oid, array_oid)
+        self.alt_name = alt_name
+        self.delimiter = delimiter
 
 
 class TypesRegistry:
@@ -35,8 +48,10 @@ class TypesRegistry:
         if info.array_oid:
             self._by_oid[info.array_oid] = info
         self._by_name[info.name] = info
-        if info.alt_name not in self._by_name:
-            self._by_name[info.alt_name] = info
+
+        if isinstance(info, BuiltinTypeInfo):
+            if info.alt_name not in self._by_name:
+                self._by_name[info.alt_name] = info
 
     def __iter__(self) -> Generator[TypeInfo, None, None]:
         seen = set()
@@ -143,7 +158,7 @@ for r in [
     # autogenerated: end
     # fmt: on
 ]:
-    builtins.add(TypeInfo(*r))
+    builtins.add(BuiltinTypeInfo(*r))
 
 
 def self_update() -> None:
index 8487fedb88e0cfc128ba30f6355ab3b363cf992e..05743e38228c366cf469b3ed017d10729ebc9304 100644 (file)
@@ -87,7 +87,7 @@ def check_libpq_version(got, want):
         )
 
 
-@pytest.fixture
+@pytest.fixture(scope="session")
 def dsn(request):
     """Return the dsn used to connect to the `--test-dsn` database."""
     dsn = request.config.getoption("--test-dsn")
@@ -113,3 +113,13 @@ def conn(dsn):
     from psycopg3 import Connection
 
     return Connection.connect(dsn)
+
+
+@pytest.fixture(scope="session")
+def svcconn(dsn):
+    """
+    Return a session `Connection` connected to the ``--test-dsn`` database.
+    """
+    from psycopg3 import Connection
+
+    return Connection.connect(dsn)
index 9406e175c2765809aa06a6403e1ce5279e263d2c..cf04e38e759d6cf6ccae393243940e7ebfc03a60 100644 (file)
@@ -1,6 +1,8 @@
 import pytest
 
 from psycopg3.adapt import Format
+from psycopg3.types import builtins
+from psycopg3.types import composite
 
 
 @pytest.mark.parametrize(
@@ -73,3 +75,71 @@ def test_cast_record_binary(conn, want, rec):
     assert res == want
     for o1, o2 in zip(res, want):
         assert type(o1) is type(o2)
+
+
+@pytest.fixture(scope="session")
+def testcomp(svcconn):
+    cur = svcconn.cursor()
+    cur.execute(
+        """
+        drop type if exists testcomp cascade;
+        create type testcomp as (foo text, bar int8, baz float8);
+        """
+    )
+
+
+def test_fetch_info(conn, testcomp):
+    info = composite.fetch_info(conn, "testcomp")
+    assert info.name == "testcomp"
+    assert info.oid > 0
+    assert info.oid != info.array_oid > 0
+    assert len(info.fields) == 3
+    for i, (name, t) in enumerate(
+        [("foo", "text"), ("bar", "int8"), ("baz", "float8")]
+    ):
+        assert info.fields[i].name == name
+        assert info.fields[i].type_oid == builtins[t].oid
+
+
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_composite(conn, testcomp, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
+    info = composite.fetch_info(conn, "testcomp")
+    composite.register(info)
+
+    res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
+    assert res.foo == "hello"
+    assert res.bar == 10
+    assert res.baz == 20.0
+    assert isinstance(res.baz, float)
+
+    res = cur.execute(
+        "select array[row('hello', 10, 30)::testcomp]"
+    ).fetchone()[0]
+    assert len(res) == 1
+    assert res[0].baz == 30.0
+    assert isinstance(res[0].baz, float)
+
+
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_composite_factory(conn, testcomp, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
+    info = composite.fetch_info(conn, "testcomp")
+
+    class MyThing:
+        def __init__(self, *args):
+            self.foo, self.bar, self.baz = args
+
+    composite.register(info, factory=MyThing)
+
+    res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
+    assert isinstance(res, MyThing)
+    assert res.baz == 20.0
+    assert isinstance(res.baz, float)
+
+    res = cur.execute(
+        "select array[row('hello', 10, 30)::testcomp]"
+    ).fetchone()[0]
+    assert len(res) == 1
+    assert res[0].baz == 30.0
+    assert isinstance(res[0].baz, float)