]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Replace 'TypesRegistry.get_range()' with more generic 'get_by_subtipe()'
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 30 Sep 2021 00:55:23 +0000 (00:55 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 30 Sep 2021 00:57:57 +0000 (00:57 +0000)
psycopg/psycopg/_typeinfo.py
psycopg/psycopg/types/range.py

index 62efefa8cd7a7d1dedd0f00376ee2b33fd63908e..e0902bdc506fc70ff9bf10113a655f810d53412a 100644 (file)
@@ -281,7 +281,15 @@ class TypesRegistry:
                 seen.add(t.oid)
                 yield t
 
+    @overload
     def __getitem__(self, key: Union[str, int]) -> TypeInfo:
+        ...
+
+    @overload
+    def __getitem__(self, key: Tuple[Type[T], int]) -> T:
+        ...
+
+    def __getitem__(self, key: RegistryKey) -> TypeInfo:
         """
         Return info about a type, specified by name or oid
 
@@ -292,7 +300,7 @@ class TypesRegistry:
         if isinstance(key, str):
             if key.endswith("[]"):
                 key = key[:-2]
-        elif not isinstance(key, int):
+        elif not isinstance(key, (int, tuple)):
             raise TypeError(
                 f"the key must be an oid or a name, got {type(key)}"
             )
@@ -303,7 +311,15 @@ class TypesRegistry:
                 f"couldn't find the type {key!r} in the types registry"
             )
 
+    @overload
     def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
+        ...
+
+    @overload
+    def get(self, key: Tuple[Type[T], int]) -> Optional[T]:
+        ...
+
+    def get(self, key: RegistryKey) -> Optional[TypeInfo]:
         """
         Return info about a type, specified by name or oid
 
@@ -332,17 +348,19 @@ class TypesRegistry:
         else:
             return t.oid
 
-    def get_range(self, key: Union[str, int]) -> Optional[TypeInfo]:
+    def get_by_subtype(
+        self, cls: Type[T], subtype: Union[int, str]
+    ) -> Optional[T]:
         """
-        Return info about a range by its element name or oid
+        Return info about a TypeInfo subclass by its element name or oid
 
         Return None if the element or its range are not found.
         """
         try:
-            info = self[key]
+            info = self[subtype]
         except KeyError:
             return None
-        return self._registry.get((RangeInfo, info.oid))
+        return self.get((cls, info.oid))
 
     def _ensure_own_state(self) -> None:
         # Time to write! so, copy.
index c8f1ba5161d1200a887df39a720e0e490e92d1dc..06b81a6fc6800d2c3d2858b18ca218cc9b7d6d72 100644 (file)
@@ -307,10 +307,8 @@ class BaseRangeDumper(RecursiveDumper):
     def _get_range_oid(self, sub_oid: int) -> int:
         """
         Return the oid of the range from the oid of its elements.
-
-        Raise InterfaceError if not found.
         """
-        info = self._tx.adapters.types.get_range(sub_oid)
+        info = self._tx.adapters.types.get_by_subtype(RangeInfo, sub_oid)
         return info.oid if info else INVALID_OID