]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added register_adapter/caster functions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 28 Mar 2020 11:50:19 +0000 (00:50 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 28 Mar 2020 11:50:19 +0000 (00:50 +1300)
psycopg3/adaptation.py
psycopg3/cursor.py

index 5bb1b4096084d0866bc9fa5e134ba6eb977e0c50..293a47d00725fe33e785cd0adb8bb977abcacd91 100644 (file)
@@ -8,6 +8,8 @@ import codecs
 
 from . import exceptions as exc
 from .pq import Format
+from .cursor import BaseCursor
+from .connection import BaseConnection
 
 INVALID_OID = 0
 TEXT_OID = 25
@@ -21,9 +23,44 @@ global_adapters = {}
 global_casters = {}
 
 
-class ValuesTransformer:
+def register_adapter(cls, adapter, context=None, format=Format.TEXT):
+
+    if context is not None and not isinstance(
+        context, (BaseConnection, BaseCursor)
+    ):
+        raise TypeError(
+            f"the context should be a connection or cursor;"
+            f" got {type(context).__name__}"
+        )
+
+    where = context.adapters if context is not None else global_adapters
+    where[cls, format] = adapter
+
+
+def register_binary_adapter(cls, adapter, context=None):
+    register_adapter(cls, adapter, context, format=Format.BINARY)
+
+
+def register_caster(oid, caster, context=None, format=Format.TEXT):
+    if context is not None and not isinstance(
+        context, (BaseConnection, BaseCursor)
+    ):
+        raise TypeError(
+            f"the context should be a connection or cursor;"
+            f" got {type(context).__name__}"
+        )
+
+    where = context.adapters if context is not None else global_casters
+    where[oid, format] = caster
+
+
+def register_binary_caster(oid, caster, context=None):
+    register_caster(oid, caster, context, format=Format.BINARY)
+
+
+class Transformer:
     """
-    An object that can adapt efficiently a number of value.
+    An object that can adapt efficiently between Python and PostgreSQL.
 
     The life cycle of the object is the query, so it is assumed that stuff like
     the server version or connection encoding will not change. It can have its
@@ -31,9 +68,6 @@ class ValuesTransformer:
     """
 
     def __init__(self, context):
-        from .connection import BaseConnection
-        from .cursor import BaseCursor
-
         if context is None:
             self.connection = None
             self.cursor = None
@@ -217,10 +251,11 @@ class StringCaster(Typecaster):
             return data
 
 
-global_adapters[str, Format.TEXT] = StringAdapter
-global_adapters[str, Format.BINARY] = StringAdapter
-global_casters[TEXT_OID, Format.TEXT] = StringCaster
-global_casters[TEXT_OID, Format.BINARY] = StringCaster
+register_adapter(str, StringAdapter)
+register_binary_adapter(str, StringAdapter)
+
+register_caster(TEXT_OID, StringCaster)
+register_binary_caster(TEXT_OID, StringCaster)
 
 
 def adapt_int(obj):
@@ -231,8 +266,8 @@ def cast_int(data):
     return int(ascii_decode(data)[0])
 
 
-global_adapters[int, Format.TEXT] = adapt_int
-global_casters[NUMERIC_OID, Format.TEXT] = cast_int
+register_adapter(int, adapt_int)
+register_caster(NUMERIC_OID, cast_int)
 
 
 class UnknownCaster(Typecaster):
@@ -255,5 +290,5 @@ def binary_cast_unknown(data):
     return data
 
 
-global_casters[INVALID_OID, Format.TEXT] = UnknownCaster
-global_casters[INVALID_OID, Format.BINARY] = binary_cast_unknown
+register_caster(INVALID_OID, UnknownCaster)
+register_binary_caster(INVALID_OID, binary_cast_unknown)
index d92c16b16acba6231a63cdd8d1bdf0f37431f473..e5d63f1f9539b13b3cd23f95492be0546293d12d 100644 (file)
@@ -6,7 +6,6 @@ psycopg3 cursor objects
 
 from . import exceptions as exc
 from .pq import error_message, DiagnosticField, ExecStatus
-from .adaptation import ValuesTransformer
 from .utils.queries import query2pg, reorder_params
 
 
@@ -19,11 +18,13 @@ class BaseCursor:
         self._reset()
 
     def _reset(self):
+        from .adaptation import Transformer
+
         self._results = []
         self._result = None
         self._pos = 0
         self._iresult = 0
-        self._transformer = ValuesTransformer(self)
+        self._transformer = Transformer(self)
 
     def _execute_send(self, query, vars):
         # Implement part of execute() before waiting common to sync and async