]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Making adapters/casters simpler
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 28 Mar 2020 11:22:20 +0000 (00:22 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 28 Mar 2020 11:34:01 +0000 (00:34 +1300)
It can be either a function or an Adapter/Typecaster instance, in case
configuration based on class/oid/connection is required.

Adapter functions can optionally return an oid.

psycopg3/adaptation.py

index ba5761141ce05879ec1c92f4e3627a6fabc4a288..5bb1b4096084d0866bc9fa5e134ba6eb977e0c50 100644 (file)
@@ -5,7 +5,6 @@ Entry point into the adaptation system.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
-from functools import partial
 
 from . import exceptions as exc
 from .pq import Format
@@ -13,7 +12,6 @@ from .pq import Format
 INVALID_OID = 0
 TEXT_OID = 25
 NUMERIC_OID = 1700
-FLOAT8_INT = 701
 
 ascii_encode = codecs.lookup("ascii").encode
 ascii_decode = codecs.lookup("ascii").decode
@@ -85,7 +83,13 @@ class ValuesTransformer:
         types = []
 
         for var, fmt in zip(objs, fmts):
-            data, oid = self.adapt(var, fmt)
+            data = self.adapt(var, fmt)
+            if isinstance(data, tuple):
+                oid = data[1]
+                data = data[0]
+            else:
+                oid = TEXT_OID
+
             out.append(data)
             types.append(oid)
 
@@ -105,32 +109,29 @@ class ValuesTransformer:
         except KeyError:
             pass
 
-        xf = self.lookup_adapter(cls)
-        if fmt == Format.TEXT:
-            func = self._adapt_funcs[cls, fmt] = xf.get_text_adapter(
-                cls, self.connection
-            )
-        else:
-            assert fmt == Format.BINARY
-            func = self._adapt_funcs[cls, fmt] = xf.get_binary_adapter(
-                cls, self.connection
-            )
+        adapter = self.lookup_adapter(cls, fmt)
+        if isinstance(adapter, type):
+            adapter = adapter(cls, self.connection).adapt
 
-        return func
+        return adapter
+
+    def lookup_adapter(self, cls, fmt):
+        key = (cls, fmt)
 
-    def lookup_adapter(self, cls):
         cur = self.cursor
-        if cur is not None and cls in cur.adapters:
-            return cur.adapters[cls]
+        if cur is not None and key in cur.adapters:
+            return cur.adapters[key]
 
         conn = self.connection
-        if conn is not None and cls in conn.adapters:
-            return conn.adapters[cls]
+        if conn is not None and key in conn.adapters:
+            return conn.adapters[key]
 
-        if cls in global_adapters:
-            return global_adapters[cls]
+        if key in global_adapters:
+            return global_adapters[key]
 
-        raise exc.ProgrammingError(f"cannot adapt type {cls.__name__}")
+        raise exc.ProgrammingError(
+            f"cannot adapt type {cls.__name__} to format {fmt}"
+        )
 
     def cast_row(self, result, n):
         self.result = result
@@ -147,116 +148,91 @@ class ValuesTransformer:
         except KeyError:
             pass
 
-        xf = self.lookup_caster(oid)
-        if fmt == Format.TEXT:
-            func = self._cast_funcs[oid, fmt] = xf.get_text_caster(
-                oid, self.connection
-            )
-        else:
-            assert fmt == Format.BINARY
-            func = self._cast_funcs[oid, fmt] = xf.get_binary_caster(
-                oid, self.connection
-            )
+        caster = self.lookup_caster(oid, fmt)
+        if isinstance(caster, type):
+            caster = caster(oid, self.connection).cast
 
-        return func
+        return caster
+
+    def lookup_caster(self, oid, fmt):
+        key = (oid, fmt)
 
-    def lookup_caster(self, oid):
         cur = self.cursor
-        if cur is not None and oid in cur.casters:
-            return cur.casters[oid]
+        if cur is not None and key in cur.casters:
+            return cur.casters[key]
 
         conn = self.connection
-        if conn is not None and oid in conn.casters:
-            return conn.casters[oid]
+        if conn is not None and key in conn.casters:
+            return conn.casters[key]
 
-        if oid in global_casters:
-            return global_casters[oid]
-        else:
-            return UnknownCaster()
+        if key in global_casters:
+            return global_casters[key]
+
+        return global_casters[INVALID_OID, fmt]
 
 
 class Adapter:
-    def get_text_adapter(self, cls, conn):
-        raise exc.NotSupportedError(
-            f"the type {cls.__name__} doesn't support text adaptation"
-        )
+    def __init__(self, cls, conn):
+        self.cls = cls
+        self.conn = conn
 
-    def get_binary_adapter(self, cls, conn):
-        raise exc.NotSupportedError(
-            f"the type {cls.__name__} doesn't support binary adaptation"
-        )
+    def adapt(self, obj):
+        raise NotImplementedError()
 
 
 class Typecaster:
-    def get_text_caster(self, oid, conn):
-        raise exc.NotSupportedError(
-            f"the PostgreSQL type {oid} doesn't support cast from text"
-        )
-
-    def get_binary_caster(self, oid, conn):
-        raise exc.NotSupportedError(
-            f"the PostgreSQL type {oid} doesn't support cast from binary"
-        )
-
-    @staticmethod
-    def cast_to_bytes(value):
-        return value
+    def __init__(self, oid, conn):
+        self.oid = oid
+        self.conn = conn
 
-    @staticmethod
-    def cast_to_str(codec, value):
-        return codec.decode(value)[0]
+    def cast(self, data):
+        raise NotImplementedError()
 
 
 class StringAdapter(Adapter):
-    def get_text_adapter(self, cls, conn):
-        codec = conn.codec if conn is not None else utf8_codec
-        return partial(self.adapt_str, codec)
+    def __init__(self, cls, conn):
+        super().__init__(cls, conn)
+        self.encode = (conn.codec if conn is not None else utf8_codec).encode
 
-    # format is the same in binary and text
-    get_binary_adapter = get_text_adapter
-
-    @staticmethod
-    def adapt_str(codec, value):
-        return codec.encode(value)[0], TEXT_OID
+    def adapt(self, obj):
+        return self.encode(obj)[0]
 
 
 class StringCaster(Typecaster):
-    def get_text_caster(self, oid, conn):
-        if conn is None or conn.pgenc == b"SQL_ASCII":
-            # we don't have enough info to decode bytes
-            return self.unparsed_bytes
-
-        codec = conn.codec
-        return partial(self.cast_to_str, codec)
-
-    # format is the same in binary and text
-    get_binary_caster = get_text_caster
-
+    def __init__(self, oid, conn):
+        super().__init__(oid, conn)
+        if conn is not None:
+            if conn.pgenc != b"SQL_ASCII":
+                self.decode = conn.codec.decode
+            else:
+                self.decode = None
+        else:
+            self.decode = utf8_codec.decode
 
-global_adapters[str] = StringAdapter()
-global_casters[TEXT_OID] = StringCaster()
+    def cast(self, data):
+        if self.decode is not None:
+            return self.decode(data)[0]
+        else:
+            # return bytes for SQL_ASCII db
+            return data
 
 
-class IntAdapter(Adapter):
-    def get_text_adapter(self, cls, conn):
-        return self.adapt_int
+global_adapters[str, Format.TEXT] = StringAdapter
+global_adapters[str, Format.BINARY] = StringAdapter
+global_casters[TEXT_OID, Format.TEXT] = StringCaster
+global_casters[TEXT_OID, Format.BINARY] = StringCaster
 
-    @staticmethod
-    def adapt_int(value):
-        return ascii_encode(str(value))[0], NUMERIC_OID
 
+def adapt_int(obj):
+    return ascii_encode(str(obj))[0], NUMERIC_OID
 
-class IntCaster(Typecaster):
-    def get_text_caster(self, oid, conn):
-        return self.cast_int
 
-    @staticmethod
-    def cast_int(value):
-        return int(ascii_decode(value)[0])
+def cast_int(data):
+    return int(ascii_decode(data)[0])
 
 
-global_adapters[int] = IntAdapter()
-global_casters[NUMERIC_OID] = IntCaster()
+global_adapters[int, Format.TEXT] = adapt_int
+global_casters[NUMERIC_OID, Format.TEXT] = cast_int
 
 
 class UnknownCaster(Typecaster):
@@ -264,17 +240,20 @@ class UnknownCaster(Typecaster):
     Fallback object to convert unknown types to Python
     """
 
-    def get_text_caster(self, oid, conn):
-        if conn is None:
-            # we don't have enough info to decode bytes
-            return self.cast_to_bytes
+    def __init__(self, oid, conn):
+        super().__init__(oid, conn)
+        if conn is not None:
+            self.decode = conn.codec.decode
+        else:
+            self.decode = utf8_codec.decode
+
+    def cast(self, data):
+        return self.decode(data)[0]
+
 
-        codec = conn.codec
-        return partial(self.cast_to_str, codec)
+def binary_cast_unknown(data):
+    return data
 
-    def get_binary_caster(self, oid, conn):
-        return self.cast_to_bytes
 
-    @staticmethod
-    def cast_to_str(codec, value):
-        return codec.decode(value)[0]
+global_casters[INVALID_OID, Format.TEXT] = UnknownCaster
+global_casters[INVALID_OID, Format.BINARY] = binary_cast_unknown