]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Simplify the types registry using a single map
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 29 Sep 2021 20:01:31 +0000 (20:01 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 30 Sep 2021 00:57:57 +0000 (00:57 +0000)
psycopg/psycopg/_typeinfo.py

index 93380e59c211d5f056250d6d9ff5ae69480cc600..62efefa8cd7a7d1dedd0f00376ee2b33fd63908e 100644 (file)
@@ -8,7 +8,7 @@ information to the adapters if needed.
 # Copyright (C) 2020-2021 The Psycopg Team
 
 from typing import Any, Dict, Iterator, Optional, overload
-from typing import Sequence, Type, TypeVar, Union, TYPE_CHECKING
+from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
 
 from . import errors as e
 from .abc import AdaptContext
@@ -186,7 +186,7 @@ WHERE t.oid = %(name)s::regtype
     def _added(self, registry: "TypesRegistry") -> None:
         """Method called by the *registry* when the object is added there."""
         # Map ranges subtypes to info
-        registry._by_range_subtype[self.subtype_oid] = self
+        registry._registry[RangeInfo, self.subtype_oid] = self
 
 
 class CompositeInfo(TypeInfo):
@@ -234,6 +234,9 @@ WHERE t.oid = %(name)s::regtype
 """
 
 
+RegistryKey = Union[str, int, Tuple[type, int]]
+
+
 class TypesRegistry:
     """
     Container for the information about types in a database.
@@ -242,44 +245,38 @@ class TypesRegistry:
     __module__ = "psycopg.types"
 
     def __init__(self, template: Optional["TypesRegistry"] = None):
-        self._by_oid: Dict[int, TypeInfo]
-        self._by_name: Dict[str, TypeInfo]
-        self._by_range_subtype: Dict[int, TypeInfo]
+        self._registry: Dict[RegistryKey, TypeInfo]
 
         # Make a shallow copy: it will become a proper copy if the registry
         # is edited.
         if template:
-            self._by_oid = template._by_oid
-            self._by_name = template._by_name
-            self._by_range_subtype = template._by_range_subtype
+            self._registry = template._registry
             self._own_state = False
             template._own_state = False
         else:
             self.clear()
 
     def clear(self) -> None:
-        self._by_oid = {}
-        self._by_name = {}
-        self._by_range_subtype = {}
+        self._registry = {}
         self._own_state = True
 
     def add(self, info: TypeInfo) -> None:
         self._ensure_own_state()
         if info.oid:
-            self._by_oid[info.oid] = info
+            self._registry[info.oid] = info
         if info.array_oid:
-            self._by_oid[info.array_oid] = info
-        self._by_name[info.name] = info
+            self._registry[info.array_oid] = info
+        self._registry[info.name] = info
 
-        if info.alt_name and info.alt_name not in self._by_name:
-            self._by_name[info.alt_name] = info
+        if info.alt_name and info.alt_name not in self._registry:
+            self._registry[info.alt_name] = info
 
         # Allow info to customise further their relation with the registry
         info._added(self)
 
     def __iter__(self) -> Iterator[TypeInfo]:
         seen = set()
-        for t in self._by_oid.values():
+        for t in self._registry.values():
             if t.oid not in seen:
                 seen.add(t.oid)
                 yield t
@@ -292,17 +289,15 @@ class TypesRegistry:
 
         Raise KeyError if not found.
         """
+        if isinstance(key, str):
+            if key.endswith("[]"):
+                key = key[:-2]
+        elif not isinstance(key, int):
+            raise TypeError(
+                f"the key must be an oid or a name, got {type(key)}"
+            )
         try:
-            if isinstance(key, str):
-                if key.endswith("[]"):
-                    key = key[:-2]
-                return self._by_name[key]
-            elif isinstance(key, int):
-                return self._by_oid[key]
-            else:
-                raise TypeError(
-                    f"the key must be an oid or a name, got {type(key)}"
-                )
+            return self._registry[key]
         except KeyError:
             raise KeyError(
                 f"couldn't find the type {key!r} in the types registry"
@@ -347,12 +342,10 @@ class TypesRegistry:
             info = self[key]
         except KeyError:
             return None
-        return self._by_range_subtype.get(info.oid)
+        return self._registry.get((RangeInfo, info.oid))
 
     def _ensure_own_state(self) -> None:
         # Time to write! so, copy.
         if not self._own_state:
-            self._by_oid = self._by_oid.copy()
-            self._by_name = self._by_name.copy()
-            self._by_range_subtype = self._by_range_subtype.copy()
+            self._registry = self._registry.copy()
             self._own_state = True