]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Adapter/caster register functions can be used as decorators
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 28 Mar 2020 15:02:51 +0000 (04:02 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 28 Mar 2020 15:02:51 +0000 (04:02 +1300)
psycopg3/adaptation.py
psycopg3/types/numeric.py
psycopg3/types/text.py

index a8be71ff485799809acd7f362b559d69a3a32ea0..b13e013ec9c1d16cddcf5ae210b875d626de0666 100644 (file)
@@ -5,6 +5,7 @@ Entry point into the adaptation system.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
+from functools import partial
 
 from . import exceptions as exc
 from .pq import Format
@@ -26,7 +27,11 @@ class Adapter:
         raise NotImplementedError()
 
     @staticmethod
-    def register(cls, adapter, context=None, format=Format.TEXT):
+    def register(cls, adapter=None, context=None, format=Format.TEXT):
+        if adapter is None:
+            # used as decorator
+            return partial(Adapter.register, cls, format=format)
+
         if not isinstance(cls, type):
             raise TypeError(
                 f"adapters should be registered on classes, got {cls} instead"
@@ -51,10 +56,11 @@ class Adapter:
 
         where = context.adapters if context is not None else Adapter.globals
         where[cls, format] = adapter
+        return adapter
 
     @staticmethod
-    def register_binary(cls, adapter, context=None):
-        Adapter.register(cls, adapter, context, format=Format.BINARY)
+    def register_binary(cls, adapter=None, context=None):
+        return Adapter.register(cls, adapter, context, format=Format.BINARY)
 
 
 class Typecaster:
@@ -68,7 +74,11 @@ class Typecaster:
         raise NotImplementedError()
 
     @staticmethod
-    def register(oid, caster, context=None, format=Format.TEXT):
+    def register(oid, caster=None, context=None, format=Format.TEXT):
+        if caster is None:
+            # used as decorator
+            return partial(Typecaster.register, oid, format=format)
+
         if not isinstance(oid, int):
             raise TypeError(
                 f"typecasters should be registered on oid, got {oid} instead"
@@ -93,42 +103,11 @@ class Typecaster:
 
         where = context.adapters if context is not None else Typecaster.globals
         where[oid, format] = caster
+        return caster
 
     @staticmethod
-    def register_binary(oid, caster, context=None):
-        Typecaster.register(oid, caster, context, format=Format.BINARY)
-
-
-def adapter(oid):
-    def adapter_(obj):
-        Adapter.register(oid, obj)
-        return obj
-
-    return adapter_
-
-
-def binary_adapter(oid):
-    def binary_adapter_(obj):
-        Adapter.register_binary(oid, obj)
-        return obj
-
-    return binary_adapter_
-
-
-def caster(oid):
-    def caster_(obj):
-        Typecaster.register(oid, obj)
-        return obj
-
-    return caster_
-
-
-def binary_caster(oid):
-    def binary_caster_(obj):
-        Typecaster.register_binary(oid, obj)
-        return obj
-
-    return binary_caster_
+    def register_binary(oid, caster=None, context=None):
+        return Typecaster.register(oid, caster, context, format=Format.BINARY)
 
 
 class Transformer:
@@ -278,7 +257,7 @@ class Transformer:
         return Typecaster.globals[INVALID_OID, fmt]
 
 
-@caster(INVALID_OID)
+@Typecaster.register(INVALID_OID)
 class UnknownCaster(Typecaster):
     """
     Fallback object to convert unknown types to Python
@@ -295,6 +274,6 @@ class UnknownCaster(Typecaster):
         return self.decode(data)[0]
 
 
-@binary_caster(INVALID_OID)
+@Typecaster.register_binary(INVALID_OID)
 def cast_unknown(data):
     return data
index ecf179ff40d1a4343bc690a602cd89d40a849d43..60d8a25aeced360ec177771dc21931fa13906f1b 100644 (file)
@@ -6,15 +6,15 @@ Adapters of numeric types.
 
 import codecs
 
-from ..adaptation import adapter, caster
+from ..adaptation import Adapter, Typecaster
 from .oids import type_oid
 
 
-@adapter(int)
+@Adapter.register(int)
 def adapt_int(obj, encode=codecs.lookup("ascii").encode):
     return encode(str(obj))[0], type_oid["numeric"]
 
 
-@caster(type_oid["numeric"])
+@Typecaster.register(type_oid["numeric"])
 def cast_int(data, decode=codecs.lookup("ascii").decode):
     return int(decode(data)[0])
index 37515fb2de451111ff2df901b144be93f50ad615..3d1405deba22ee948ff461176c8e56c46172103c 100644 (file)
@@ -6,13 +6,13 @@ Adapters of textual types.
 
 import codecs
 
-from ..adaptation import Adapter, adapter, binary_adapter
-from ..adaptation import Typecaster, caster, binary_caster
+from ..adaptation import Adapter
+from ..adaptation import Typecaster
 from .oids import type_oid
 
 
-@adapter(str)
-@binary_adapter(str)
+@Adapter.register(str)
+@Adapter.register_binary(str)
 class StringAdapter(Adapter):
     def __init__(self, cls, conn):
         super().__init__(cls, conn)
@@ -24,8 +24,8 @@ class StringAdapter(Adapter):
         return self.encode(obj)[0]
 
 
-@caster(type_oid["text"])
-@binary_caster(type_oid["text"])
+@Typecaster.register(type_oid["text"])
+@Typecaster.register_binary(type_oid["text"])
 class StringCaster(Typecaster):
     def __init__(self, oid, conn):
         super().__init__(oid, conn)