]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added adapter/caster decorators
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 28 Mar 2020 12:06:09 +0000 (01:06 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 28 Mar 2020 12:06:09 +0000 (01:06 +1300)
psycopg3/adaptation.py

index 293a47d00725fe33e785cd0adb8bb977abcacd91..3404b3c83f06978073a43551dfcbfd4d3b328061 100644 (file)
@@ -24,15 +24,27 @@ global_casters = {}
 
 
 def register_adapter(cls, adapter, context=None, format=Format.TEXT):
+    if not isinstance(cls, type):
+        raise TypeError(
+            f"adapters should be registered on classes, got {cls} instead"
+        )
 
     if context is not None and not isinstance(
         context, (BaseConnection, BaseCursor)
     ):
         raise TypeError(
-            f"the context should be a connection or cursor;"
+            f"the context should be a connection or cursor,"
             f" got {type(context).__name__}"
         )
 
+    if not (
+        callable(adapter)
+        or (isinstance(adapter, type) and issubclass(adapter, Adapter))
+    ):
+        raise TypeError(
+            f"adapters should be callable or Adapter subclasses, got {adapter}"
+        )
+
     where = context.adapters if context is not None else global_adapters
     where[cls, format] = adapter
 
@@ -42,14 +54,27 @@ def register_binary_adapter(cls, adapter, context=None):
 
 
 def register_caster(oid, caster, context=None, format=Format.TEXT):
+    if not isinstance(oid, int):
+        raise TypeError(
+            f"typecasters should be registered on oid, got {oid} instead"
+        )
+
     if context is not None and not isinstance(
         context, (BaseConnection, BaseCursor)
     ):
         raise TypeError(
-            f"the context should be a connection or cursor;"
+            f"the context should be a connection or cursor,"
             f" got {type(context).__name__}"
         )
 
+    if not (
+        callable(caster)
+        or (isinstance(caster, type) and issubclass(caster, Typecaster))
+    ):
+        raise TypeError(
+            f"adapters should be callable or Typecaster subclasses, got {caster}"
+        )
+
     where = context.adapters if context is not None else global_casters
     where[oid, format] = caster
 
@@ -58,6 +83,38 @@ def register_binary_caster(oid, caster, context=None):
     register_caster(oid, caster, context, format=Format.BINARY)
 
 
+def adapter(oid):
+    def adapter_(obj):
+        register_adapter(oid, obj)
+        return obj
+
+    return adapter_
+
+
+def binary_adapter(oid):
+    def binary_adapter_(obj):
+        register_binary_adapter(oid, obj)
+        return obj
+
+    return binary_adapter_
+
+
+def caster(oid):
+    def caster_(obj):
+        register_caster(oid, obj)
+        return obj
+
+    return caster_
+
+
+def binary_caster(oid):
+    def binary_caster_(obj):
+        register_binary_caster(oid, obj)
+        return obj
+
+    return binary_caster_
+
+
 class Transformer:
     """
     An object that can adapt efficiently between Python and PostgreSQL.
@@ -223,6 +280,8 @@ class Typecaster:
         raise NotImplementedError()
 
 
+@adapter(str)
+@binary_adapter(str)
 class StringAdapter(Adapter):
     def __init__(self, cls, conn):
         super().__init__(cls, conn)
@@ -232,6 +291,8 @@ class StringAdapter(Adapter):
         return self.encode(obj)[0]
 
 
+@caster(TEXT_OID)
+@binary_caster(TEXT_OID)
 class StringCaster(Typecaster):
     def __init__(self, oid, conn):
         super().__init__(oid, conn)
@@ -251,25 +312,17 @@ class StringCaster(Typecaster):
             return data
 
 
-register_adapter(str, StringAdapter)
-register_binary_adapter(str, StringAdapter)
-
-register_caster(TEXT_OID, StringCaster)
-register_binary_caster(TEXT_OID, StringCaster)
-
-
+@adapter(int)
 def adapt_int(obj):
     return ascii_encode(str(obj))[0], NUMERIC_OID
 
 
+@caster(NUMERIC_OID)
 def cast_int(data):
     return int(ascii_decode(data)[0])
 
 
-register_adapter(int, adapt_int)
-register_caster(NUMERIC_OID, cast_int)
-
-
+@caster(INVALID_OID)
 class UnknownCaster(Typecaster):
     """
     Fallback object to convert unknown types to Python
@@ -286,9 +339,6 @@ class UnknownCaster(Typecaster):
         return self.decode(data)[0]
 
 
-def binary_cast_unknown(data):
+@binary_caster(INVALID_OID)
+def cast_unknown(data):
     return data
-
-
-register_caster(INVALID_OID, UnknownCaster)
-register_binary_caster(INVALID_OID, binary_cast_unknown)