From: Daniele Varrazzo Date: Sun, 29 Mar 2020 15:18:19 +0000 (+1300) Subject: Fixed problem with adapters decorators X-Git-Tag: 3.0.dev0~647 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=70f76c71b62b28cd9e12c690dcac78cc1258f186;p=thirdparty%2Fpsycopg.git Fixed problem with adapters decorators Now mypy --strict passes 100% --- diff --git a/psycopg3/adaptation.py b/psycopg3/adaptation.py index bf7ab98d6..b776da02f 100644 --- a/psycopg3/adaptation.py +++ b/psycopg3/adaptation.py @@ -17,7 +17,6 @@ from typing import ( Tuple, Union, ) -from functools import partial from . import exceptions as exc from .pq import Format, PGresult @@ -54,14 +53,10 @@ class Adapter: @staticmethod def register( cls: type, - adapter: Optional[AdapterType] = None, + adapter: AdapterType, context: Optional[AdaptContext] = None, format: Format = Format.TEXT, ) -> AdapterType: - if adapter is None: - # used as decorator - return partial(Adapter.register, cls, format=format) - if not isinstance(cls, type): raise TypeError( f"adapters should be registered on classes, got {cls} instead" @@ -72,7 +67,7 @@ class Adapter: ): raise TypeError( f"the context should be a connection or cursor," - f" got {type(context).__name__}" + f" got {type(context)}" ) if not ( @@ -91,11 +86,27 @@ class Adapter: @staticmethod def register_binary( cls: type, - adapter: Optional[AdapterType] = None, + adapter: AdapterType, context: Optional[AdaptContext] = None, ) -> AdapterType: return Adapter.register(cls, adapter, context, format=Format.BINARY) + @staticmethod + def text(cls: type) -> Callable[[Any], Any]: + def register_adapter_(adapter: AdapterType) -> AdapterType: + Adapter.register(cls, adapter) + return adapter + + return register_adapter_ + + @staticmethod + def binary(cls: type) -> Callable[[Any], Any]: + def register_binary_adapter_(adapter: AdapterType) -> AdapterType: + Adapter.register_binary(cls, adapter) + return adapter + + return register_binary_adapter_ + class Typecaster: globals: TypecastersMap = {} @@ -110,14 +121,10 @@ class Typecaster: @staticmethod def register( oid: int, - caster: Optional[TypecasterType] = None, + caster: TypecasterType, context: Optional[AdaptContext] = None, format: Format = Format.TEXT, ) -> TypecasterType: - if caster is None: - # used as decorator - return partial(Typecaster.register, oid, format=format) - if not isinstance(oid, int): raise TypeError( f"typecasters should be registered on oid, got {oid} instead" @@ -128,7 +135,7 @@ class Typecaster: ): raise TypeError( f"the context should be a connection or cursor," - f" got {type(context).__name__}" + f" got {type(context)}" ) if not ( @@ -147,11 +154,27 @@ class Typecaster: @staticmethod def register_binary( oid: int, - caster: Optional[TypecasterType] = None, + caster: TypecasterType, context: Optional[AdaptContext] = None, ) -> TypecasterType: return Typecaster.register(oid, caster, context, format=Format.BINARY) + @staticmethod + def text(oid: int) -> Callable[[Any], Any]: + def register_caster_(caster: TypecasterType) -> TypecasterType: + Typecaster.register(oid, caster) + return caster + + return register_caster_ + + @staticmethod + def binary(oid: int) -> Callable[[Any], Any]: + def register_binary_caster_(caster: TypecasterType) -> TypecasterType: + Typecaster.register_binary(oid, caster) + return caster + + return register_binary_caster_ + class Transformer: """ @@ -178,7 +201,7 @@ class Transformer: else: raise TypeError( f"the context should be a connection or cursor," - f" got {type(context).__name__}" + f" got {type(context)}" ) # mapping class, fmt -> adaptation function @@ -264,7 +287,7 @@ class Transformer: return Adapter.globals[key] raise exc.ProgrammingError( - f"cannot adapt type {cls.__name__} to format {Format(fmt).name}" + f"cannot adapt type {cls} to format {Format(fmt).name}" ) def cast_row(self, result: PGresult, n: int) -> Generator[Any, None, None]: @@ -305,7 +328,7 @@ class Transformer: return Typecaster.globals[INVALID_OID, fmt] -@Typecaster.register(INVALID_OID) +@Typecaster.text(INVALID_OID) class UnknownCaster(Typecaster): """ Fallback object to convert unknown types to Python @@ -323,6 +346,6 @@ class UnknownCaster(Typecaster): return self.decode(data)[0] -@Typecaster.register_binary(INVALID_OID) +@Typecaster.binary(INVALID_OID) def cast_unknown(data: bytes) -> bytes: return data diff --git a/psycopg3/types/numeric.py b/psycopg3/types/numeric.py index b633de924..dc1f74777 100644 --- a/psycopg3/types/numeric.py +++ b/psycopg3/types/numeric.py @@ -14,11 +14,11 @@ _encode = codecs.lookup("ascii").encode _decode = codecs.lookup("ascii").decode -@Adapter.register(int) +@Adapter.text(int) def adapt_int(obj: int) -> Tuple[bytes, int]: return _encode(str(obj))[0], type_oid["numeric"] -@Typecaster.register(type_oid["numeric"]) +@Typecaster.text(type_oid["numeric"]) def cast_int(data: bytes) -> int: return int(_decode(data)[0]) diff --git a/psycopg3/types/text.py b/psycopg3/types/text.py index 2db4d1d30..20910e71c 100644 --- a/psycopg3/types/text.py +++ b/psycopg3/types/text.py @@ -7,14 +7,17 @@ Adapters of textual types. import codecs from typing import Optional, Union -from ..adaptation import Adapter, Typecaster +from ..adaptation import ( + Adapter, + Typecaster, +) from ..connection import BaseConnection from ..utils.typing import EncodeFunc, DecodeFunc from .oids import type_oid -@Adapter.register(str) -@Adapter.register_binary(str) +@Adapter.text(str) +@Adapter.binary(str) class StringAdapter(Adapter): def __init__(self, cls: type, conn: BaseConnection): super().__init__(cls, conn) @@ -29,8 +32,8 @@ class StringAdapter(Adapter): return self._encode(obj)[0] -@Typecaster.register(type_oid["text"]) -@Typecaster.register_binary(type_oid["text"]) +@Typecaster.text(type_oid["text"]) +@Typecaster.binary(type_oid["text"]) class StringCaster(Typecaster): decode: Optional[DecodeFunc]