]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Move the row maker as a Transformer attribute
authorDenis Laxalde <denis.laxalde@dalibo.com>
Thu, 11 Feb 2021 11:22:32 +0000 (12:22 +0100)
committerDenis Laxalde <denis.laxalde@dalibo.com>
Thu, 11 Feb 2021 16:12:25 +0000 (17:12 +0100)
Instead of carrying the row maker (_make_row attribute) on the cursor
and possibly calling it to transform each row in cursor methods, we
define a 'make_row' attribute on Transformer that is possibly used in
load_row() and load_rows().

In the Python implementation of Transformer.load_rows(), we use tuple as
as make_row() when the attribute is unset.

In the Cython implementation, we make 'make_row' a plain property with a
'_make_row' attribute under the hood. We finally transform individual or
list of records using self.make_row().

psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/proto.py
psycopg3_c/psycopg3_c/_psycopg3.pyi
psycopg3_c/psycopg3_c/_psycopg3/transform.pyx

index 3663d8ec157924ee34d61554aba5f131f73e39ac..47da6a18aceaf657b15584de92fd130d965d3f08 100644 (file)
@@ -11,7 +11,7 @@ from collections import defaultdict
 from . import pq
 from . import errors as e
 from .oids import INVALID_OID
-from .proto import LoadFunc, AdaptContext
+from .proto import LoadFunc, AdaptContext, Row, RowMaker
 from ._enums import Format
 
 if TYPE_CHECKING:
@@ -38,6 +38,7 @@ class Transformer(AdaptContext):
     __module__ = "psycopg3.adapt"
     _adapters: "AdaptersMap"
     _pgresult: Optional["PGresult"] = None
+    make_row: Optional[RowMaker] = None
 
     def __init__(self, context: Optional[AdaptContext] = None):
 
@@ -157,7 +158,7 @@ class Transformer(AdaptContext):
             dumper = cache[key1] = dumper.upgrade(obj, format)
             return dumper
 
-    def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]:
+    def load_rows(self, row0: int, row1: int) -> List[Row]:
         res = self._pgresult
         if not res:
             raise e.InterfaceError("result not set")
@@ -167,19 +168,23 @@ class Transformer(AdaptContext):
                 f"rows must be included between 0 and {self._ntuples}"
             )
 
-        records: List[Tuple[Any, ...]]
+        records: List[Row]
         records = [None] * (row1 - row0)  # type: ignore[list-item]
+        if self.make_row:
+            mkrow = self.make_row
+        else:
+            mkrow = tuple
         for row in range(row0, row1):
             record: List[Any] = [None] * self._nfields
             for col in range(self._nfields):
                 val = res.get_value(row, col)
                 if val is not None:
                     record[col] = self._row_loaders[col](val)
-            records[row - row0] = tuple(record)
+            records[row - row0] = mkrow(record)
 
         return records
 
-    def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
+    def load_row(self, row: int) -> Optional[Row]:
         res = self._pgresult
         if not res:
             return None
@@ -193,7 +198,7 @@ class Transformer(AdaptContext):
             if val is not None:
                 record[col] = self._row_loaders[col](val)
 
