]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
register function and global adapters map moved as class members
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 28 Mar 2020 14:54:21 +0000 (03:54 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 28 Mar 2020 14:54:21 +0000 (03:54 +1300)
psycopg3/adaptation.py

index d717381a09be4d2b5865bddf2ea34530f3fc3000..a8be71ff485799809acd7f362b559d69a3a32ea0 100644 (file)
@@ -14,73 +14,94 @@ from .connection import BaseConnection
 
 INVALID_OID = 0
 
-global_adapters = {}
-global_casters = {}
 
+class Adapter:
+    globals = {}
 
-def register_adapter(cls, adapter, context=None, format=Format.TEXT):
-    if not isinstance(cls, type):
-        raise TypeError(
-            f"adapters should be registered on classes, got {cls} instead"
-        )
+    def __init__(self, cls, conn):
+        self.cls = cls
+        self.conn = conn
 
-    if context is not None and not isinstance(
-        context, (BaseConnection, BaseCursor)
-    ):
-        raise TypeError(
-            f"the context should be a connection or cursor,"
-            f" got {type(context).__name__}"
-        )
+    def adapt(self, obj):
+        raise NotImplementedError()
 
-    if not (
-        callable(adapter)
-        or (isinstance(adapter, type) and issubclass(adapter, Adapter))
-    ):
-        raise TypeError(
-            f"adapters should be callable or Adapter subclasses, got {adapter}"
-        )
+    @staticmethod
+    def register(cls, adapter, context=None, format=Format.TEXT):
+        if not isinstance(cls, type):
+            raise TypeError(
+                f"adapters should be registered on classes, got {cls} instead"
+            )
 
-    where = context.adapters if context is not None else global_adapters
-    where[cls, format] = adapter
+        if context is not None and not isinstance(
+            context, (BaseConnection, BaseCursor)
+        ):
+            raise TypeError(
+                f"the context should be a connection or cursor,"
+                f" got {type(context).__name__}"
+            )
 
+        if not (
+            callable(adapter)
+            or (isinstance(adapter, type) and issubclass(adapter, Adapter))
+        ):
+            raise TypeError(
+                f"adapters should be callable or Adapter subclasses,"
+                f" got {adapter} instead"
+            )
 
-def register_binary_adapter(cls, adapter, context=None):
-    register_adapter(cls, adapter, context, format=Format.BINARY)
+        where = context.adapters if context is not None else Adapter.globals
+        where[cls, format] = adapter
 
+    @staticmethod
+    def register_binary(cls, adapter, context=None):
+        Adapter.register(cls, adapter, context, format=Format.BINARY)
 
-def register_caster(oid, caster, context=None, format=Format.TEXT):
-    if not isinstance(oid, int):
-        raise TypeError(
-            f"typecasters should be registered on oid, got {oid} instead"
-        )
 
-    if context is not None and not isinstance(
-        context, (BaseConnection, BaseCursor)
-    ):
-        raise TypeError(
-            f"the context should be a connection or cursor,"
-            f" got {type(context).__name__}"
-        )
+class Typecaster:
+    globals = {}
 
-    if not (
-        callable(caster)
-        or (isinstance(caster, type) and issubclass(caster, Typecaster))
-    ):
-        raise TypeError(
-            f"adapters should be callable or Typecaster subclasses, got {caster}"
-        )
+    def __init__(self, oid, conn):
+        self.oid = oid
+        self.conn = conn
 
-    where = context.adapters if context is not None else global_casters
-    where[oid, format] = caster
+    def cast(self, data):
+        raise NotImplementedError()
 
+    @staticmethod
+    def register(oid, caster, context=None, format=Format.TEXT):
+        if not isinstance(oid, int):
+            raise TypeError(
+                f"typecasters should be registered on oid, got {oid} instead"
+            )
 
-def register_binary_caster(oid, caster, context=None):
-    register_caster(oid, caster, context, format=Format.BINARY)
+        if context is not None and not isinstance(
+            context, (BaseConnection, BaseCursor)
+        ):
+            raise TypeError(
+                f"the context should be a connection or cursor,"
+                f" got {type(context).__name__}"
+            )
+
+        if not (
+            callable(caster)
+            or (isinstance(caster, type) and issubclass(caster, Typecaster))
+        ):
+            raise TypeError(
+                f"adapters should be callable or Typecaster subclasses,"
+                f" got {caster} instead"
+            )
+
+        where = context.adapters if context is not None else Typecaster.globals
+        where[oid, format] = caster
+
+    @staticmethod
+    def register_binary(oid, caster, context=None):
+        Typecaster.register(oid, caster, context, format=Format.BINARY)
 
 
 def adapter(oid):
     def adapter_(obj):
-        register_adapter(oid, obj)
+        Adapter.register(oid, obj)
         return obj
 
     return adapter_
@@ -88,7 +109,7 @@ def adapter(oid):
 
 def binary_adapter(oid):
     def binary_adapter_(obj):
-        register_binary_adapter(oid, obj)
+        Adapter.register_binary(oid, obj)
         return obj
 
     return binary_adapter_
@@ -96,7 +117,7 @@ def binary_adapter(oid):
 
 def caster(oid):
     def caster_(obj):
-        register_caster(oid, obj)
+        Typecaster.register(oid, obj)
         return obj
 
     return caster_
@@ -104,7 +125,7 @@ def caster(oid):
 
 def binary_caster(oid):
     def binary_caster_(obj):
-        register_binary_caster(oid, obj)
+        Typecaster.register_binary(oid, obj)
         return obj
 
     return binary_caster_
@@ -212,8 +233,8 @@ class Transformer:
         if conn is not None and key in conn.adapters:
             return conn.adapters[key]
 
-        if key in global_adapters:
-            return global_adapters[key]
+        if key in Adapter.globals:
+            return Adapter.globals[key]
 
         raise exc.ProgrammingError(
             f"cannot adapt type {cls.__name__} to format {Format(fmt).name}"
@@ -251,28 +272,10 @@ class Transformer:
         if conn is not None and key in conn.casters:
             return conn.casters[key]
 
-        if key in global_casters:
-            return global_casters[key]
-
-        return global_casters[INVALID_OID, fmt]
-
+        if key in Typecaster.globals:
+            return Typecaster.globals[key]
 
-class Adapter:
-    def __init__(self, cls, conn):
-        self.cls = cls
-        self.conn = conn
-
-    def adapt(self, obj):
-        raise NotImplementedError()
-
-
-class Typecaster:
-    def __init__(self, oid, conn):
-        self.oid = oid
-        self.conn = conn
-
-    def cast(self, data):
-        raise NotImplementedError()
+        return Typecaster.globals[INVALID_OID, fmt]
 
 
 @caster(INVALID_OID)