From: Daniele Varrazzo Date: Sat, 5 Dec 2020 02:48:40 +0000 (+0000) Subject: Fixed composite info fetch with homonymous types X-Git-Tag: 3.0.dev0~279 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=49ea2dda4520d6477d8753e841fb9579ca734a39;p=thirdparty%2Fpsycopg.git Fixed composite info fetch with homonymous types Use a fully qualified name, optionally expressed as sql.Identifier, to find the type in the right schema. --- diff --git a/psycopg3/psycopg3/types/composite.py b/psycopg3/psycopg3/types/composite.py index 8d309a66f..49580d4ac 100644 --- a/psycopg3/psycopg3/types/composite.py +++ b/psycopg3/psycopg3/types/composite.py @@ -7,9 +7,11 @@ Support for composite types adaptation. import re import struct from collections import namedtuple -from typing import Any, Callable, Iterator, Sequence, Tuple, Type +from typing import Any, Callable, Iterator, List, Sequence, Tuple, Type, Union from typing import Optional, TYPE_CHECKING +from .. import sql +from .. import errors as e from ..oids import builtins, TypeInfo from ..adapt import Format, Dumper, Loader, Transformer from ..proto import AdaptContext @@ -36,29 +38,39 @@ class CompositeTypeInfo(TypeInfo): self.fields = list(fields) @classmethod - def _from_record(cls, rec: Any) -> Optional["CompositeTypeInfo"]: - if rec is None: + def _from_records(cls, recs: List[Any]) -> Optional["CompositeTypeInfo"]: + if not recs: return None + if len(recs) > 1: + raise e.ProgrammingError( + f"found {len(recs)} different types named {recs[0][0]}" + ) - name, oid, array_oid, fnames, ftypes = rec + name, oid, array_oid, fnames, ftypes = recs[0] fields = [FieldInfo(*p) for p in zip(fnames, ftypes)] return CompositeTypeInfo(name, oid, array_oid, fields) -def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]: +def fetch_info( + conn: "Connection", name: Union[str, sql.Identifier] +) -> Optional[CompositeTypeInfo]: + if isinstance(name, sql.Composable): + name = name.as_string(conn) cur = conn.cursor(format=Format.BINARY) cur.execute(_type_info_query, {"name": name}) - rec = cur.fetchone() - return CompositeTypeInfo._from_record(rec) + recs = cur.fetchall() + return CompositeTypeInfo._from_records(recs) async def fetch_info_async( - conn: "AsyncConnection", name: str + conn: "AsyncConnection", name: Union[str, sql.Identifier] ) -> Optional[CompositeTypeInfo]: + if isinstance(name, sql.Composable): + name = name.as_string(conn) cur = await conn.cursor(format=Format.BINARY) await cur.execute(_type_info_query, {"name": name}) - rec = await cur.fetchone() - return CompositeTypeInfo._from_record(rec) + recs = await cur.fetchall() + return CompositeTypeInfo._from_records(recs) def register( @@ -113,14 +125,14 @@ left join ( select a.attrelid, a.attname, a.atttypid from pg_attribute a join pg_type t on t.typrelid = a.attrelid - where t.typname = %(name)s + where t.oid = %(name)s::regtype and a.attnum > 0 and not a.attisdropped order by a.attnum ) x group by attrelid ) a on a.attrelid = t.typrelid -where t.typname = %(name)s +where t.oid = %(name)s::regtype """ diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py index ca71a7d39..1ecb9c0c9 100644 --- a/tests/types/test_composite.py +++ b/tests/types/test_composite.py @@ -1,5 +1,6 @@ import pytest +from psycopg3.sql import Identifier from psycopg3.oids import builtins from psycopg3.adapt import Format, Loader from psycopg3.types import composite @@ -96,35 +97,78 @@ def testcomp(svcconn): cur = svcconn.cursor() cur.execute( """ + create schema if not exists testschema; + drop type if exists testcomp cascade; + drop type if exists testschema.testcomp cascade; + create type testcomp as (foo text, bar int8, baz float8); + create type testschema.testcomp as (foo text, bar int8, qux bool); """ ) -def test_fetch_info(conn, testcomp): - info = composite.fetch_info(conn, "testcomp") +@pytest.mark.parametrize( + "name, fields", + [ + ( + "testcomp", + [("foo", "text"), ("bar", "int8"), ("baz", "float8")], + ), + ( + "testschema.testcomp", + [("foo", "text"), ("bar", "int8"), ("qux", "bool")], + ), + ( + Identifier("testcomp"), + [("foo", "text"), ("bar", "int8"), ("baz", "float8")], + ), + ( + Identifier("testschema", "testcomp"), + [("foo", "text"), ("bar", "int8"), ("qux", "bool")], + ), + ], +) +def test_fetch_info(conn, testcomp, name, fields): + info = composite.fetch_info(conn, name) assert info.name == "testcomp" assert info.oid > 0 assert info.oid != info.array_oid > 0 assert len(info.fields) == 3 - for i, (name, t) in enumerate( - [("foo", "text"), ("bar", "int8"), ("baz", "float8")] - ): + for i, (name, t) in enumerate(fields): assert info.fields[i].name == name assert info.fields[i].type_oid == builtins[t].oid +@pytest.mark.parametrize( + "name, fields", + [ + ( + "testcomp", + [("foo", "text"), ("bar", "int8"), ("baz", "float8")], + ), + ( + "testschema.testcomp", + [("foo", "text"), ("bar", "int8"), ("qux", "bool")], + ), + ( + Identifier("testcomp"), + [("foo", "text"), ("bar", "int8"), ("baz", "float8")], + ), + ( + Identifier("testschema", "testcomp"), + [("foo", "text"), ("bar", "int8"), ("qux", "bool")], + ), + ], +) @pytest.mark.asyncio -async def test_fetch_info_async(aconn, testcomp): - info = await composite.fetch_info_async(aconn, "testcomp") +async def test_fetch_info_async(aconn, testcomp, name, fields): + info = await composite.fetch_info_async(aconn, name) assert info.name == "testcomp" assert info.oid > 0 assert info.oid != info.array_oid > 0 assert len(info.fields) == 3 - for i, (name, t) in enumerate( - [("foo", "text"), ("bar", "int8"), ("baz", "float8")] - ): + for i, (name, t) in enumerate(fields): assert info.fields[i].name == name assert info.fields[i].type_oid == builtins[t].oid