# Copyright (C) 2020-2021 The Psycopg Team
from typing import Any, Dict, Iterator, Optional, overload
-from typing import Sequence, Type, TypeVar, Union, TYPE_CHECKING
+from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
from . import errors as e
from .abc import AdaptContext
def _added(self, registry: "TypesRegistry") -> None:
"""Method called by the *registry* when the object is added there."""
# Map ranges subtypes to info
- registry._by_range_subtype[self.subtype_oid] = self
+ registry._registry[RangeInfo, self.subtype_oid] = self
class CompositeInfo(TypeInfo):
"""
+RegistryKey = Union[str, int, Tuple[type, int]]
+
+
class TypesRegistry:
"""
Container for the information about types in a database.
__module__ = "psycopg.types"
def __init__(self, template: Optional["TypesRegistry"] = None):
- self._by_oid: Dict[int, TypeInfo]
- self._by_name: Dict[str, TypeInfo]
- self._by_range_subtype: Dict[int, TypeInfo]
+ self._registry: Dict[RegistryKey, TypeInfo]
# Make a shallow copy: it will become a proper copy if the registry
# is edited.
if template:
- self._by_oid = template._by_oid
- self._by_name = template._by_name
- self._by_range_subtype = template._by_range_subtype
+ self._registry = template._registry
self._own_state = False
template._own_state = False
else:
self.clear()
def clear(self) -> None:
- self._by_oid = {}
- self._by_name = {}
- self._by_range_subtype = {}
+ self._registry = {}
self._own_state = True
def add(self, info: TypeInfo) -> None:
self._ensure_own_state()
if info.oid:
- self._by_oid[info.oid] = info
+ self._registry[info.oid] = info
if info.array_oid:
- self._by_oid[info.array_oid] = info
- self._by_name[info.name] = info
+ self._registry[info.array_oid] = info
+ self._registry[info.name] = info
- if info.alt_name and info.alt_name not in self._by_name:
- self._by_name[info.alt_name] = info
+ if info.alt_name and info.alt_name not in self._registry:
+ self._registry[info.alt_name] = info
# Allow info to customise further their relation with the registry
info._added(self)
def __iter__(self) -> Iterator[TypeInfo]:
seen = set()
- for t in self._by_oid.values():
+ for t in self._registry.values():
if t.oid not in seen:
seen.add(t.oid)
yield t
Raise KeyError if not found.
"""
+ if isinstance(key, str):
+ if key.endswith("[]"):
+ key = key[:-2]
+ elif not isinstance(key, int):
+ raise TypeError(
+ f"the key must be an oid or a name, got {type(key)}"
+ )
try:
- if isinstance(key, str):
- if key.endswith("[]"):
- key = key[:-2]
- return self._by_name[key]
- elif isinstance(key, int):
- return self._by_oid[key]
- else:
- raise TypeError(
- f"the key must be an oid or a name, got {type(key)}"
- )
+ return self._registry[key]
except KeyError:
raise KeyError(
f"couldn't find the type {key!r} in the types registry"
info = self[key]
except KeyError:
return None
- return self._by_range_subtype.get(info.oid)
+ return self._registry.get((RangeInfo, info.oid))
def _ensure_own_state(self) -> None:
# Time to write! so, copy.
if not self._own_state:
- self._by_oid = self._by_oid.copy()
- self._by_name = self._by_name.copy()
- self._by_range_subtype = self._by_range_subtype.copy()
+ self._registry = self._registry.copy()
self._own_state = True