]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added sketch of adaptation layer
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Mar 2020 09:02:42 +0000 (22:02 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Mar 2020 12:33:51 +0000 (01:33 +1300)
psycopg3/adaptation.py [new file with mode: 0644]
psycopg3/connection.py
psycopg3/cursor.py
tests/pq/test_exec.py
tests/test_async_cursor.py
tests/test_cursor.py

diff --git a/psycopg3/adaptation.py b/psycopg3/adaptation.py
new file mode 100644 (file)
index 0000000..afc33c7
--- /dev/null
@@ -0,0 +1,147 @@
+"""
+Entry point into the adaptation system.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import codecs
+
+from . import exceptions as exc
+from .pq import Format
+
+INVALID_OID = 0
+TEXT_OID = 25
+NUMERIC_OID = 1700
+FLOAT8_INT = 701
+
+ascii_encode = codecs.lookup("ascii").encode
+utf8_codec = codecs.lookup("utf-8")
+
+
+class ValuesAdapter:
+    """
+    An object that can adapt efficiently a number of value.
+
+    The life cycle of the object is the query, so it is assumed that stuff like
+    the server version or connection encoding will not change. It can have its
+    state so adapting several values of the same type can use optimisations.
+    """
+
+    def __init__(self, context):
+        from .connection import BaseConnection
+        from .cursor import BaseCursor
+
+        if context is None:
+            self.connection = None
+            self.cursor = None
+        elif isinstance(context, BaseConnection):
+            self.connection = context
+            self.cursor = None
+        elif isinstance(context, BaseCursor):
+            self.connection = context.conn
+            self.cursor = context
+        else:
+            raise TypeError(
+                f"the context should be a connection or cursor,"
+                f" got {type(context).__name__}")
+
+        # mapping class -> adaptation function
+        self._adapt_funcs = {}
+
+    def adapt_sequence(self, objs, fmts):
+        out = []
+        types = []
+
+        for var, fmt in zip(objs, fmts):
+            data, oid = self.adapt(var, fmt)
+            out.append(data)
+            types.append(oid)
+
+        return out, types
+
+    def adapt(self, obj, fmt):
+        if obj is None:
+            return None, TEXT_OID
+
+        cls = type(obj)
+        try:
+            func = self._adapt_funcs[cls, fmt]
+        except KeyError:
+            pass
+        else:
+            return func(obj)
+
+        adapter = self.lookup_adapter(cls)
+        if fmt == Format.TEXT:
+            func = self._adapt_funcs[cls, fmt] = adapter.get_text_adapter(
+                cls, self.connection
+            )
+        else:
+            assert fmt == Format.BINARY
+            func = self._adapt_funcs[cls, fmt] = adapter.get_binary_adapter(
+                cls, self.connection
+            )
+
+        return func(obj)
+
+    def lookup_adapter(self, cls):
+        cur = self.cursor
+        if (
+            cur is not None
+            and cls in cur.adapters
+        ):
+            return cur.adapters[cls]
+
+        conn = self.connection
+        if (
+            conn is not None
+            and cls in conn.adapters
+        ):
+            return conn.adapters[cls]
+
+        if cls in global_adapters:
+            return global_adapters[cls]
+
+        raise exc.ProgrammingError(f"cannot adapt type {cls.__name__}")
+
+
+global_adapters = {}
+
+
+class Adapter:
+    def get_text_adapter(self, cls, conn):
+        raise exc.NotSupportedError(
+            f"the type {cls.__name__} doesn't support text adaptation"
+        )
+
+    def get_binary_adapter(self, cls, conn):
+        raise exc.NotSupportedError(
+            f"the type {cls.__name__} doesn't support binary adaptation"
+        )
+
+
+class StringAdapter(Adapter):
+    def get_text_adapter(self, cls, conn):
+        codec = conn.codec if conn is not None else utf8_codec
+
+        def adapt_text(value):
+            return codec.encode(value)[0], TEXT_OID
+
+        return adapt_text
+
+    # format is the same in binary and text
+    get_binary_adapter = get_text_adapter
+
+
+global_adapters[str] = StringAdapter()
+
+
+class IntAdapter(Adapter):
+    def get_text_adapter(self, cls, conn):
+        return self.adapt_int
+
+    def adapt_int(self, value):
+        return ascii_encode(str(value))[0], NUMERIC_OID
+
+
+global_adapters[int] = IntAdapter()
index 86ef56f0bb1c73a2ee9a617598ded38c212e51c0..874e2e44849be5a4027c98a442fc967e95767119 100644 (file)
@@ -29,6 +29,7 @@ class BaseConnection:
     def __init__(self, pgconn):
         self.pgconn = pgconn
         self.cursor_factory = None
+        self.adapters = {}
         # name of the postgres encoding (in bytes)
         self._pgenc = None
 
index 603ffdcf9b63c206a6ec19658694fc22e65feabc..3bbb8a96c22297c7d137a8a7db0250422fa3740b 100644 (file)
@@ -6,6 +6,7 @@ psycopg3 cursor objects
 
 from . import exceptions as exc
 from .pq import error_message, DiagnosticField, ExecStatus
+from .adaptation import ValuesAdapter
 from .utils.queries import query2pg, reorder_params
 
 
@@ -13,6 +14,7 @@ class BaseCursor:
     def __init__(self, conn, binary=False):
         self.conn = conn
         self.binary = binary
+        self.adapters = {}
         self._results = []
         self._result = None
         self._iresult = 0
@@ -33,9 +35,10 @@ class BaseCursor:
         if vars:
             if order is not None:
                 vars = reorder_params(vars, order)
-            params = self._adapt_sequence(vars, formats)
+            adapter = ValuesAdapter(self)
+            params, types = adapter.adapt_sequence(vars, formats)
             self.conn.pgconn.send_query_params(
-                query, params, param_formats=formats
+                query, params, param_formats=formats, param_types=types
             )
         else:
             self.conn.pgconn.send_query(query)
@@ -83,14 +86,6 @@ class BaseCursor:
             self._result = self._results[self._iresult]
             return True
 
-    def _adapt_sequence(self, vars, formats):
-        # TODO: stub. Need adaptation layer.
-        codec = self.conn.codec
-        out = [
-            codec.encode(str(v))[0] if v is not None else None for v in vars
-        ]
-        return out
-
 
 class Cursor(BaseCursor):
     def execute(self, query, vars=None):
index e7579d71b1f809d4411c2271ad7fb97c8af1015d..af8322e3ed8e00eac89ba40bb3300b0e0d598a42 100644 (file)
@@ -38,9 +38,9 @@ def test_exec_params_types(pq, pgconn):
 
 
 def test_exec_params_nulls(pq, pgconn):
-    if pgconn.server_version < 100000:
-        pytest.xfail("it doesn't work on pg < 10")
-    res = pgconn.exec_params(b"select $1, $2, $3", [b"hi", b"", None])
+    res = pgconn.exec_params(
+        b"select $1::text, $2::text, $3::text", [b"hi", b"", None]
+    )
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == b"hi"
     assert res.get_value(0, 1) == b""
index 4c9e0e587a85c54f5a20e6eb2d3fb84fa64c1714..4c0dfd57caf5d8f01029b0763474879fb6739ed0 100644 (file)
@@ -1,6 +1,3 @@
-import pytest
-
-
 def test_execute_many(aconn, loop):
     cur = aconn.cursor()
     rv = loop.run_until_complete(cur.execute("select 'foo'; select 'bar'"))
@@ -13,8 +10,6 @@ def test_execute_many(aconn, loop):
 
 
 def test_execute_sequence(aconn, loop):
-    if aconn.pgconn.server_version < 100000:
-        pytest.xfail("it doesn't work on pg < 10")
     cur = aconn.cursor()
     rv = loop.run_until_complete(
         cur.execute("select %s, %s, %s", [1, "foo", None])
index 658153ff9530e7af65b099dab6895c4ee0d401d2..52fc3f8172f9124a1016278c055cdf8c2525c084 100644 (file)
@@ -1,6 +1,3 @@
-import pytest
-
-
 def test_execute_many(conn):
     cur = conn.cursor()
     rv = cur.execute("select 'foo'; select 'bar'")
@@ -13,8 +10,6 @@ def test_execute_many(conn):
 
 
 def test_execute_sequence(conn):
-    if conn.pgconn.server_version < 100000:
-        pytest.xfail("it doesn't work on pg < 10")
     cur = conn.cursor()
     rv = cur.execute("select %s, %s, %s", [1, "foo", None])
     assert rv is cur