# Copyright (C) 2020 The Psycopg Team
import codecs
+from functools import partial
from . import exceptions as exc
from .pq import Format
FLOAT8_INT = 701
ascii_encode = codecs.lookup("ascii").encode
+ascii_decode = codecs.lookup("ascii").decode
utf8_codec = codecs.lookup("utf-8")
+global_adapters = {}
+global_casters = {}
+
-class ValuesAdapter:
+class ValuesTransformer:
"""
An object that can adapt efficiently a number of value.
else:
raise TypeError(
f"the context should be a connection or cursor,"
- f" got {type(context).__name__}")
+ f" got {type(context).__name__}"
+ )
- # mapping class -> adaptation function
+ # mapping class, fmt -> adaptation function
self._adapt_funcs = {}
+ # mapping oid, fmt -> cast function
+ self._cast_funcs = {}
+
+ # The result to return values from
+ self._result = None
+
+ # sequence of cast function from value to python
+ # the length of the result columns
+ self._row_casters = None
+
+ @property
+ def result(self):
+ return self._result
+
+ @result.setter
+ def result(self, result):
+ if self._result is result:
+ return
+
+ rc = self._row_casters = []
+ for c in range(result.nfields):
+ oid = result.ftype(c)
+ fmt = result.fformat(c)
+ func = self.get_cast_function(oid, fmt)
+ rc.append(func)
+
def adapt_sequence(self, objs, fmts):
out = []
types = []
return None, TEXT_OID
cls = type(obj)
+ func = self.get_adapt_function(cls, fmt)
+ return func(obj)
+
+ def get_adapt_function(self, cls, fmt):
try:
- func = self._adapt_funcs[cls, fmt]
+ return self._adapt_funcs[cls, fmt]
except KeyError:
pass
- else:
- return func(obj)
- adapter = self.lookup_adapter(cls)
+ xf = self.lookup_adapter(cls)
if fmt == Format.TEXT:
- func = self._adapt_funcs[cls, fmt] = adapter.get_text_adapter(
+ func = self._adapt_funcs[cls, fmt] = xf.get_text_adapter(
cls, self.connection
)
else:
assert fmt == Format.BINARY
- func = self._adapt_funcs[cls, fmt] = adapter.get_binary_adapter(
+ func = self._adapt_funcs[cls, fmt] = xf.get_binary_adapter(
cls, self.connection
)
- return func(obj)
+ return func
def lookup_adapter(self, cls):
cur = self.cursor
- if (
- cur is not None
- and cls in cur.adapters
- ):
+ if cur is not None and cls in cur.adapters:
return cur.adapters[cls]
conn = self.connection
- if (
- conn is not None
- and cls in conn.adapters
- ):
+ if conn is not None and cls in conn.adapters:
return conn.adapters[cls]
if cls in global_adapters:
raise exc.ProgrammingError(f"cannot adapt type {cls.__name__}")
+ def cast_row(self, result, n):
+ self.result = result
-global_adapters = {}
+ for col, func in enumerate(self._row_casters):
+ v = result.get_value(n, col)
+ if v is not None:
+ v = func(v)
+ yield v
+
+ def get_cast_function(self, oid, fmt):
+ try:
+ return self._cast_funcs[oid, fmt]
+ except KeyError:
+ pass
+
+ xf = self.lookup_caster(oid)
+ if fmt == Format.TEXT:
+ func = self._cast_funcs[oid, fmt] = xf.get_text_caster(
+ oid, self.connection
+ )
+ else:
+ assert fmt == Format.BINARY
+ func = self._cast_funcs[oid, fmt] = xf.get_binary_caster(
+ oid, self.connection
+ )
+
+ return func
+ def lookup_caster(self, oid):
+ cur = self.cursor
+ if cur is not None and oid in cur.casters:
+ return cur.casters[oid]
+
+ conn = self.connection
+ if conn is not None and oid in conn.casters:
+ return conn.casters[oid]
-class Adapter:
+ if oid in global_casters:
+ return global_casters[oid]
+ else:
+ return UnknownCaster()
+
+
+class Transformer:
def get_text_adapter(self, cls, conn):
raise exc.NotSupportedError(
f"the type {cls.__name__} doesn't support text adaptation"
f"the type {cls.__name__} doesn't support binary adaptation"
)
+ def get_text_caster(self, oid, conn):
+ raise exc.NotSupportedError(
+ f"the PostgreSQL type {oid} doesn't support cast from text"
+ )
+
+ def get_binary_caster(self, oid, conn):
+ raise exc.NotSupportedError(
+ f"the PostgreSQL type {oid} doesn't support cast from binary"
+ )
+
+ @staticmethod
+ def cast_to_bytes(value):
+ return value
+
+ @staticmethod
+ def cast_to_str(codec, value):
+ return codec.decode(value)[0]
-class StringAdapter(Adapter):
+
+class StringTransformer(Transformer):
def get_text_adapter(self, cls, conn):
codec = conn.codec if conn is not None else utf8_codec
+ return partial(self.adapt_str, codec)
- def adapt_text(value):
- return codec.encode(value)[0], TEXT_OID
+ def get_text_caster(self, oid, conn):
+ if conn is None or conn.pgenc == b"SQL_ASCII":
+ # we don't have enough info to decode bytes
+ return self.unparsed_bytes
- return adapt_text
+ codec = conn.codec
+ return partial(self.cast_to_str, codec)
# format is the same in binary and text
get_binary_adapter = get_text_adapter
+ get_binary_caster = get_text_caster
+
+ @staticmethod
+ def adapt_str(codec, value):
+ return codec.encode(value)[0], TEXT_OID
-global_adapters[str] = StringAdapter()
+global_adapters[str] = global_casters[TEXT_OID] = StringTransformer()
-class IntAdapter(Adapter):
+class IntTransformer(Transformer):
def get_text_adapter(self, cls, conn):
return self.adapt_int
- def adapt_int(self, value):
+ def get_text_caster(self, oid, conn):
+ return self.cast_int
+
+ @staticmethod
+ def adapt_int(value):
return ascii_encode(str(value))[0], NUMERIC_OID
+ @staticmethod
+ def cast_int(value):
+ return int(ascii_decode(value)[0])
+
+
+global_casters[NUMERIC_OID] = global_adapters[int] = IntTransformer()
+
+
+class UnknownCaster(Transformer):
+ """
+ Fallback object to convert unknown types to Python
+ """
+
+ def get_text_caster(self, oid, conn):
+ if conn is None:
+ # we don't have enough info to decode bytes
+ return self.cast_to_bytes
+
+ codec = conn.codec
+ return partial(self.cast_to_str, codec)
+
+ def get_binary_caster(self, oid, conn):
+ return self.cast_to_bytes
-global_adapters[int] = IntAdapter()
+ @staticmethod
+ def cast_to_str(codec, value):
+ return codec.decode(value)[0]
from . import exceptions as exc
from .pq import error_message, DiagnosticField, ExecStatus
-from .adaptation import ValuesAdapter
+from .adaptation import ValuesTransformer
from .utils.queries import query2pg, reorder_params
self.conn = conn
self.binary = binary
self.adapters = {}
+ self.casters = {}
+ self._reset()
+
+ def _reset(self):
self._results = []
self._result = None
+ self._pos = 0
self._iresult = 0
+ self._transformer = ValuesTransformer(self)
def _execute_send(self, query, vars):
# Implement part of execute() before waiting common to sync and async
- self._results = []
- self._result = None
- self._iresult = 0
+ self._reset()
+
codec = self.conn.codec
if isinstance(query, str):
if vars:
if order is not None:
vars = reorder_params(vars, order)
- adapter = ValuesAdapter(self)
- params, types = adapter.adapt_sequence(vars, formats)
+ params, types = self._transformer.adapt_sequence(vars, formats)
self.conn.pgconn.send_query_params(
query, params, param_formats=formats, param_types=types
)
self._iresult += 1
if self._iresult < len(self._results):
self._result = self._results[self._iresult]
+ self._pos = 0
return True
+ def fetchone(self):
+ rv = self._cast_row(self._pos)
+ if rv is not None:
+ self._pos += 1
+ return rv
+
+ def _cast_row(self, n):
+ if self._result is None:
+ return None
+ if n >= self._result.ntuples:
+ return None
+
+ return tuple(self._transformer.cast_row(self._result, n))
+
class Cursor(BaseCursor):
def execute(self, query, vars=None):