from typing_extensions import TypeAlias
from . import errors as e
-from .abc import AdaptContext
+from .abc import AdaptContext, Query
from .rows import dict_row
+
if TYPE_CHECKING:
- from .connection import Connection
+ from .connection import BaseConnection, Connection
from .connection_async import AsyncConnection
- from .sql import Identifier
+ from .sql import Identifier, SQL
T = TypeVar("T", bound="TypeInfo")
RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]]
@classmethod
def fetch(
cls: Type[T],
- conn: "Union[Connection[Any], AsyncConnection[Any]]",
+ conn: Union["Connection[Any]", "AsyncConnection[Any]"],
name: Union[str, "Identifier"],
) -> Any:
"""Query a system catalog to read information about a type."""
register_array(self, context)
@classmethod
- def _get_info_query(
- cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
- ) -> str:
- return """\
+ def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
+ from .sql import SQL
+
+ return SQL(
+ """\
SELECT
typname AS name, oid, typarray AS array_oid,
oid::regtype::text AS regtype, typdelim AS delimiter
FROM pg_type t
-WHERE t.oid = %(name)s::regtype
+WHERE t.oid = {regtype}
ORDER BY t.oid
"""
+ ).format(regtype=cls._to_regtype(conn))
+
+ @classmethod
+ def _has_to_regtype_function(cls, conn: "BaseConnection[Any]") -> bool:
+ # introduced in PostgreSQL 9.4 and CockroachDB 22.2
+ info = conn.info
+ return info.vendor == "PostgreSQL" or (
+ info.vendor == "CockroachDB" and info.server_version >= 220200
+ )
+
+ @classmethod
+ def _to_regtype(cls, conn: "BaseConnection[Any]") -> "SQL":
+ from .sql import SQL
+
+ if cls._has_to_regtype_function(conn):
+ return SQL("to_regtype(%(name)s)")
+ else:
+ return SQL("%(name)s::regtype")
def _added(self, registry: "TypesRegistry") -> None:
"""Method called by the `!registry` when the object is added there."""
self.subtype_oid = subtype_oid
@classmethod
- def _get_info_query(
- cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
- ) -> str:
+ def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
+ # CockroachDB does not support range so no need to use _to_regtype
return """\
SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
t.oid::regtype::text AS regtype,
r.rngsubtype AS subtype_oid
FROM pg_type t
JOIN pg_range r ON t.oid = r.rngtypid
-WHERE t.oid = %(name)s::regtype
+WHERE t.oid = to_regtype(%(name)s)
"""
def _added(self, registry: "TypesRegistry") -> None:
self.subtype_oid = subtype_oid
@classmethod
- def _get_info_query(
- cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
- ) -> str:
+ def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
if conn.info.server_version < 140000:
raise e.NotSupportedError(
"multirange types are only available from PostgreSQL 14"
)
+ # CockroachDB does not support multirange so no need to use _to_regtype
return """\
SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
t.oid::regtype::text AS regtype,
r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid
FROM pg_type t
JOIN pg_range r ON t.oid = r.rngmultitypid
-WHERE t.oid = %(name)s::regtype
+WHERE t.oid = to_regtype(%(name)s)
"""
def _added(self, registry: "TypesRegistry") -> None:
self.python_type: Optional[type] = None
@classmethod
- def _get_info_query(
- cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
- ) -> str:
+ def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
+ # CockroachDB does not support composite so no need to use _to_regtype
return """\
SELECT
t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
SELECT a.attrelid, a.attname, a.atttypid
FROM pg_attribute a
JOIN pg_type t ON t.typrelid = a.attrelid
- WHERE t.oid = %(name)s::regtype
+ WHERE t.oid = to_regtype(%(name)s)
AND a.attnum > 0
AND NOT a.attisdropped
ORDER BY a.attnum
) x
GROUP BY attrelid
) a ON a.attrelid = t.typrelid
-WHERE t.oid = %(name)s::regtype
+WHERE t.oid = to_regtype(%(name)s)
"""
self.enum: Optional[Type[Enum]] = None
@classmethod
- def _get_info_query(
- cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
- ) -> str:
- return """\
+ def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
+ from .sql import SQL
+
+ return SQL(
+ """\
SELECT name, oid, array_oid, array_agg(label) AS labels
FROM (
SELECT
FROM pg_type t
LEFT JOIN pg_enum e
ON e.enumtypid = t.oid
- WHERE t.oid = %(name)s::regtype
+ WHERE t.oid = {regtype}
ORDER BY e.enumsortorder
) x
GROUP BY name, oid, array_oid
"""
+ ).format(regtype=cls._to_regtype(conn))
class TypesRegistry:
from psycopg import sql
from psycopg.pq import TransactionStatus
from psycopg.types import TypeInfo
+from psycopg.types.composite import CompositeInfo
+from psycopg.types.enum import EnumInfo
+from psycopg.types.multirange import MultirangeInfo
+from psycopg.types.range import RangeInfo
@pytest.mark.parametrize("name", ["text", sql.Identifier("text")])
assert info.array_oid == psycopg.adapters.types["text"].array_oid
-@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")])
-@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
-def test_fetch_not_found(conn, name, status):
+_name = pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")])
+_status = pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
+_info_cls = pytest.mark.parametrize(
+ "info_cls",
+ [
+ pytest.param(TypeInfo),
+ pytest.param(RangeInfo, marks=pytest.mark.crdb_skip("range")),
+ pytest.param(
+ MultirangeInfo,
+ marks=(pytest.mark.crdb_skip("range"), pytest.mark.pg(">= 14")),
+ ),
+ pytest.param(CompositeInfo, marks=pytest.mark.crdb_skip("composite")),
+ pytest.param(EnumInfo),
+ ],
+)
+
+
+@_name
+@_status
+@_info_cls
+def test_fetch_not_found(conn, name, status, info_cls, monkeypatch):
+
+ if TypeInfo._has_to_regtype_function(conn):
+ exit_orig = psycopg.Transaction.__exit__
+
+ def exit(self, exc_type, exc_val, exc_tb):
+ assert exc_val is None
+ return exit_orig(self, exc_type, exc_val, exc_tb)
+
+ monkeypatch.setattr(psycopg.Transaction, "__exit__", exit)
status = getattr(TransactionStatus, status)
if status == TransactionStatus.INTRANS:
conn.execute("select 1")
assert conn.info.transaction_status == status
- info = TypeInfo.fetch(conn, name)
+ info = info_cls.fetch(conn, name)
assert conn.info.transaction_status == status
assert info is None
@pytest.mark.asyncio
-@pytest.mark.parametrize("name", ["nosuch", sql.Identifier("nosuch")])
-@pytest.mark.parametrize("status", ["IDLE", "INTRANS"])
-async def test_fetch_not_found_async(aconn, name, status):
+@_name
+@_status
+@_info_cls
+async def test_fetch_not_found_async(aconn, name, status, info_cls, monkeypatch):
+
+ if TypeInfo._has_to_regtype_function(aconn):
+ exit_orig = psycopg.AsyncTransaction.__aexit__
+
+ async def aexit(self, exc_type, exc_val, exc_tb):
+ assert exc_val is None
+ return await exit_orig(self, exc_type, exc_val, exc_tb)
+
+ monkeypatch.setattr(psycopg.AsyncTransaction, "__aexit__", aexit)
status = getattr(TransactionStatus, status)
if status == TransactionStatus.INTRANS:
await aconn.execute("select 1")
assert aconn.info.transaction_status == status
- info = await TypeInfo.fetch(aconn, name)
+ info = await info_cls.fetch(aconn, name)
assert aconn.info.transaction_status == status
assert info is None