From: Daniele Varrazzo Date: Fri, 27 Mar 2020 09:02:42 +0000 (+1300) Subject: Added sketch of adaptation layer X-Git-Tag: 3.0.dev0~666 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=88b3a29d9f94eac6fa58df09d0162a03a2a0c0f2;p=thirdparty%2Fpsycopg.git Added sketch of adaptation layer --- diff --git a/psycopg3/adaptation.py b/psycopg3/adaptation.py new file mode 100644 index 000000000..afc33c7d8 --- /dev/null +++ b/psycopg3/adaptation.py @@ -0,0 +1,147 @@ +""" +Entry point into the adaptation system. +""" + +# Copyright (C) 2020 The Psycopg Team + +import codecs + +from . import exceptions as exc +from .pq import Format + +INVALID_OID = 0 +TEXT_OID = 25 +NUMERIC_OID = 1700 +FLOAT8_INT = 701 + +ascii_encode = codecs.lookup("ascii").encode +utf8_codec = codecs.lookup("utf-8") + + +class ValuesAdapter: + """ + An object that can adapt efficiently a number of value. + + The life cycle of the object is the query, so it is assumed that stuff like + the server version or connection encoding will not change. It can have its + state so adapting several values of the same type can use optimisations. + """ + + def __init__(self, context): + from .connection import BaseConnection + from .cursor import BaseCursor + + if context is None: + self.connection = None + self.cursor = None + elif isinstance(context, BaseConnection): + self.connection = context + self.cursor = None + elif isinstance(context, BaseCursor): + self.connection = context.conn + self.cursor = context + else: + raise TypeError( + f"the context should be a connection or cursor," + f" got {type(context).__name__}") + + # mapping class -> adaptation function + self._adapt_funcs = {} + + def adapt_sequence(self, objs, fmts): + out = [] + types = [] + + for var, fmt in zip(objs, fmts): + data, oid = self.adapt(var, fmt) + out.append(data) + types.append(oid) + + return out, types + + def adapt(self, obj, fmt): + if obj is None: + return None, TEXT_OID + + cls = type(obj) + try: + func = self._adapt_funcs[cls, fmt] + except KeyError: + pass + else: + return func(obj) + + adapter = self.lookup_adapter(cls) + if fmt == Format.TEXT: + func = self._adapt_funcs[cls, fmt] = adapter.get_text_adapter( + cls, self.connection + ) + else: + assert fmt == Format.BINARY + func = self._adapt_funcs[cls, fmt] = adapter.get_binary_adapter( + cls, self.connection + ) + + return func(obj) + + def lookup_adapter(self, cls): + cur = self.cursor + if ( + cur is not None + and cls in cur.adapters + ): + return cur.adapters[cls] + + conn = self.connection + if ( + conn is not None + and cls in conn.adapters + ): + return conn.adapters[cls] + + if cls in global_adapters: + return global_adapters[cls] + + raise exc.ProgrammingError(f"cannot adapt type {cls.__name__}") + + +global_adapters = {} + + +class Adapter: + def get_text_adapter(self, cls, conn): + raise exc.NotSupportedError( + f"the type {cls.__name__} doesn't support text adaptation" + ) + + def get_binary_adapter(self, cls, conn): + raise exc.NotSupportedError( + f"the type {cls.__name__} doesn't support binary adaptation" + ) + + +class StringAdapter(Adapter): + def get_text_adapter(self, cls, conn): + codec = conn.codec if conn is not None else utf8_codec + + def adapt_text(value): + return codec.encode(value)[0], TEXT_OID + + return adapt_text + + # format is the same in binary and text + get_binary_adapter = get_text_adapter + + +global_adapters[str] = StringAdapter() + + +class IntAdapter(Adapter): + def get_text_adapter(self, cls, conn): + return self.adapt_int + + def adapt_int(self, value): + return ascii_encode(str(value))[0], NUMERIC_OID + + +global_adapters[int] = IntAdapter() diff --git a/psycopg3/connection.py b/psycopg3/connection.py index 86ef56f0b..874e2e448 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -29,6 +29,7 @@ class BaseConnection: def __init__(self, pgconn): self.pgconn = pgconn self.cursor_factory = None + self.adapters = {} # name of the postgres encoding (in bytes) self._pgenc = None diff --git a/psycopg3/cursor.py b/psycopg3/cursor.py index 603ffdcf9..3bbb8a96c 100644 --- a/psycopg3/cursor.py +++ b/psycopg3/cursor.py @@ -6,6 +6,7 @@ psycopg3 cursor objects from . import exceptions as exc from .pq import error_message, DiagnosticField, ExecStatus +from .adaptation import ValuesAdapter from .utils.queries import query2pg, reorder_params @@ -13,6 +14,7 @@ class BaseCursor: def __init__(self, conn, binary=False): self.conn = conn self.binary = binary + self.adapters = {} self._results = [] self._result = None self._iresult = 0 @@ -33,9 +35,10 @@ class BaseCursor: if vars: if order is not None: vars = reorder_params(vars, order) - params = self._adapt_sequence(vars, formats) + adapter = ValuesAdapter(self) + params, types = adapter.adapt_sequence(vars, formats) self.conn.pgconn.send_query_params( - query, params, param_formats=formats + query, params, param_formats=formats, param_types=types ) else: self.conn.pgconn.send_query(query) @@ -83,14 +86,6 @@ class BaseCursor: self._result = self._results[self._iresult] return True - def _adapt_sequence(self, vars, formats): - # TODO: stub. Need adaptation layer. - codec = self.conn.codec - out = [ - codec.encode(str(v))[0] if v is not None else None for v in vars - ] - return out - class Cursor(BaseCursor): def execute(self, query, vars=None): diff --git a/tests/pq/test_exec.py b/tests/pq/test_exec.py index e7579d71b..af8322e3e 100644 --- a/tests/pq/test_exec.py +++ b/tests/pq/test_exec.py @@ -38,9 +38,9 @@ def test_exec_params_types(pq, pgconn): def test_exec_params_nulls(pq, pgconn): - if pgconn.server_version < 100000: - pytest.xfail("it doesn't work on pg < 10") - res = pgconn.exec_params(b"select $1, $2, $3", [b"hi", b"", None]) + res = pgconn.exec_params( + b"select $1::text, $2::text, $3::text", [b"hi", b"", None] + ) assert res.status == pq.ExecStatus.TUPLES_OK assert res.get_value(0, 0) == b"hi" assert res.get_value(0, 1) == b"" diff --git a/tests/test_async_cursor.py b/tests/test_async_cursor.py index 4c9e0e587..4c0dfd57c 100644 --- a/tests/test_async_cursor.py +++ b/tests/test_async_cursor.py @@ -1,6 +1,3 @@ -import pytest - - def test_execute_many(aconn, loop): cur = aconn.cursor() rv = loop.run_until_complete(cur.execute("select 'foo'; select 'bar'")) @@ -13,8 +10,6 @@ def test_execute_many(aconn, loop): def test_execute_sequence(aconn, loop): - if aconn.pgconn.server_version < 100000: - pytest.xfail("it doesn't work on pg < 10") cur = aconn.cursor() rv = loop.run_until_complete( cur.execute("select %s, %s, %s", [1, "foo", None]) diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 658153ff9..52fc3f817 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,6 +1,3 @@ -import pytest - - def test_execute_many(conn): cur = conn.cursor() rv = cur.execute("select 'foo'; select 'bar'") @@ -13,8 +10,6 @@ def test_execute_many(conn): def test_execute_sequence(conn): - if conn.pgconn.server_version < 100000: - pytest.xfail("it doesn't work on pg < 10") cur = conn.cursor() rv = cur.execute("select %s, %s, %s", [1, "foo", None]) assert rv is cur