]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fixed composite info fetch with homonymous types
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 5 Dec 2020 02:48:40 +0000 (02:48 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 5 Dec 2020 02:48:40 +0000 (02:48 +0000)
Use a fully qualified name, optionally expressed as sql.Identifier, to
find the type in the right schema.

psycopg3/psycopg3/types/composite.py
tests/types/test_composite.py

index 8d309a66f1b75676326199f91e8b6cda0b0976b4..49580d4acbc9a266817c2fa32fe4ca93010a7339 100644 (file)
@@ -7,9 +7,11 @@ Support for composite types adaptation.
 import re
 import struct
 from collections import namedtuple
-from typing import Any, Callable, Iterator, Sequence, Tuple, Type
+from typing import Any, Callable, Iterator, List, Sequence, Tuple, Type, Union
 from typing import Optional, TYPE_CHECKING
 
+from .. import sql
+from .. import errors as e
 from ..oids import builtins, TypeInfo
 from ..adapt import Format, Dumper, Loader, Transformer
 from ..proto import AdaptContext
@@ -36,29 +38,39 @@ class CompositeTypeInfo(TypeInfo):
         self.fields = list(fields)
 
     @classmethod
-    def _from_record(cls, rec: Any) -> Optional["CompositeTypeInfo"]:
-        if rec is None:
+    def _from_records(cls, recs: List[Any]) -> Optional["CompositeTypeInfo"]:
+        if not recs:
             return None
+        if len(recs) > 1:
+            raise e.ProgrammingError(
+                f"found {len(recs)} different types named {recs[0][0]}"
+            )
 
-        name, oid, array_oid, fnames, ftypes = rec
+        name, oid, array_oid, fnames, ftypes = recs[0]
         fields = [FieldInfo(*p) for p in zip(fnames, ftypes)]
         return CompositeTypeInfo(name, oid, array_oid, fields)
 
 
-def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]:
+def fetch_info(
+    conn: "Connection", name: Union[str, sql.Identifier]
+) -> Optional[CompositeTypeInfo]:
+    if isinstance(name, sql.Composable):
+        name = name.as_string(conn)
     cur = conn.cursor(format=Format.BINARY)
     cur.execute(_type_info_query, {"name": name})
-    rec = cur.fetchone()
-    return CompositeTypeInfo._from_record(rec)
+    recs = cur.fetchall()
+    return CompositeTypeInfo._from_records(recs)
 
 
 async def fetch_info_async(
-    conn: "AsyncConnection", name: str
+    conn: "AsyncConnection", name: Union[str, sql.Identifier]
 ) -> Optional[CompositeTypeInfo]:
+    if isinstance(name, sql.Composable):
+        name = name.as_string(conn)
     cur = await conn.cursor(format=Format.BINARY)
     await cur.execute(_type_info_query, {"name": name})
-    rec = await cur.fetchone()
-    return CompositeTypeInfo._from_record(rec)
+    recs = await cur.fetchall()
+    return CompositeTypeInfo._from_records(recs)
 
 
 def register(
@@ -113,14 +125,14 @@ left join (
         select a.attrelid, a.attname, a.atttypid
         from pg_attribute a
         join pg_type t on t.typrelid = a.attrelid
-        where t.typname = %(name)s
+        where t.oid = %(name)s::regtype
         and a.attnum > 0
         and not a.attisdropped
         order by a.attnum
     ) x
     group by attrelid
 ) a on a.attrelid = t.typrelid
-where t.typname = %(name)s
+where t.oid = %(name)s::regtype
 """
 
 
index ca71a7d39bca872a6810270f05b2b2ba35690ade..1ecb9c0c99b38176b6f96a926755109e0d28debb 100644 (file)
@@ -1,5 +1,6 @@
 import pytest
 
+from psycopg3.sql import Identifier
 from psycopg3.oids import builtins
 from psycopg3.adapt import Format, Loader
 from psycopg3.types import composite
@@ -96,35 +97,78 @@ def testcomp(svcconn):
     cur = svcconn.cursor()
     cur.execute(
         """
+        create schema if not exists testschema;
+
         drop type if exists testcomp cascade;
+        drop type if exists testschema.testcomp cascade;
+
         create type testcomp as (foo text, bar int8, baz float8);
+        create type testschema.testcomp as (foo text, bar int8, qux bool);
         """
     )
 
 
-def test_fetch_info(conn, testcomp):
-    info = composite.fetch_info(conn, "testcomp")
+@pytest.mark.parametrize(
+    "name, fields",
+    [
+        (
+            "testcomp",
+            [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+        ),
+        (
+            "testschema.testcomp",
+            [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+        ),
+        (
+            Identifier("testcomp"),
+            [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+        ),
+        (
+            Identifier("testschema", "testcomp"),
+            [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+        ),
+    ],
+)
+def test_fetch_info(conn, testcomp, name, fields):
+    info = composite.fetch_info(conn, name)
     assert info.name == "testcomp"
     assert info.oid > 0
     assert info.oid != info.array_oid > 0
     assert len(info.fields) == 3
-    for i, (name, t) in enumerate(
-        [("foo", "text"), ("bar", "int8"), ("baz", "float8")]
-    ):
+    for i, (name, t) in enumerate(fields):
         assert info.fields[i].name == name
         assert info.fields[i].type_oid == builtins[t].oid
 
 
+@pytest.mark.parametrize(
+    "name, fields",
+    [
+        (
+            "testcomp",
+            [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+        ),
+        (
+            "testschema.testcomp",
+            [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+        ),
+        (
+            Identifier("testcomp"),
+            [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+        ),
+        (
+            Identifier("testschema", "testcomp"),
+            [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+        ),
+    ],
+)
 @pytest.mark.asyncio
-async def test_fetch_info_async(aconn, testcomp):
-    info = await composite.fetch_info_async(aconn, "testcomp")
+async def test_fetch_info_async(aconn, testcomp, name, fields):
+    info = await composite.fetch_info_async(aconn, name)
     assert info.name == "testcomp"
     assert info.oid > 0
     assert info.oid != info.array_oid > 0
     assert len(info.fields) == 3
-    for i, (name, t) in enumerate(
-        [("foo", "text"), ("bar", "int8"), ("baz", "float8")]
-    ):
+    for i, (name, t) in enumerate(fields):
         assert info.fields[i].name == name
         assert info.fields[i].type_oid == builtins[t].oid