]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added basic typecasting and cursor.fetchone()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Mar 2020 12:26:57 +0000 (01:26 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Mar 2020 12:33:51 +0000 (01:33 +1300)
psycopg3/adaptation.py
psycopg3/connection.py
psycopg3/cursor.py
tests/test_cursor.py

index afc33c7d8a1fdedc6263bc19ff3f23b15f484a02..978bb98baec1048842b80cfdc987912dfb7a95a7 100644 (file)
@@ -5,6 +5,7 @@ Entry point into the adaptation system.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
+from functools import partial
 
 from . import exceptions as exc
 from .pq import Format
@@ -15,10 +16,14 @@ NUMERIC_OID = 1700
 FLOAT8_INT = 701
 
 ascii_encode = codecs.lookup("ascii").encode
+ascii_decode = codecs.lookup("ascii").decode
 utf8_codec = codecs.lookup("utf-8")
 
+global_adapters = {}
+global_casters = {}
+
 
-class ValuesAdapter:
+class ValuesTransformer:
     """
     An object that can adapt efficiently a number of value.
 
@@ -43,11 +48,38 @@ class ValuesAdapter:
         else:
             raise TypeError(
                 f"the context should be a connection or cursor,"
-                f" got {type(context).__name__}")
+                f" got {type(context).__name__}"
+            )
 
-        # mapping class -> adaptation function
+        # mapping class, fmt -> adaptation function
         self._adapt_funcs = {}
 
+        # mapping oid, fmt -> cast function
+        self._cast_funcs = {}
+
+        # The result to return values from
+        self._result = None
+
+        # sequence of cast function from value to python
+        # the length of the result columns
+        self._row_casters = None
+
+    @property
+    def result(self):
+        return self._result
+
+    @result.setter
+    def result(self, result):
+        if self._result is result:
+            return
+
+        rc = self._row_casters = []
+        for c in range(result.nfields):
+            oid = result.ftype(c)
+            fmt = result.fformat(c)
+            func = self.get_cast_function(oid, fmt)
+            rc.append(func)
+
     def adapt_sequence(self, objs, fmts):
         out = []
         types = []
@@ -64,39 +96,35 @@ class ValuesAdapter:
             return None, TEXT_OID
 
         cls = type(obj)
+        func = self.get_adapt_function(cls, fmt)
+        return func(obj)
+
+    def get_adapt_function(self, cls, fmt):
         try:
-            func = self._adapt_funcs[cls, fmt]
+            return self._adapt_funcs[cls, fmt]
         except KeyError:
             pass
-        else:
-            return func(obj)
 
-        adapter = self.lookup_adapter(cls)
+        xf = self.lookup_adapter(cls)
         if fmt == Format.TEXT:
-            func = self._adapt_funcs[cls, fmt] = adapter.get_text_adapter(
+            func = self._adapt_funcs[cls, fmt] = xf.get_text_adapter(
                 cls, self.connection
             )
         else:
             assert fmt == Format.BINARY
-            func = self._adapt_funcs[cls, fmt] = adapter.get_binary_adapter(
+            func = self._adapt_funcs[cls, fmt] = xf.get_binary_adapter(
                 cls, self.connection
             )
 
-        return func(obj)
+        return func
 
     def lookup_adapter(self, cls):
         cur = self.cursor
-        if (
-            cur is not None
-            and cls in cur.adapters
-        ):
+        if cur is not None and cls in cur.adapters:
             return cur.adapters[cls]
 
         conn = self.connection
-        if (
-            conn is not None
-            and cls in conn.adapters
-        ):
+        if conn is not None and cls in conn.adapters:
             return conn.adapters[cls]
 
         if cls in global_adapters:
@@ -104,11 +132,50 @@ class ValuesAdapter:
 
         raise exc.ProgrammingError(f"cannot adapt type {cls.__name__}")
 
+    def cast_row(self, result, n):
+        self.result = result
 
-global_adapters = {}
+        for col, func in enumerate(self._row_casters):
+            v = result.get_value(n, col)
+            if v is not None:
+                v = func(v)
+            yield v
+
+    def get_cast_function(self, oid, fmt):
+        try:
+            return self._cast_funcs[oid, fmt]
+        except KeyError:
+            pass
+
+        xf = self.lookup_caster(oid)
+        if fmt == Format.TEXT:
+            func = self._cast_funcs[oid, fmt] = xf.get_text_caster(
+                oid, self.connection
+            )
+        else:
+            assert fmt == Format.BINARY
+            func = self._cast_funcs[oid, fmt] = xf.get_binary_caster(
+                oid, self.connection
+            )
+
+        return func
 
+    def lookup_caster(self, oid):
+        cur = self.cursor
+        if cur is not None and oid in cur.casters:
+            return cur.casters[oid]
+
+        conn = self.connection
+        if conn is not None and oid in conn.casters:
+            return conn.casters[oid]
 
-class Adapter:
+        if oid in global_casters:
+            return global_casters[oid]
+        else:
+            return UnknownCaster()
+
+
+class Transformer:
     def get_text_adapter(self, cls, conn):
         raise exc.NotSupportedError(
             f"the type {cls.__name__} doesn't support text adaptation"
@@ -119,29 +186,85 @@ class Adapter:
             f"the type {cls.__name__} doesn't support binary adaptation"
         )
 
+    def get_text_caster(self, oid, conn):
+        raise exc.NotSupportedError(
+            f"the PostgreSQL type {oid} doesn't support cast from text"
+        )
+
+    def get_binary_caster(self, oid, conn):
+        raise exc.NotSupportedError(
+            f"the PostgreSQL type {oid} doesn't support cast from binary"
+        )
+
+    @staticmethod
+    def cast_to_bytes(value):
+        return value
+
+    @staticmethod
+    def cast_to_str(codec, value):
+        return codec.decode(value)[0]
 
-class StringAdapter(Adapter):
+
+class StringTransformer(Transformer):
     def get_text_adapter(self, cls, conn):
         codec = conn.codec if conn is not None else utf8_codec
+        return partial(self.adapt_str, codec)
 
-        def adapt_text(value):
-            return codec.encode(value)[0], TEXT_OID
+    def get_text_caster(self, oid, conn):
+        if conn is None or conn.pgenc == b"SQL_ASCII":
+            # we don't have enough info to decode bytes
+            return self.unparsed_bytes
 
-        return adapt_text
+        codec = conn.codec
+        return partial(self.cast_to_str, codec)
 
     # format is the same in binary and text
     get_binary_adapter = get_text_adapter
+    get_binary_caster = get_text_caster
+
+    @staticmethod
+    def adapt_str(codec, value):
+        return codec.encode(value)[0], TEXT_OID
 
 
-global_adapters[str] = StringAdapter()
+global_adapters[str] = global_casters[TEXT_OID] = StringTransformer()
 
 
-class IntAdapter(Adapter):
+class IntTransformer(Transformer):
     def get_text_adapter(self, cls, conn):
         return self.adapt_int
 
-    def adapt_int(self, value):
+    def get_text_caster(self, oid, conn):
+        return self.cast_int
+
+    @staticmethod
+    def adapt_int(value):
         return ascii_encode(str(value))[0], NUMERIC_OID
 
+    @staticmethod
+    def cast_int(value):
+        return int(ascii_decode(value)[0])
+
+
+global_casters[NUMERIC_OID] = global_adapters[int] = IntTransformer()
+
+
+class UnknownCaster(Transformer):
+    """
+    Fallback object to convert unknown types to Python
+    """
+
+    def get_text_caster(self, oid, conn):
+        if conn is None:
+            # we don't have enough info to decode bytes
+            return self.cast_to_bytes
+
+        codec = conn.codec
+        return partial(self.cast_to_str, codec)
+
+    def get_binary_caster(self, oid, conn):
+        return self.cast_to_bytes
 
-global_adapters[int] = IntAdapter()
+    @staticmethod
+    def cast_to_str(codec, value):
+        return codec.decode(value)[0]
index 874e2e44849be5a4027c98a442fc967e95767119..eff6503e4f17985c230e553b20125f76ae84138f 100644 (file)
@@ -30,8 +30,9 @@ class BaseConnection:
         self.pgconn = pgconn
         self.cursor_factory = None
         self.adapters = {}
+        self.casters = {}
         # name of the postgres encoding (in bytes)
-        self._pgenc = None
+        self.pgenc = None
 
     def cursor(self, name=None):
         return self.cursor_factory(self)
@@ -40,10 +41,11 @@ class BaseConnection:
     def codec(self):
         # TODO: utf8 fastpath?
         pgenc = self.pgconn.parameter_status(b"client_encoding")
-        if self._pgenc != pgenc:
+        if self.pgenc != pgenc:
             # for unknown encodings and SQL_ASCII be strict and use ascii
             pyenc = pq.py_codecs.get(pgenc.decode("ascii"), "ascii")
             self._codec = codecs.lookup(pyenc)
+            self.pgenc = pgenc
         return self._codec
 
     def encode(self, s):
index 3bbb8a96c22297c7d137a8a7db0250422fa3740b..030948d74b6d33fa6f8451774bd367ce55abd966 100644 (file)
@@ -6,7 +6,7 @@ psycopg3 cursor objects
 
 from . import exceptions as exc
 from .pq import error_message, DiagnosticField, ExecStatus
-from .adaptation import ValuesAdapter
+from .adaptation import ValuesTransformer
 from .utils.queries import query2pg, reorder_params
 
 
@@ -15,15 +15,20 @@ class BaseCursor:
         self.conn = conn
         self.binary = binary
         self.adapters = {}
+        self.casters = {}
+        self._reset()
+
+    def _reset(self):
         self._results = []
         self._result = None
+        self._pos = 0
         self._iresult = 0
+        self._transformer = ValuesTransformer(self)
 
     def _execute_send(self, query, vars):
         # Implement part of execute() before waiting common to sync and async
-        self._results = []
-        self._result = None
-        self._iresult = 0
+        self._reset()
+
         codec = self.conn.codec
 
         if isinstance(query, str):
@@ -35,8 +40,7 @@ class BaseCursor:
         if vars:
             if order is not None:
                 vars = reorder_params(vars, order)
-            adapter = ValuesAdapter(self)
-            params, types = adapter.adapt_sequence(vars, formats)
+            params, types = self._transformer.adapt_sequence(vars, formats)
             self.conn.pgconn.send_query_params(
                 query, params, param_formats=formats, param_types=types
             )
@@ -84,8 +88,23 @@ class BaseCursor:
         self._iresult += 1
         if self._iresult < len(self._results):
             self._result = self._results[self._iresult]
+            self._pos = 0
             return True
 
+    def fetchone(self):
+        rv = self._cast_row(self._pos)
+        if rv is not None:
+            self._pos += 1
+        return rv
+
+    def _cast_row(self, n):
+        if self._result is None:
+            return None
+        if n >= self._result.ntuples:
+            return None
+
+        return tuple(self._transformer.cast_row(self._result, n))
+
 
 class Cursor(BaseCursor):
     def execute(self, query, vars=None):
index 52fc3f8172f9124a1016278c055cdf8c2525c084..f99c6868ed15acad66df95332e04290a95780480 100644 (file)
@@ -18,3 +18,14 @@ def test_execute_sequence(conn):
     assert cur._result.get_value(0, 1) == b"foo"
     assert cur._result.get_value(0, 2) is None
     assert cur.nextset() is None
+
+
+def test_fetchone(conn):
+    cur = conn.cursor()
+    cur.execute("select %s, %s, %s", [1, "foo", None])
+    row = cur.fetchone()
+    assert row[0] == 1
+    assert row[1] == "foo"
+    assert row[2] is None
+    row = cur.fetchone()
+    assert row is None