--- /dev/null
+"""
+Entry point into the adaptation system.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import codecs
+
+from . import exceptions as exc
+from .pq import Format
+
+INVALID_OID = 0
+TEXT_OID = 25
+NUMERIC_OID = 1700
+FLOAT8_INT = 701
+
+ascii_encode = codecs.lookup("ascii").encode
+utf8_codec = codecs.lookup("utf-8")
+
+
+class ValuesAdapter:
+ """
+ An object that can adapt efficiently a number of value.
+
+ The life cycle of the object is the query, so it is assumed that stuff like
+ the server version or connection encoding will not change. It can have its
+ state so adapting several values of the same type can use optimisations.
+ """
+
+ def __init__(self, context):
+ from .connection import BaseConnection
+ from .cursor import BaseCursor
+
+ if context is None:
+ self.connection = None
+ self.cursor = None
+ elif isinstance(context, BaseConnection):
+ self.connection = context
+ self.cursor = None
+ elif isinstance(context, BaseCursor):
+ self.connection = context.conn
+ self.cursor = context
+ else:
+ raise TypeError(
+ f"the context should be a connection or cursor,"
+ f" got {type(context).__name__}")
+
+ # mapping class -> adaptation function
+ self._adapt_funcs = {}
+
+ def adapt_sequence(self, objs, fmts):
+ out = []
+ types = []
+
+ for var, fmt in zip(objs, fmts):
+ data, oid = self.adapt(var, fmt)
+ out.append(data)
+ types.append(oid)
+
+ return out, types
+
+ def adapt(self, obj, fmt):
+ if obj is None:
+ return None, TEXT_OID
+
+ cls = type(obj)
+ try:
+ func = self._adapt_funcs[cls, fmt]
+ except KeyError:
+ pass
+ else:
+ return func(obj)
+
+ adapter = self.lookup_adapter(cls)
+ if fmt == Format.TEXT:
+ func = self._adapt_funcs[cls, fmt] = adapter.get_text_adapter(
+ cls, self.connection
+ )
+ else:
+ assert fmt == Format.BINARY
+ func = self._adapt_funcs[cls, fmt] = adapter.get_binary_adapter(
+ cls, self.connection
+ )
+
+ return func(obj)
+
+ def lookup_adapter(self, cls):
+ cur = self.cursor
+ if (
+ cur is not None
+ and cls in cur.adapters
+ ):
+ return cur.adapters[cls]
+
+ conn = self.connection
+ if (
+ conn is not None
+ and cls in conn.adapters
+ ):
+ return conn.adapters[cls]
+
+ if cls in global_adapters:
+ return global_adapters[cls]
+
+ raise exc.ProgrammingError(f"cannot adapt type {cls.__name__}")
+
+
+global_adapters = {}
+
+
+class Adapter:
+ def get_text_adapter(self, cls, conn):
+ raise exc.NotSupportedError(
+ f"the type {cls.__name__} doesn't support text adaptation"
+ )
+
+ def get_binary_adapter(self, cls, conn):
+ raise exc.NotSupportedError(
+ f"the type {cls.__name__} doesn't support binary adaptation"
+ )
+
+
+class StringAdapter(Adapter):
+ def get_text_adapter(self, cls, conn):
+ codec = conn.codec if conn is not None else utf8_codec
+
+ def adapt_text(value):
+ return codec.encode(value)[0], TEXT_OID
+
+ return adapt_text
+
+ # format is the same in binary and text
+ get_binary_adapter = get_text_adapter
+
+
+global_adapters[str] = StringAdapter()
+
+
+class IntAdapter(Adapter):
+ def get_text_adapter(self, cls, conn):
+ return self.adapt_int
+
+ def adapt_int(self, value):
+ return ascii_encode(str(value))[0], NUMERIC_OID
+
+
+global_adapters[int] = IntAdapter()
from . import exceptions as exc
from .pq import error_message, DiagnosticField, ExecStatus
+from .adaptation import ValuesAdapter
from .utils.queries import query2pg, reorder_params
def __init__(self, conn, binary=False):
self.conn = conn
self.binary = binary
+ self.adapters = {}
self._results = []
self._result = None
self._iresult = 0
if vars:
if order is not None:
vars = reorder_params(vars, order)
- params = self._adapt_sequence(vars, formats)
+ adapter = ValuesAdapter(self)
+ params, types = adapter.adapt_sequence(vars, formats)
self.conn.pgconn.send_query_params(
- query, params, param_formats=formats
+ query, params, param_formats=formats, param_types=types
)
else:
self.conn.pgconn.send_query(query)
self._result = self._results[self._iresult]
return True
- def _adapt_sequence(self, vars, formats):
- # TODO: stub. Need adaptation layer.
- codec = self.conn.codec
- out = [
- codec.encode(str(v))[0] if v is not None else None for v in vars
- ]
- return out
-
class Cursor(BaseCursor):
def execute(self, query, vars=None):
def test_exec_params_nulls(pq, pgconn):
- if pgconn.server_version < 100000:
- pytest.xfail("it doesn't work on pg < 10")
- res = pgconn.exec_params(b"select $1, $2, $3", [b"hi", b"", None])
+ res = pgconn.exec_params(
+ b"select $1::text, $2::text, $3::text", [b"hi", b"", None]
+ )
assert res.status == pq.ExecStatus.TUPLES_OK
assert res.get_value(0, 0) == b"hi"
assert res.get_value(0, 1) == b""