-        return tuple(record)
+        return self.make_row(record) if self.make_row else tuple(record)
 
     def load_sequence(
         self, record: Sequence[Optional[bytes]]
index e4cc116f322c71a0013395ed0e30519e86f90bdf..362fc750e53545d62d766c0d878c83d58e239f38 100644 (file)
@@ -18,7 +18,7 @@ from . import generators
 from .pq import ExecStatus, Format
 from .copy import Copy, AsyncCopy
 from .proto import ConnectionType, Query, Params, PQGen
-from .proto import Row, RowFactory, RowMaker
+from .proto import Row, RowFactory
 from ._column import Column
 from ._queries import PostgresQuery
 from ._preparing import Prepare
@@ -50,7 +50,7 @@ class BaseCursor(Generic[ConnectionType]):
     if sys.version_info >= (3, 7):
         __slots__ = """
             _conn format _adapters arraysize _closed _results _pgresult _pos
-            _iresult _rowcount _pgq _tx _last_query _row_factory _make_row
+            _iresult _rowcount _pgq _tx _last_query _row_factory
             __weakref__
             """.split()
 
@@ -76,7 +76,6 @@ class BaseCursor(Generic[ConnectionType]):
     def _reset(self) -> None:
         self._results: List["PGresult"] = []
         self._pgresult: Optional["PGresult"] = None
-        self._make_row: Optional[RowMaker] = None
         self._pos = 0
         self._iresult = 0
         self._rowcount = -1
@@ -266,7 +265,7 @@ class BaseCursor(Generic[ConnectionType]):
 
         elif res.status == ExecStatus.SINGLE_TUPLE:
             if self._row_factory:
-                self._make_row = self._row_factory(self)
+                self._tx.make_row = self._row_factory(self)
             self.pgresult = res  # will set it on the transformer too
             # TODO: the transformer may do excessive work here: create a
             # path that doesn't clear the loaders every time.
@@ -371,7 +370,7 @@ class BaseCursor(Generic[ConnectionType]):
         self._results = list(results)
         self.pgresult = results[0]
         if self._row_factory:
-            self._make_row = self._row_factory(self)
+            self._tx.make_row = self._row_factory(self)
         nrows = self.pgresult.command_tuples
         if nrows is not None:
             if self._rowcount < 0:
@@ -495,7 +494,7 @@ class Cursor(BaseCursor["Connection"]):
             while self._conn.wait(self._stream_fetchone_gen()):
                 rec = self._tx.load_row(0)
                 assert rec is not None
-                yield self._make_row(rec) if self._make_row else rec
+                yield rec
 
     def fetchone(self) -> Optional[Row]:
         """
@@ -507,8 +506,7 @@ class Cursor(BaseCursor["Connection"]):
         record = self._tx.load_row(self._pos)
         if record is not None:
             self._pos += 1
-            return self._make_row(record) if self._make_row else record
-        return record
+        return record  # type: ignore[no-any-return]
 
     def fetchmany(self, size: int = 0) -> Sequence[Row]:
         """
@@ -525,8 +523,6 @@ class Cursor(BaseCursor["Connection"]):
             self._pos, min(self._pos + size, self.pgresult.ntuples)
         )
         self._pos += len(records)
-        if self._make_row:
-            return list(map(self._make_row, records))
         return records
 
     def fetchall(self) -> Sequence[Row]:
@@ -537,8 +533,6 @@ class Cursor(BaseCursor["Connection"]):
         assert self.pgresult
         records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
         self._pos += self.pgresult.ntuples
-        if self._make_row:
-            return list(map(self._make_row, records))
         return records
 
     def __iter__(self) -> Iterator[Row]:
@@ -551,7 +545,7 @@ class Cursor(BaseCursor["Connection"]):
             if row is None:
                 break
             self._pos += 1
-            yield self._make_row(row) if self._make_row else row
+            yield row
 
     @contextmanager
     def copy(self, statement: Query) -> Iterator[Copy]:
@@ -610,15 +604,14 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             while await self._conn.wait(self._stream_fetchone_gen()):
                 rec = self._tx.load_row(0)
                 assert rec is not None
-                yield self._make_row(rec) if self._make_row else rec
+                yield rec
 
     async def fetchone(self) -> Optional[Row]:
         self._check_result()
         rv = self._tx.load_row(self._pos)
         if rv is not None:
             self._pos += 1
-            return self._make_row(rv) if self._make_row else rv
-        return rv
+        return rv  # type: ignore[no-any-return]
 
     async def fetchmany(self, size: int = 0) -> List[Row]:
         self._check_result()
@@ -630,8 +623,6 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             self._pos, min(self._pos + size, self.pgresult.ntuples)
         )
         self._pos += len(records)
-        if self._make_row:
-            return list(map(self._make_row, records))
         return records
 
     async def fetchall(self) -> List[Row]:
@@ -639,8 +630,6 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
         assert self.pgresult
         records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
         self._pos += self.pgresult.ntuples
-        if self._make_row:
-            return list(map(self._make_row, records))
         return records
 
     async def __aiter__(self) -> AsyncIterator[Row]:
@@ -653,7 +642,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             if row is None:
                 break
             self._pos += 1
-            yield self._make_row(row) if self._make_row else row
+            yield row
 
     @asynccontextmanager
     async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]:
