From: Daniele Varrazzo Date: Sat, 28 Mar 2020 12:06:09 +0000 (+1300) Subject: Added adapter/caster decorators X-Git-Tag: 3.0.dev0~660 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=518fcc9655790e50fe58c5e30e50349cc77543cf;p=thirdparty%2Fpsycopg.git Added adapter/caster decorators --- diff --git a/psycopg3/adaptation.py b/psycopg3/adaptation.py index 293a47d00..3404b3c83 100644 --- a/psycopg3/adaptation.py +++ b/psycopg3/adaptation.py @@ -24,15 +24,27 @@ global_casters = {} def register_adapter(cls, adapter, context=None, format=Format.TEXT): + if not isinstance(cls, type): + raise TypeError( + f"adapters should be registered on classes, got {cls} instead" + ) if context is not None and not isinstance( context, (BaseConnection, BaseCursor) ): raise TypeError( - f"the context should be a connection or cursor;" + f"the context should be a connection or cursor," f" got {type(context).__name__}" ) + if not ( + callable(adapter) + or (isinstance(adapter, type) and issubclass(adapter, Adapter)) + ): + raise TypeError( + f"adapters should be callable or Adapter subclasses, got {adapter}" + ) + where = context.adapters if context is not None else global_adapters where[cls, format] = adapter @@ -42,14 +54,27 @@ def register_binary_adapter(cls, adapter, context=None): def register_caster(oid, caster, context=None, format=Format.TEXT): + if not isinstance(oid, int): + raise TypeError( + f"typecasters should be registered on oid, got {oid} instead" + ) + if context is not None and not isinstance( context, (BaseConnection, BaseCursor) ): raise TypeError( - f"the context should be a connection or cursor;" + f"the context should be a connection or cursor," f" got {type(context).__name__}" ) + if not ( + callable(caster) + or (isinstance(caster, type) and issubclass(caster, Typecaster)) + ): + raise TypeError( + f"adapters should be callable or Typecaster subclasses, got {caster}" + ) + where = context.adapters if context is not None else global_casters where[oid, format] = caster @@ -58,6 +83,38 @@ def register_binary_caster(oid, caster, context=None): register_caster(oid, caster, context, format=Format.BINARY) +def adapter(oid): + def adapter_(obj): + register_adapter(oid, obj) + return obj + + return adapter_ + + +def binary_adapter(oid): + def binary_adapter_(obj): + register_binary_adapter(oid, obj) + return obj + + return binary_adapter_ + + +def caster(oid): + def caster_(obj): + register_caster(oid, obj) + return obj + + return caster_ + + +def binary_caster(oid): + def binary_caster_(obj): + register_binary_caster(oid, obj) + return obj + + return binary_caster_ + + class Transformer: """ An object that can adapt efficiently between Python and PostgreSQL. @@ -223,6 +280,8 @@ class Typecaster: raise NotImplementedError() +@adapter(str) +@binary_adapter(str) class StringAdapter(Adapter): def __init__(self, cls, conn): super().__init__(cls, conn) @@ -232,6 +291,8 @@ class StringAdapter(Adapter): return self.encode(obj)[0] +@caster(TEXT_OID) +@binary_caster(TEXT_OID) class StringCaster(Typecaster): def __init__(self, oid, conn): super().__init__(oid, conn) @@ -251,25 +312,17 @@ class StringCaster(Typecaster): return data -register_adapter(str, StringAdapter) -register_binary_adapter(str, StringAdapter) - -register_caster(TEXT_OID, StringCaster) -register_binary_caster(TEXT_OID, StringCaster) - - +@adapter(int) def adapt_int(obj): return ascii_encode(str(obj))[0], NUMERIC_OID +@caster(NUMERIC_OID) def cast_int(data): return int(ascii_decode(data)[0]) -register_adapter(int, adapt_int) -register_caster(NUMERIC_OID, cast_int) - - +@caster(INVALID_OID) class UnknownCaster(Typecaster): """ Fallback object to convert unknown types to Python @@ -286,9 +339,6 @@ class UnknownCaster(Typecaster): return self.decode(data)[0] -def binary_cast_unknown(data): +@binary_caster(INVALID_OID) +def cast_unknown(data): return data - - -register_caster(INVALID_OID, UnknownCaster) -register_binary_caster(INVALID_OID, binary_cast_unknown)