]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added result to transformer, fetching data from there
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 18 Apr 2020 06:22:54 +0000 (18:22 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 18 Apr 2020 06:23:42 +0000 (18:23 +1200)
psycopg3/adapt.py
psycopg3/cursor.py

index eadfcd1b4e55ebd0578207ecfe18fd89f15aea9b..73ecba12165373efdbfd1298b12a2b504d19fac6 100644 (file)
@@ -173,6 +173,7 @@ class Transformer:
         self._dumpers_maps: List[DumpersMap] = []
         self._loaders_maps: List[LoadersMap] = []
         self._setup_context(context)
+        self.pgresult = None
 
         # mapping class, fmt -> dump function
         self._dump_funcs: Dict[Tuple[type, Format], DumpFunc] = {}
@@ -228,6 +229,34 @@ class Transformer:
         self._dumpers_maps.append(Dumper.globals)
         self._loaders_maps.append(Loader.globals)
 
+    @property
+    def pgresult(self) -> Optional[pq.PGresult]:
+        return self._pgresult
+
+    @pgresult.setter
+    def pgresult(self, result: Optional[pq.PGresult]) -> None:
+        self._pgresult = result
+        rc = self._row_loaders = []
+
+        self._ntuples: int
+        self._nfields: int
+        if result is None:
+            self._nfields = self._ntuples = 0
+            return
+
+        nf = self._nfields = result.nfields
+        self._ntuples = result.ntuples
+
+        for i in range(nf):
+            oid = result.ftype(i)
+            fmt = result.fformat(i)
+            rc.append(self.get_load_function(oid, fmt))
+
+    def set_row_types(self, types: Iterable[Tuple[int, Format]]) -> None:
+        rc = self._row_loaders = []
+        for oid, fmt in types:
+            rc.append(self.get_load_function(oid, fmt))
+
     def dump_sequence(
         self, objs: Iterable[Any], formats: Iterable[Format]
     ) -> Tuple[List[Optional[bytes]], List[int]]:
@@ -282,10 +311,23 @@ class Transformer:
             f"cannot adapt type {src.__name__} to format {Format(format).name}"
         )
 
-    def set_row_types(self, types: Iterable[Tuple[int, Format]]) -> None:
-        rc = self._row_loaders = []
-        for oid, fmt in types:
-            rc.append(self.get_load_function(oid, fmt))
+    def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
+        res = self.pgresult
+        if res is None:
+            return None
+
+        if row >= self._ntuples:
+            return None
+
+        rv: List[Any] = []
+        for col in range(self._nfields):
+            val = res.get_value(row, col)
+            if val is None:
+                rv.append(None)
+            else:
+                rv.append(self._row_loaders[col](val))
+
+        return tuple(rv)
 
     def load_sequence(
         self, record: Iterable[Optional[bytes]]
index e0f7e50403761c92cd2ad8f793376ece9f894d00..e07ec0d2c1b37596ac4af755626ece96c8a34c95 100644 (file)
@@ -6,7 +6,7 @@ psycopg3 cursor objects
 
 import codecs
 from operator import attrgetter
-from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING
+from typing import Any, List, Optional, Sequence, TYPE_CHECKING
 
 from . import errors as e
 from . import pq
@@ -97,13 +97,8 @@ class BaseCursor:
     def pgresult(self, result: Optional[pq.PGresult]) -> None:
         self._pgresult = result
         if result is not None:
-            self._ntuples = result.ntuples
-            self._nfields = result.nfields
             if self._transformer is not None:
-                self._transformer.set_row_types(
-                    (result.ftype(i), result.fformat(i))
-                    for i in range(self._nfields)
-                )
+                self._transformer.pgresult = result
 
     @property
     def description(self) -> Optional[List[Column]]:
@@ -243,15 +238,6 @@ class BaseCursor:
                 "the last operation didn't produce a result"
             )
 
-    def _load_row(self, n: int) -> Optional[Tuple[Any, ...]]:
-        if n >= self._ntuples:
-            return None
-
-        get_value = self.pgresult.get_value  # type: ignore
-        return self._transformer.load_sequence(
-            get_value(n, i) for i in range(self._nfields)
-        )
-
 
 class Cursor(BaseCursor):
     connection: "Connection"
@@ -296,7 +282,7 @@ class Cursor(BaseCursor):
 
     def fetchone(self) -> Optional[Sequence[Any]]:
         self._check_result()
-        rv = self._load_row(self._pos)
+        rv = self._transformer.load_row(self._pos)
         if rv is not None:
             self._pos += 1
         return rv
@@ -308,7 +294,7 @@ class Cursor(BaseCursor):
 
         rv: List[Sequence[Any]] = []
         while len(rv) < size:
-            row = self._load_row(self._pos)
+            row = self._transformer.load_row(self._pos)
             if row is None:
                 break
             self._pos += 1
@@ -320,7 +306,7 @@ class Cursor(BaseCursor):
         self._check_result()
         rv: List[Sequence[Any]] = []
         while 1:
-            row = self._load_row(self._pos)
+            row = self._transformer.load_row(self._pos)
             if row is None:
                 break
             self._pos += 1
@@ -374,7 +360,7 @@ class AsyncCursor(BaseCursor):
 
     async def fetchone(self) -> Optional[Sequence[Any]]:
         self._check_result()
-        rv = self._load_row(self._pos)
+        rv = self._transformer.load_row(self._pos)
         if rv is not None:
             self._pos += 1
         return rv
@@ -388,7 +374,7 @@ class AsyncCursor(BaseCursor):
 
         rv: List[Sequence[Any]] = []
         while len(rv) < size:
-            row = self._load_row(self._pos)
+            row = self._transformer.load_row(self._pos)
             if row is None:
                 break
             self._pos += 1
@@ -400,7 +386,7 @@ class AsyncCursor(BaseCursor):
         self._check_result()
         rv: List[Sequence[Any]] = []
         while 1:
-            row = self._load_row(self._pos)
+            row = self._transformer.load_row(self._pos)
             if row is None:
                 break
             self._pos += 1