index 517c782ad2e84ef2549c441bc38789ad0360ac63..0581c02393f118a4ca70ed328a7acdc8962fe623 100644 (file)
@@ -45,6 +45,21 @@ Wait states.
 """
 
 
+# Row factories
+
+Row = TypeVar("Row", Tuple[Any, ...], Any)
+
+
+class RowMaker(Protocol):
+    def __call__(self, __values: Sequence[Any]) -> Row:
+        ...
+
+
+class RowFactory(Protocol):
+    def __call__(self, __cursor: "BaseCursor[ConnectionType]") -> RowMaker:
+        ...
+
+
 # Adaptation types
 
 DumpFunc = Callable[[Any], bytes]
@@ -71,6 +86,8 @@ class AdaptContext(Protocol):
 
 
 class Transformer(Protocol):
+    make_row: Optional[RowMaker] = None
+
     def __init__(self, context: Optional[AdaptContext] = None):
         ...
 
@@ -103,10 +120,10 @@ class Transformer(Protocol):
     def get_dumper(self, obj: Any, format: Format) -> "Dumper":
         ...
 
-    def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]:
+    def load_rows(self, row0: int, row1: int) -> List[Row]:
         ...
 
-    def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
+    def load_row(self, row: int) -> Optional[Row]:
         ...
 
     def load_sequence(
@@ -116,18 +133,3 @@ class Transformer(Protocol):
 
     def get_loader(self, oid: int, format: pq.Format) -> "Loader":
         ...
-
-
-# Row factories
-
-Row = TypeVar("Row", Tuple[Any, ...], Any)
-
-
-class RowMaker(Protocol):
-    def __call__(self, __values: Sequence[Any]) -> Row:
-        ...
-
-
-class RowFactory(Protocol):
-    def __call__(self, __cursor: "BaseCursor[ConnectionType]") -> RowMaker:
-        ...
index 6f6c419385e03d845900769c791859158c28e376..fc603dd3e390de68249706e13d5e276b3c2d6d49 100644 (file)
@@ -22,6 +22,10 @@ class Transformer(proto.AdaptContext):
     @property
     def adapters(self) -> AdaptersMap: ...
     @property
+    def make_row(self) -> Optional[proto.RowMaker]: ...
+    @make_row.setter
+    def make_row(self, row_maker: proto.RowMaker) -> None: ...
+    @property
     def pgresult(self) -> Optional[PGresult]: ...
     @pgresult.setter
     def pgresult(self, result: Optional[PGresult]) -> None: ...
@@ -32,8 +36,8 @@ class Transformer(proto.AdaptContext):
         self, params: Sequence[Any], formats: Sequence[Format]
     ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: ...
     def get_dumper(self, obj: Any, format: Format) -> Dumper: ...
-    def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]: ...
-    def load_row(self, row: int) -> Optional[Tuple[Any, ...]]: ...
+    def load_rows(self, row0: int, row1: int) -> List[proto.Row]: ...
+    def load_row(self, row: int) -> Optional[proto.Row]: ...
     def load_sequence(
         self, record: Sequence[Optional[bytes]]
     ) -> Tuple[Any, ...]: ...
index f92c94d50649c620d7200d01c2242c0beb944096..86ac903480c9d8bd557b950162dcea41efd4a93d 100644 (file)
@@ -24,6 +24,7 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
 from psycopg3 import errors as e
 from psycopg3._enums import Format as Pg3Format
 from psycopg3.pq import Format as PqFormat
+from psycopg3.proto import Row, RowMaker
 
 # internal structure: you are not supposed to know this. But it's worth some
 # 10% of the innermost loop, so I'm willing to ask for forgiveness later...
@@ -82,6 +83,7 @@ cdef class Transformer:
     cdef int _nfields, _ntuples
     cdef list _row_dumpers
     cdef list _row_loaders
+    cdef object _make_row
 
     def __cinit__(self, context: Optional["AdaptContext"] = None):
         if context is not None:
@@ -92,6 +94,14 @@ cdef class Transformer:
             self.adapters = global_adapters
             self.connection = None
 
+    @property
+    def make_row(self) -> Optional[RowMaker]:
+        return self._make_row
+
+    @make_row.setter
+    def make_row(self, row_maker: RowMaker) -> None:
+        self._make_row = row_maker
+
     @property
     def pgresult(self) -> Optional[PGresult]:
         return self._pgresult
@@ -271,7 +281,7 @@ cdef class Transformer:
 
         return ps, ts, fs
 
-    def load_rows(self, int row0, int row1) -> List[Tuple[Any, ...]]:
+    def load_rows(self, int row0, int row1) -> List[Row]:
         if self._pgresult is None:
             raise e.InterfaceError("result not set")
 
@@ -331,9 +341,11 @@ cdef class Transformer:
                     Py_INCREF(pyval)
                     PyTuple_SET_ITEM(<object>brecord, col, pyval)
 
+        if self.make_row:
+            return list(map(self.make_row, records))
         return records
 
-    def load_row(self, int row) -> Optional[Tuple[Any, ...]]:
+    def load_row(self, int row) -> Optional[Row]:
         if self._pgresult is None:
             return None
 
@@ -372,6 +384,8 @@ cdef class Transformer:
             Py_INCREF(pyval)
             PyTuple_SET_ITEM(record, col, pyval)
 
+        if self.make_row:
+            return self.make_row(record)
         return record
 
     cpdef object load_sequence(self, record: Sequence[Optional[bytes]]):