]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fixed problem with adapters decorators
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 29 Mar 2020 15:18:19 +0000 (04:18 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 29 Mar 2020 15:37:55 +0000 (04:37 +1300)
Now mypy --strict passes 100%

psycopg3/adaptation.py
psycopg3/types/numeric.py
psycopg3/types/text.py

index bf7ab98d6dadee78c87de792d0397af90d84ea98..b776da02f18824125f394a03a2a730d53a0933b9 100644 (file)
@@ -17,7 +17,6 @@ from typing import (
     Tuple,
     Union,
 )
-from functools import partial
 
 from . import exceptions as exc
 from .pq import Format, PGresult
@@ -54,14 +53,10 @@ class Adapter:
     @staticmethod
     def register(
         cls: type,
-        adapter: Optional[AdapterType] = None,
+        adapter: AdapterType,
         context: Optional[AdaptContext] = None,
         format: Format = Format.TEXT,
     ) -> AdapterType:
-        if adapter is None:
-            # used as decorator
-            return partial(Adapter.register, cls, format=format)
-
         if not isinstance(cls, type):
             raise TypeError(
                 f"adapters should be registered on classes, got {cls} instead"
@@ -72,7 +67,7 @@ class Adapter:
         ):
             raise TypeError(
                 f"the context should be a connection or cursor,"
-                f" got {type(context).__name__}"
+                f" got {type(context)}"
             )
 
         if not (
@@ -91,11 +86,27 @@ class Adapter:
     @staticmethod
     def register_binary(
         cls: type,
-        adapter: Optional[AdapterType] = None,
+        adapter: AdapterType,
         context: Optional[AdaptContext] = None,
     ) -> AdapterType:
         return Adapter.register(cls, adapter, context, format=Format.BINARY)
 
+    @staticmethod
+    def text(cls: type) -> Callable[[Any], Any]:
+        def register_adapter_(adapter: AdapterType) -> AdapterType:
+            Adapter.register(cls, adapter)
+            return adapter
+
+        return register_adapter_
+
+    @staticmethod
+    def binary(cls: type) -> Callable[[Any], Any]:
+        def register_binary_adapter_(adapter: AdapterType) -> AdapterType:
+            Adapter.register_binary(cls, adapter)
+            return adapter
+
+        return register_binary_adapter_
+
 
 class Typecaster:
     globals: TypecastersMap = {}
@@ -110,14 +121,10 @@ class Typecaster:
     @staticmethod
     def register(
         oid: int,
-        caster: Optional[TypecasterType] = None,
+        caster: TypecasterType,
         context: Optional[AdaptContext] = None,
         format: Format = Format.TEXT,
     ) -> TypecasterType:
-        if caster is None:
-            # used as decorator
-            return partial(Typecaster.register, oid, format=format)
-
         if not isinstance(oid, int):
             raise TypeError(
                 f"typecasters should be registered on oid, got {oid} instead"
@@ -128,7 +135,7 @@ class Typecaster:
         ):
             raise TypeError(
                 f"the context should be a connection or cursor,"
-                f" got {type(context).__name__}"
+                f" got {type(context)}"
             )
 
         if not (
@@ -147,11 +154,27 @@ class Typecaster:
     @staticmethod
     def register_binary(
         oid: int,
-        caster: Optional[TypecasterType] = None,
+        caster: TypecasterType,
         context: Optional[AdaptContext] = None,
     ) -> TypecasterType:
         return Typecaster.register(oid, caster, context, format=Format.BINARY)
 
+    @staticmethod
+    def text(oid: int) -> Callable[[Any], Any]:
+        def register_caster_(caster: TypecasterType) -> TypecasterType:
+            Typecaster.register(oid, caster)
+            return caster
+
+        return register_caster_
+
+    @staticmethod
+    def binary(oid: int) -> Callable[[Any], Any]:
+        def register_binary_caster_(caster: TypecasterType) -> TypecasterType:
+            Typecaster.register_binary(oid, caster)
+            return caster
+
+        return register_binary_caster_
+
 
 class Transformer:
     """
@@ -178,7 +201,7 @@ class Transformer:
         else:
             raise TypeError(
                 f"the context should be a connection or cursor,"
-                f" got {type(context).__name__}"
+                f" got {type(context)}"
             )
 
         # mapping class, fmt -> adaptation function
@@ -264,7 +287,7 @@ class Transformer:
             return Adapter.globals[key]
 
         raise exc.ProgrammingError(
-            f"cannot adapt type {cls.__name__} to format {Format(fmt).name}"
+            f"cannot adapt type {cls} to format {Format(fmt).name}"
         )
 
     def cast_row(self, result: PGresult, n: int) -> Generator[Any, None, None]:
@@ -305,7 +328,7 @@ class Transformer:
         return Typecaster.globals[INVALID_OID, fmt]
 
 
-@Typecaster.register(INVALID_OID)
+@Typecaster.text(INVALID_OID)
 class UnknownCaster(Typecaster):
     """
     Fallback object to convert unknown types to Python
@@ -323,6 +346,6 @@ class UnknownCaster(Typecaster):
         return self.decode(data)[0]
 
 
-@Typecaster.register_binary(INVALID_OID)
+@Typecaster.binary(INVALID_OID)
 def cast_unknown(data: bytes) -> bytes:
     return data
index b633de9241cee1591dd73f9f7751eda17ba35a17..dc1f74777be92511a3e848688de523beb3111d65 100644 (file)
@@ -14,11 +14,11 @@ _encode = codecs.lookup("ascii").encode
 _decode = codecs.lookup("ascii").decode
 
 
-@Adapter.register(int)
+@Adapter.text(int)
 def adapt_int(obj: int) -> Tuple[bytes, int]:
     return _encode(str(obj))[0], type_oid["numeric"]
 
 
-@Typecaster.register(type_oid["numeric"])
+@Typecaster.text(type_oid["numeric"])
 def cast_int(data: bytes) -> int:
     return int(_decode(data)[0])
index 2db4d1d300ecb771c3d36060ebcde573924bb75a..20910e71c2e10d2b753704f5d64c2892ffd42605 100644 (file)
@@ -7,14 +7,17 @@ Adapters of textual types.
 import codecs
 from typing import Optional, Union
 
-from ..adaptation import Adapter, Typecaster
+from ..adaptation import (
+    Adapter,
+    Typecaster,
+)
 from ..connection import BaseConnection
 from ..utils.typing import EncodeFunc, DecodeFunc
 from .oids import type_oid
 
 
-@Adapter.register(str)
-@Adapter.register_binary(str)
+@Adapter.text(str)
+@Adapter.binary(str)
 class StringAdapter(Adapter):
     def __init__(self, cls: type, conn: BaseConnection):
         super().__init__(cls, conn)
@@ -29,8 +32,8 @@ class StringAdapter(Adapter):
         return self._encode(obj)[0]
 
 
-@Typecaster.register(type_oid["text"])
-@Typecaster.register_binary(type_oid["text"])
+@Typecaster.text(type_oid["text"])
+@Typecaster.binary(type_oid["text"])
 class StringCaster(Typecaster):
 
     decode: Optional[DecodeFunc]