From bd50582a3ad93efd9ff2f02cd948c4a6ffa9d2fb Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 29 Mar 2020 00:22:20 +1300 Subject: [PATCH] Making adapters/casters simpler It can be either a function or an Adapter/Typecaster instance, in case configuration based on class/oid/connection is required. Adapter functions can optionally return an oid. --- psycopg3/adaptation.py | 205 ++++++++++++++++++----------------------- 1 file changed, 92 insertions(+), 113 deletions(-) diff --git a/psycopg3/adaptation.py b/psycopg3/adaptation.py index ba5761141..5bb1b4096 100644 --- a/psycopg3/adaptation.py +++ b/psycopg3/adaptation.py @@ -5,7 +5,6 @@ Entry point into the adaptation system. # Copyright (C) 2020 The Psycopg Team import codecs -from functools import partial from . import exceptions as exc from .pq import Format @@ -13,7 +12,6 @@ from .pq import Format INVALID_OID = 0 TEXT_OID = 25 NUMERIC_OID = 1700 -FLOAT8_INT = 701 ascii_encode = codecs.lookup("ascii").encode ascii_decode = codecs.lookup("ascii").decode @@ -85,7 +83,13 @@ class ValuesTransformer: types = [] for var, fmt in zip(objs, fmts): - data, oid = self.adapt(var, fmt) + data = self.adapt(var, fmt) + if isinstance(data, tuple): + oid = data[1] + data = data[0] + else: + oid = TEXT_OID + out.append(data) types.append(oid) @@ -105,32 +109,29 @@ class ValuesTransformer: except KeyError: pass - xf = self.lookup_adapter(cls) - if fmt == Format.TEXT: - func = self._adapt_funcs[cls, fmt] = xf.get_text_adapter( - cls, self.connection - ) - else: - assert fmt == Format.BINARY - func = self._adapt_funcs[cls, fmt] = xf.get_binary_adapter( - cls, self.connection - ) + adapter = self.lookup_adapter(cls, fmt) + if isinstance(adapter, type): + adapter = adapter(cls, self.connection).adapt - return func + return adapter + + def lookup_adapter(self, cls, fmt): + key = (cls, fmt) - def lookup_adapter(self, cls): cur = self.cursor - if cur is not None and cls in cur.adapters: - return cur.adapters[cls] + if cur is not None and key in cur.adapters: + return cur.adapters[key] conn = self.connection - if conn is not None and cls in conn.adapters: - return conn.adapters[cls] + if conn is not None and key in conn.adapters: + return conn.adapters[key] - if cls in global_adapters: - return global_adapters[cls] + if key in global_adapters: + return global_adapters[key] - raise exc.ProgrammingError(f"cannot adapt type {cls.__name__}") + raise exc.ProgrammingError( + f"cannot adapt type {cls.__name__} to format {fmt}" + ) def cast_row(self, result, n): self.result = result @@ -147,116 +148,91 @@ class ValuesTransformer: except KeyError: pass - xf = self.lookup_caster(oid) - if fmt == Format.TEXT: - func = self._cast_funcs[oid, fmt] = xf.get_text_caster( - oid, self.connection - ) - else: - assert fmt == Format.BINARY - func = self._cast_funcs[oid, fmt] = xf.get_binary_caster( - oid, self.connection - ) + caster = self.lookup_caster(oid, fmt) + if isinstance(caster, type): + caster = caster(oid, self.connection).cast - return func + return caster + + def lookup_caster(self, oid, fmt): + key = (oid, fmt) - def lookup_caster(self, oid): cur = self.cursor - if cur is not None and oid in cur.casters: - return cur.casters[oid] + if cur is not None and key in cur.casters: + return cur.casters[key] conn = self.connection - if conn is not None and oid in conn.casters: - return conn.casters[oid] + if conn is not None and key in conn.casters: + return conn.casters[key] - if oid in global_casters: - return global_casters[oid] - else: - return UnknownCaster() + if key in global_casters: + return global_casters[key] + + return global_casters[INVALID_OID, fmt] class Adapter: - def get_text_adapter(self, cls, conn): - raise exc.NotSupportedError( - f"the type {cls.__name__} doesn't support text adaptation" - ) + def __init__(self, cls, conn): + self.cls = cls + self.conn = conn - def get_binary_adapter(self, cls, conn): - raise exc.NotSupportedError( - f"the type {cls.__name__} doesn't support binary adaptation" - ) + def adapt(self, obj): + raise NotImplementedError() class Typecaster: - def get_text_caster(self, oid, conn): - raise exc.NotSupportedError( - f"the PostgreSQL type {oid} doesn't support cast from text" - ) - - def get_binary_caster(self, oid, conn): - raise exc.NotSupportedError( - f"the PostgreSQL type {oid} doesn't support cast from binary" - ) - - @staticmethod - def cast_to_bytes(value): - return value + def __init__(self, oid, conn): + self.oid = oid + self.conn = conn - @staticmethod - def cast_to_str(codec, value): - return codec.decode(value)[0] + def cast(self, data): + raise NotImplementedError() class StringAdapter(Adapter): - def get_text_adapter(self, cls, conn): - codec = conn.codec if conn is not None else utf8_codec - return partial(self.adapt_str, codec) + def __init__(self, cls, conn): + super().__init__(cls, conn) + self.encode = (conn.codec if conn is not None else utf8_codec).encode - # format is the same in binary and text - get_binary_adapter = get_text_adapter - - @staticmethod - def adapt_str(codec, value): - return codec.encode(value)[0], TEXT_OID + def adapt(self, obj): + return self.encode(obj)[0] class StringCaster(Typecaster): - def get_text_caster(self, oid, conn): - if conn is None or conn.pgenc == b"SQL_ASCII": - # we don't have enough info to decode bytes - return self.unparsed_bytes - - codec = conn.codec - return partial(self.cast_to_str, codec) - - # format is the same in binary and text - get_binary_caster = get_text_caster - + def __init__(self, oid, conn): + super().__init__(oid, conn) + if conn is not None: + if conn.pgenc != b"SQL_ASCII": + self.decode = conn.codec.decode + else: + self.decode = None + else: + self.decode = utf8_codec.decode -global_adapters[str] = StringAdapter() -global_casters[TEXT_OID] = StringCaster() + def cast(self, data): + if self.decode is not None: + return self.decode(data)[0] + else: + # return bytes for SQL_ASCII db + return data -class IntAdapter(Adapter): - def get_text_adapter(self, cls, conn): - return self.adapt_int +global_adapters[str, Format.TEXT] = StringAdapter +global_adapters[str, Format.BINARY] = StringAdapter +global_casters[TEXT_OID, Format.TEXT] = StringCaster +global_casters[TEXT_OID, Format.BINARY] = StringCaster - @staticmethod - def adapt_int(value): - return ascii_encode(str(value))[0], NUMERIC_OID +def adapt_int(obj): + return ascii_encode(str(obj))[0], NUMERIC_OID -class IntCaster(Typecaster): - def get_text_caster(self, oid, conn): - return self.cast_int - @staticmethod - def cast_int(value): - return int(ascii_decode(value)[0]) +def cast_int(data): + return int(ascii_decode(data)[0]) -global_adapters[int] = IntAdapter() -global_casters[NUMERIC_OID] = IntCaster() +global_adapters[int, Format.TEXT] = adapt_int +global_casters[NUMERIC_OID, Format.TEXT] = cast_int class UnknownCaster(Typecaster): @@ -264,17 +240,20 @@ class UnknownCaster(Typecaster): Fallback object to convert unknown types to Python """ - def get_text_caster(self, oid, conn): - if conn is None: - # we don't have enough info to decode bytes - return self.cast_to_bytes + def __init__(self, oid, conn): + super().__init__(oid, conn) + if conn is not None: + self.decode = conn.codec.decode + else: + self.decode = utf8_codec.decode + + def cast(self, data): + return self.decode(data)[0] + - codec = conn.codec - return partial(self.cast_to_str, codec) +def binary_cast_unknown(data): + return data - def get_binary_caster(self, oid, conn): - return self.cast_to_bytes - @staticmethod - def cast_to_str(codec, value): - return codec.decode(value)[0] +global_casters[INVALID_OID, Format.TEXT] = UnknownCaster +global_casters[INVALID_OID, Format.BINARY] = binary_cast_unknown -- 2.47.3