# Copyright (C) 2020 The Psycopg Team
import codecs
-from functools import partial
from . import exceptions as exc
from .pq import Format
INVALID_OID = 0
TEXT_OID = 25
NUMERIC_OID = 1700
-FLOAT8_INT = 701
ascii_encode = codecs.lookup("ascii").encode
ascii_decode = codecs.lookup("ascii").decode
types = []
for var, fmt in zip(objs, fmts):
- data, oid = self.adapt(var, fmt)
+ data = self.adapt(var, fmt)
+ if isinstance(data, tuple):
+ oid = data[1]
+ data = data[0]
+ else:
+ oid = TEXT_OID
+
out.append(data)
types.append(oid)
except KeyError:
pass
- xf = self.lookup_adapter(cls)
- if fmt == Format.TEXT:
- func = self._adapt_funcs[cls, fmt] = xf.get_text_adapter(
- cls, self.connection
- )
- else:
- assert fmt == Format.BINARY
- func = self._adapt_funcs[cls, fmt] = xf.get_binary_adapter(
- cls, self.connection
- )
+ adapter = self.lookup_adapter(cls, fmt)
+ if isinstance(adapter, type):
+ adapter = adapter(cls, self.connection).adapt
- return func
+ return adapter
+
+ def lookup_adapter(self, cls, fmt):
+ key = (cls, fmt)
- def lookup_adapter(self, cls):
cur = self.cursor
- if cur is not None and cls in cur.adapters:
- return cur.adapters[cls]
+ if cur is not None and key in cur.adapters:
+ return cur.adapters[key]
conn = self.connection
- if conn is not None and cls in conn.adapters:
- return conn.adapters[cls]
+ if conn is not None and key in conn.adapters:
+ return conn.adapters[key]
- if cls in global_adapters:
- return global_adapters[cls]
+ if key in global_adapters:
+ return global_adapters[key]
- raise exc.ProgrammingError(f"cannot adapt type {cls.__name__}")
+ raise exc.ProgrammingError(
+ f"cannot adapt type {cls.__name__} to format {fmt}"
+ )
def cast_row(self, result, n):
self.result = result
except KeyError:
pass
- xf = self.lookup_caster(oid)
- if fmt == Format.TEXT:
- func = self._cast_funcs[oid, fmt] = xf.get_text_caster(
- oid, self.connection
- )
- else:
- assert fmt == Format.BINARY
- func = self._cast_funcs[oid, fmt] = xf.get_binary_caster(
- oid, self.connection
- )
+ caster = self.lookup_caster(oid, fmt)
+ if isinstance(caster, type):
+ caster = caster(oid, self.connection).cast
- return func
+ return caster
+
+ def lookup_caster(self, oid, fmt):
+ key = (oid, fmt)
- def lookup_caster(self, oid):
cur = self.cursor
- if cur is not None and oid in cur.casters:
- return cur.casters[oid]
+ if cur is not None and key in cur.casters:
+ return cur.casters[key]
conn = self.connection
- if conn is not None and oid in conn.casters:
- return conn.casters[oid]
+ if conn is not None and key in conn.casters:
+ return conn.casters[key]
- if oid in global_casters:
- return global_casters[oid]
- else:
- return UnknownCaster()
+ if key in global_casters:
+ return global_casters[key]
+
+ return global_casters[INVALID_OID, fmt]
class Adapter:
- def get_text_adapter(self, cls, conn):
- raise exc.NotSupportedError(
- f"the type {cls.__name__} doesn't support text adaptation"
- )
+ def __init__(self, cls, conn):
+ self.cls = cls
+ self.conn = conn
- def get_binary_adapter(self, cls, conn):
- raise exc.NotSupportedError(
- f"the type {cls.__name__} doesn't support binary adaptation"
- )
+ def adapt(self, obj):
+ raise NotImplementedError()
class Typecaster:
- def get_text_caster(self, oid, conn):
- raise exc.NotSupportedError(
- f"the PostgreSQL type {oid} doesn't support cast from text"
- )
-
- def get_binary_caster(self, oid, conn):
- raise exc.NotSupportedError(
- f"the PostgreSQL type {oid} doesn't support cast from binary"
- )
-
- @staticmethod
- def cast_to_bytes(value):
- return value
+ def __init__(self, oid, conn):
+ self.oid = oid
+ self.conn = conn
- @staticmethod
- def cast_to_str(codec, value):
- return codec.decode(value)[0]
+ def cast(self, data):
+ raise NotImplementedError()
class StringAdapter(Adapter):
- def get_text_adapter(self, cls, conn):
- codec = conn.codec if conn is not None else utf8_codec
- return partial(self.adapt_str, codec)
+ def __init__(self, cls, conn):
+ super().__init__(cls, conn)
+ self.encode = (conn.codec if conn is not None else utf8_codec).encode
- # format is the same in binary and text
- get_binary_adapter = get_text_adapter
-
- @staticmethod
- def adapt_str(codec, value):
- return codec.encode(value)[0], TEXT_OID
+ def adapt(self, obj):
+ return self.encode(obj)[0]
class StringCaster(Typecaster):
- def get_text_caster(self, oid, conn):
- if conn is None or conn.pgenc == b"SQL_ASCII":
- # we don't have enough info to decode bytes
- return self.unparsed_bytes
-
- codec = conn.codec
- return partial(self.cast_to_str, codec)
-
- # format is the same in binary and text
- get_binary_caster = get_text_caster
-
+ def __init__(self, oid, conn):
+ super().__init__(oid, conn)
+ if conn is not None:
+ if conn.pgenc != b"SQL_ASCII":
+ self.decode = conn.codec.decode
+ else:
+ self.decode = None
+ else:
+ self.decode = utf8_codec.decode
-global_adapters[str] = StringAdapter()
-global_casters[TEXT_OID] = StringCaster()
+ def cast(self, data):
+ if self.decode is not None:
+ return self.decode(data)[0]
+ else:
+ # return bytes for SQL_ASCII db
+ return data
-class IntAdapter(Adapter):
- def get_text_adapter(self, cls, conn):
- return self.adapt_int
+global_adapters[str, Format.TEXT] = StringAdapter
+global_adapters[str, Format.BINARY] = StringAdapter
+global_casters[TEXT_OID, Format.TEXT] = StringCaster
+global_casters[TEXT_OID, Format.BINARY] = StringCaster
- @staticmethod
- def adapt_int(value):
- return ascii_encode(str(value))[0], NUMERIC_OID
+def adapt_int(obj):
+ return ascii_encode(str(obj))[0], NUMERIC_OID
-class IntCaster(Typecaster):
- def get_text_caster(self, oid, conn):
- return self.cast_int
- @staticmethod
- def cast_int(value):
- return int(ascii_decode(value)[0])
+def cast_int(data):
+ return int(ascii_decode(data)[0])
-global_adapters[int] = IntAdapter()
-global_casters[NUMERIC_OID] = IntCaster()
+global_adapters[int, Format.TEXT] = adapt_int
+global_casters[NUMERIC_OID, Format.TEXT] = cast_int
class UnknownCaster(Typecaster):
Fallback object to convert unknown types to Python
"""
- def get_text_caster(self, oid, conn):
- if conn is None:
- # we don't have enough info to decode bytes
- return self.cast_to_bytes
+ def __init__(self, oid, conn):
+ super().__init__(oid, conn)
+ if conn is not None:
+ self.decode = conn.codec.decode
+ else:
+ self.decode = utf8_codec.decode
+
+ def cast(self, data):
+ return self.decode(data)[0]
+
- codec = conn.codec
- return partial(self.cast_to_str, codec)
+def binary_cast_unknown(data):
+ return data
- def get_binary_caster(self, oid, conn):
- return self.cast_to_bytes
- @staticmethod
- def cast_to_str(codec, value):
- return codec.decode(value)[0]
+global_casters[INVALID_OID, Format.TEXT] = UnknownCaster
+global_casters[INVALID_OID, Format.BINARY] = binary_cast_unknown