]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
style(typeinfo): more modern type annotations
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 3 May 2024 19:45:24 +0000 (21:45 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 15 May 2024 15:56:39 +0000 (17:56 +0200)
psycopg/psycopg/_typeinfo.py

index fc170492a5683340fe19fc8223f6040bdaa56fda..a95376bec22dfef6299de736dd571b069f108c6e 100644 (file)
@@ -7,8 +7,9 @@ information to the adapters if needed.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Dict, Iterator, Optional, overload
-from typing import Sequence, Tuple, Type, Union, TYPE_CHECKING
+from __future__ import annotations
+
+from typing import Any, Iterator, overload, Sequence, TYPE_CHECKING
 
 from . import sql
 from . import errors as e
@@ -23,7 +24,7 @@ if TYPE_CHECKING:
     from ._connection_base import BaseConnection
 
 T = TypeVar("T", bound="TypeInfo")
-RegistryKey: TypeAlias = Union[str, int, Tuple[type, int]]
+RegistryKey: TypeAlias = "str | int | tuple[type, int]"
 
 
 class TypeInfo:
@@ -57,18 +58,18 @@ class TypeInfo:
     @overload
     @classmethod
     def fetch(
-        cls: Type[T], conn: "Connection[Any]", name: Union[str, sql.Identifier]
-    ) -> Optional[T]: ...
+        cls: type[T], conn: Connection[Any], name: str | sql.Identifier
+    ) -> T | None: ...
 
     @overload
     @classmethod
     async def fetch(
-        cls: Type[T], conn: "AsyncConnection[Any]", name: Union[str, sql.Identifier]
-    ) -> Optional[T]: ...
+        cls: type[T], conn: AsyncConnection[Any], name: str | sql.Identifier
+    ) -> T | None: ...
 
     @classmethod
     def fetch(
-        cls: Type[T], conn: "BaseConnection[Any]", name: Union[str, sql.Identifier]
+        cls: type[T], conn: BaseConnection[Any], name: str | sql.Identifier
     ) -> Any:
         """Query a system catalog to read information about a type."""
         from .connection import Connection
@@ -87,7 +88,7 @@ class TypeInfo:
             )
 
     @classmethod
-    def _fetch(cls: Type[T], conn: "Connection[Any]", name: str) -> Optional[T]:
+    def _fetch(cls: type[T], conn: Connection[Any], name: str) -> T | None:
         # This might result in a nested transaction. What we want is to leave
         # the function with the connection in the state we found (either idle
         # or intrans)
@@ -106,8 +107,8 @@ class TypeInfo:
 
     @classmethod
     async def _fetch_async(
-        cls: Type[T], conn: "AsyncConnection[Any]", name: str
-    ) -> Optional[T]:
+        cls: type[T], conn: AsyncConnection[Any], name: str
+    ) -> T | None:
         try:
             from psycopg import AsyncCursor
 
@@ -124,8 +125,8 @@ class TypeInfo:
 
     @classmethod
     def _from_records(
-        cls: Type[T], name: str, recs: Sequence[Dict[str, Any]]
-    ) -> Optional[T]:
+        cls: type[T], name: str, recs: Sequence[dict[str, Any]]
+    ) -> T | None:
         if len(recs) == 1:
             return cls(**recs[0])
         elif not recs:
@@ -133,7 +134,7 @@ class TypeInfo:
         else:
             raise e.ProgrammingError(f"found {len(recs)} different types named {name}")
 
-    def register(self, context: Optional[AdaptContext] = None) -> None:
+    def register(self, context: AdaptContext | None = None) -> None:
         """
         Register the type information, globally or in the specified `!context`.
         """
@@ -152,7 +153,7 @@ class TypeInfo:
             register_array(self, context)
 
     @classmethod
-    def _get_info_query(cls, conn: "BaseConnection[Any]") -> Query:
+    def _get_info_query(cls, conn: BaseConnection[Any]) -> Query:
         return sql.SQL(
             """\
 SELECT
@@ -165,7 +166,7 @@ ORDER BY t.oid
         ).format(regtype=cls._to_regtype(conn))
 
     @classmethod
-    def _has_to_regtype_function(cls, conn: "BaseConnection[Any]") -> bool:
+    def _has_to_regtype_function(cls, conn: BaseConnection[Any]) -> bool:
         # to_regtype() introduced in PostgreSQL 9.4 and CockroachDB 22.2
         info = conn.info
         if info.vendor == "PostgreSQL":
@@ -176,7 +177,7 @@ ORDER BY t.oid
             return False
 
     @classmethod
-    def _to_regtype(cls, conn: "BaseConnection[Any]") -> sql.SQL:
+    def _to_regtype(cls, conn: BaseConnection[Any]) -> sql.SQL:
         # `to_regtype()` returns the type oid or NULL, unlike the :: operator,
         # which returns the type or raises an exception, which requires
         # a transaction rollback and leaves traces in the server logs.
@@ -186,7 +187,7 @@ ORDER BY t.oid
         else:
             return sql.SQL("%(name)s::regtype")
 
-    def _added(self, registry: "TypesRegistry") -> None:
+    def _added(self, registry: TypesRegistry) -> None:
         """Method called by the `!registry` when the object is added there."""
         pass
 
@@ -198,8 +199,8 @@ class TypesRegistry:
 
     __module__ = "psycopg.types"
 
-    def __init__(self, template: Optional["TypesRegistry"] = None):
-        self._registry: Dict[RegistryKey, TypeInfo]
+    def __init__(self, template: TypesRegistry | None = None):
+        self._registry: dict[RegistryKey, TypeInfo]
 
         # Make a shallow copy: it will become a proper copy if the registry
         # is edited.
@@ -236,10 +237,10 @@ class TypesRegistry:
                 yield t
 
     @overload
-    def __getitem__(self, key: Union[str, int]) -> TypeInfo: ...
+    def __getitem__(self, key: str | int) -> TypeInfo: ...
 
     @overload
-    def __getitem__(self, key: Tuple[Type[T], int]) -> T: ...
+    def __getitem__(self, key: tuple[type[T], int]) -> T: ...
 
     def __getitem__(self, key: RegistryKey) -> TypeInfo:
         """
@@ -260,12 +261,12 @@ class TypesRegistry:
             raise KeyError(f"couldn't find the type {key!r} in the types registry")
 
     @overload
-    def get(self, key: Union[str, int]) -> Optional[TypeInfo]: ...
+    def get(self, key: str | int) -> TypeInfo | None: ...
 
     @overload
-    def get(self, key: Tuple[Type[T], int]) -> Optional[T]: ...
+    def get(self, key: tuple[type[T], int]) -> T | None: ...
 
-    def get(self, key: RegistryKey) -> Optional[TypeInfo]:
+    def get(self, key: RegistryKey) -> TypeInfo | None:
         """
         Return info about a type, specified by name or oid
 
@@ -294,7 +295,7 @@ class TypesRegistry:
         else:
             return t.oid
 
-    def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]:
+    def get_by_subtype(self, cls: type[T], subtype: int | str) -> T | None:
         """
         Return info about a `TypeInfo` subclass by its element name or oid.