import re
import struct
from collections import namedtuple
-from typing import Any, Callable, Iterator, Sequence, Tuple, Type
+from typing import Any, Callable, Iterator, List, Sequence, Tuple, Type, Union
from typing import Optional, TYPE_CHECKING
+from .. import sql
+from .. import errors as e
from ..oids import builtins, TypeInfo
from ..adapt import Format, Dumper, Loader, Transformer
from ..proto import AdaptContext
self.fields = list(fields)
@classmethod
- def _from_record(cls, rec: Any) -> Optional["CompositeTypeInfo"]:
- if rec is None:
+ def _from_records(cls, recs: List[Any]) -> Optional["CompositeTypeInfo"]:
+ if not recs:
return None
+ if len(recs) > 1:
+ raise e.ProgrammingError(
+ f"found {len(recs)} different types named {recs[0][0]}"
+ )
- name, oid, array_oid, fnames, ftypes = rec
+ name, oid, array_oid, fnames, ftypes = recs[0]
fields = [FieldInfo(*p) for p in zip(fnames, ftypes)]
return CompositeTypeInfo(name, oid, array_oid, fields)
-def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]:
+def fetch_info(
+ conn: "Connection", name: Union[str, sql.Identifier]
+) -> Optional[CompositeTypeInfo]:
+ if isinstance(name, sql.Composable):
+ name = name.as_string(conn)
cur = conn.cursor(format=Format.BINARY)
cur.execute(_type_info_query, {"name": name})
- rec = cur.fetchone()
- return CompositeTypeInfo._from_record(rec)
+ recs = cur.fetchall()
+ return CompositeTypeInfo._from_records(recs)
async def fetch_info_async(
- conn: "AsyncConnection", name: str
+ conn: "AsyncConnection", name: Union[str, sql.Identifier]
) -> Optional[CompositeTypeInfo]:
+ if isinstance(name, sql.Composable):
+ name = name.as_string(conn)
cur = await conn.cursor(format=Format.BINARY)
await cur.execute(_type_info_query, {"name": name})
- rec = await cur.fetchone()
- return CompositeTypeInfo._from_record(rec)
+ recs = await cur.fetchall()
+ return CompositeTypeInfo._from_records(recs)
def register(
select a.attrelid, a.attname, a.atttypid
from pg_attribute a
join pg_type t on t.typrelid = a.attrelid
- where t.typname = %(name)s
+ where t.oid = %(name)s::regtype
and a.attnum > 0
and not a.attisdropped
order by a.attnum
) x
group by attrelid
) a on a.attrelid = t.typrelid
-where t.typname = %(name)s
+where t.oid = %(name)s::regtype
"""
import pytest
+from psycopg3.sql import Identifier
from psycopg3.oids import builtins
from psycopg3.adapt import Format, Loader
from psycopg3.types import composite
cur = svcconn.cursor()
cur.execute(
"""
+ create schema if not exists testschema;
+
drop type if exists testcomp cascade;
+ drop type if exists testschema.testcomp cascade;
+
create type testcomp as (foo text, bar int8, baz float8);
+ create type testschema.testcomp as (foo text, bar int8, qux bool);
"""
)
-def test_fetch_info(conn, testcomp):
- info = composite.fetch_info(conn, "testcomp")
+@pytest.mark.parametrize(
+ "name, fields",
+ [
+ (
+ "testcomp",
+ [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+ ),
+ (
+ "testschema.testcomp",
+ [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+ ),
+ (
+ Identifier("testcomp"),
+ [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+ ),
+ (
+ Identifier("testschema", "testcomp"),
+ [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+ ),
+ ],
+)
+def test_fetch_info(conn, testcomp, name, fields):
+ info = composite.fetch_info(conn, name)
assert info.name == "testcomp"
assert info.oid > 0
assert info.oid != info.array_oid > 0
assert len(info.fields) == 3
- for i, (name, t) in enumerate(
- [("foo", "text"), ("bar", "int8"), ("baz", "float8")]
- ):
+ for i, (name, t) in enumerate(fields):
assert info.fields[i].name == name
assert info.fields[i].type_oid == builtins[t].oid
+@pytest.mark.parametrize(
+ "name, fields",
+ [
+ (
+ "testcomp",
+ [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+ ),
+ (
+ "testschema.testcomp",
+ [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+ ),
+ (
+ Identifier("testcomp"),
+ [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+ ),
+ (
+ Identifier("testschema", "testcomp"),
+ [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+ ),
+ ],
+)
@pytest.mark.asyncio
-async def test_fetch_info_async(aconn, testcomp):
- info = await composite.fetch_info_async(aconn, "testcomp")
+async def test_fetch_info_async(aconn, testcomp, name, fields):
+ info = await composite.fetch_info_async(aconn, name)
assert info.name == "testcomp"
assert info.oid > 0
assert info.oid != info.array_oid > 0
assert len(info.fields) == 3
- for i, (name, t) in enumerate(
- [("foo", "text"), ("bar", "int8"), ("baz", "float8")]
- ):
+ for i, (name, t) in enumerate(fields):
assert info.fields[i].name == name
assert info.fields[i].type_oid == builtins[t].oid