]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Move 'make_row' attribute from Transformer to Cursor
authorDenis Laxalde <denis.laxalde@dalibo.com>
Wed, 14 Apr 2021 14:02:53 +0000 (16:02 +0200)
committerDenis Laxalde <denis.laxalde@dalibo.com>
Fri, 23 Apr 2021 06:50:16 +0000 (08:50 +0200)
Having this attribute defined in Transformer protocol will be
problematic when making the RowMaker protocol generic on Row because
we'll then have to also make Transformer generic on Row and
"propagating" the type variable will produce a lot of churn in the code
base.

Also, the variance (covariant) of Row in RowMaker will conflict with that
in Transformer (invariant).

On the other hand, keeping the RowMaker value and its underlying
RowFactory together/attached to the same object (now Cursor) seems
safer for consistency.

psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/proto.py
psycopg3/psycopg3/server_cursor.py
psycopg3_c/psycopg3_c/_psycopg3.pyi
psycopg3_c/psycopg3_c/_psycopg3/transform.pyx
tests/test_rows.py

index 10ce79b72cc0d93ac065cfb19fde2a06bb616fc0..6f15f6d557d6f485c84580ddd723efa25a635076 100644 (file)
@@ -38,7 +38,6 @@ class Transformer(AdaptContext):
     __module__ = "psycopg3.adapt"
     _adapters: "AdaptersMap"
     _pgresult: Optional["PGresult"] = None
-    make_row: RowMaker = tuple
 
     def __init__(self, context: Optional[AdaptContext] = None):
 
@@ -162,7 +161,7 @@ class Transformer(AdaptContext):
             dumper = cache[key1] = dumper.upgrade(obj, format)
             return dumper
 
-    def load_rows(self, row0: int, row1: int) -> List[Row]:
+    def load_rows(self, row0: int, row1: int, make_row: RowMaker) -> List[Row]:
         res = self._pgresult
         if not res:
             raise e.InterfaceError("result not set")
@@ -179,11 +178,11 @@ class Transformer(AdaptContext):
                 val = res.get_value(row, col)
                 if val is not None:
                     record[col] = self._row_loaders[col](val)
-            records.append(self.make_row(record))
+            records.append(make_row(record))
 
         return records
 
-    def load_row(self, row: int) -> Optional[Row]:
+    def load_row(self, row: int, make_row: RowMaker) -> Optional[Row]:
         res = self._pgresult
         if not res:
             return None
@@ -197,7 +196,7 @@ class Transformer(AdaptContext):
             if val is not None:
                 record[col] = self._row_loaders[col](val)
 
-        return self.make_row(record)  # type: ignore[no-any-return]
+        return make_row(record)  # type: ignore[no-any-return]
 
     def load_sequence(
         self, record: Sequence[Optional[bytes]]
index cb157f9df2e4d76c72b3f65d5959b3db0641d0ff..e6e981b06d7b4aa018b6bb6559ac6a4cc8a9b8c6 100644 (file)
@@ -48,7 +48,7 @@ class BaseCursor(Generic[ConnectionType]):
     if sys.version_info >= (3, 7):
         __slots__ = """
             _conn format _adapters arraysize _closed _results pgresult _pos
-            _iresult _rowcount _pgq _tx _last_query _row_factory
+            _iresult _rowcount _pgq _tx _last_query _row_factory _make_row
             __weakref__
             """.split()
 
@@ -165,7 +165,7 @@ class BaseCursor(Generic[ConnectionType]):
         if self._iresult < len(self._results):
             self.pgresult = self._results[self._iresult]
             self._tx.set_pgresult(self._results[self._iresult])
-            self._tx.make_row = self._row_factory(self)
+            self._make_row = self._row_factory(self)
             self._pos = 0
             nrows = self.pgresult.command_tuples
             self._rowcount = nrows if nrows is not None else -1
@@ -182,7 +182,7 @@ class BaseCursor(Generic[ConnectionType]):
     def row_factory(self, row_factory: RowFactory) -> None:
         self._row_factory = row_factory
         if self.pgresult:
-            self._tx.make_row = row_factory(self)
+            self._make_row = row_factory(self)
 
     #
     # Generators for the high level operations on the cursor
@@ -279,7 +279,7 @@ class BaseCursor(Generic[ConnectionType]):
             self.pgresult = res
             self._tx.set_pgresult(res, set_loaders=first)
             if first:
-                self._tx.make_row = self._row_factory(self)
+                self._make_row = self._row_factory(self)
             return res
 
         elif res.status in (ExecStatus.TUPLES_OK, ExecStatus.COMMAND_OK):
@@ -382,7 +382,7 @@ class BaseCursor(Generic[ConnectionType]):
         self._results = list(results)
         self.pgresult = results[0]
         self._tx.set_pgresult(results[0])
-        self._tx.make_row = self._row_factory(self)
+        self._make_row = self._row_factory(self)
         nrows = self.pgresult.command_tuples
         if nrows is not None:
             if self._rowcount < 0:
@@ -529,7 +529,7 @@ class Cursor(BaseCursor["Connection"]):
             self._conn.wait(self._stream_send_gen(query, params))
             first = True
             while self._conn.wait(self._stream_fetchone_gen(first)):
-                rec = self._tx.load_row(0)
+                rec = self._tx.load_row(0, self._make_row)
                 assert rec is not None
                 yield rec
                 first = False
@@ -543,7 +543,7 @@ class Cursor(BaseCursor["Connection"]):
         :rtype: Optional[Row], with Row defined by `row_factory`
         """
         self._check_result()
-        record = self._tx.load_row(self._pos)
+        record = self._tx.load_row(self._pos, self._make_row)
         if record is not None:
             self._pos += 1
         return record
@@ -562,7 +562,9 @@ class Cursor(BaseCursor["Connection"]):
         if not size:
             size = self.arraysize
         records: List[Row] = self._tx.load_rows(
-            self._pos, min(self._pos + size, self.pgresult.ntuples)
+            self._pos,
+            min(self._pos + size, self.pgresult.ntuples),
+            self._make_row,
         )
         self._pos += len(records)
         return records
@@ -576,7 +578,7 @@ class Cursor(BaseCursor["Connection"]):
         self._check_result()
         assert self.pgresult
         records: List[Row] = self._tx.load_rows(
-            self._pos, self.pgresult.ntuples
+            self._pos, self.pgresult.ntuples, self._make_row
         )
         self._pos = self.pgresult.ntuples
         return records
@@ -584,7 +586,8 @@ class Cursor(BaseCursor["Connection"]):
     def __iter__(self) -> Iterator[Row]:
         self._check_result()
 
-        load = self._tx.load_row
+        def load(pos: int) -> Optional[Row]:
+            return self._tx.load_row(pos, self._make_row)
 
         while 1:
             row = load(self._pos)
@@ -667,14 +670,14 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             await self._conn.wait(self._stream_send_gen(query, params))
             first = True
             while await self._conn.wait(self._stream_fetchone_gen(first)):
-                rec = self._tx.load_row(0)
+                rec = self._tx.load_row(0, self._make_row)
                 assert rec is not None
                 yield rec
                 first = False
 
     async def fetchone(self) -> Optional[Row]:
         self._check_result()
-        rv = self._tx.load_row(self._pos)
+        rv = self._tx.load_row(self._pos, self._make_row)
         if rv is not None:
             self._pos += 1
         return rv
@@ -686,7 +689,9 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
         if not size:
             size = self.arraysize
         records: List[Row] = self._tx.load_rows(
-            self._pos, min(self._pos + size, self.pgresult.ntuples)
+            self._pos,
+            min(self._pos + size, self.pgresult.ntuples),
+            self._make_row,
         )
         self._pos += len(records)
         return records
@@ -695,7 +700,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
         self._check_result()
         assert self.pgresult
         records: List[Row] = self._tx.load_rows(
-            self._pos, self.pgresult.ntuples
+            self._pos, self.pgresult.ntuples, self._make_row
         )
         self._pos = self.pgresult.ntuples
         return records
@@ -703,7 +708,8 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
     async def __aiter__(self) -> AsyncIterator[Row]:
         self._check_result()
 
-        load = self._tx.load_row
+        def load(pos: int) -> Optional[Row]:
+            return self._tx.load_row(pos, self._make_row)
 
         while 1:
             row = load(self._pos)
index b97858780a96758af4dd7e409f1662388aaed87e..f06f9cd7690f3ea5d37c3e319fa31e10db890e59 100644 (file)
@@ -89,8 +89,6 @@ class Transformer(Protocol):
     def __init__(self, context: Optional[AdaptContext] = None):
         ...
 
-    make_row: RowMaker
-
     @property
     def connection(self) -> Optional["BaseConnection"]:
         ...
@@ -121,10 +119,10 @@ class Transformer(Protocol):
     def get_dumper(self, obj: Any, format: Format) -> "Dumper":
         ...
 
-    def load_rows(self, row0: int, row1: int) -> List[Row]:
+    def load_rows(self, row0: int, row1: int, make_row: RowMaker) -> List[Row]:
         ...
 
-    def load_row(self, row: int) -> Optional[Row]:
+    def load_row(self, row: int, make_row: RowMaker) -> Optional[Row]:
         ...
 
     def load_sequence(
index d15ccfcb3bad0706e9224bfe3d72955d43d73124..4aa30f77d485aa4516928daaa40e06f2a908b3b4 100644 (file)
@@ -120,7 +120,7 @@ class ServerCursorHelper(Generic[ConnectionType]):
 
         cur.pgresult = res
         cur._tx.set_pgresult(res, set_loaders=False)
-        return cur._tx.load_rows(0, res.ntuples)
+        return cur._tx.load_rows(0, res.ntuples, cur._make_row)
 
     def _scroll_gen(
         self, cur: BaseCursor[ConnectionType], value: int, mode: str
index bd111aea115e5a9d05cc1e74dca7b63fd4ea0f01..caf380fe9d5d3cbd59446e1f76a64d19d50c281a 100644 (file)
@@ -17,7 +17,6 @@ from psycopg3.pq.proto import PGconn, PGresult
 
 class Transformer(proto.AdaptContext):
     def __init__(self, context: Optional[proto.AdaptContext] = None): ...
-    make_row: proto.RowMaker
     @property
     def connection(self) -> Optional[BaseConnection]: ...
     @property
@@ -34,8 +33,12 @@ class Transformer(proto.AdaptContext):
         self, params: Sequence[Any], formats: Sequence[Format]
     ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: ...
     def get_dumper(self, obj: Any, format: Format) -> Dumper: ...
-    def load_rows(self, row0: int, row1: int) -> List[proto.Row]: ...
-    def load_row(self, row: int) -> Optional[proto.Row]: ...
+    def load_rows(
+        self, row0: int, row1: int, make_row: proto.RowMaker
+    ) -> List[proto.Row]: ...
+    def load_row(
+        self, row: int, make_row: proto.RowMaker
+    ) -> Optional[proto.Row]: ...
     def load_sequence(
         self, record: Sequence[Optional[bytes]]
     ) -> Tuple[Any, ...]: ...
index 5635d6ca514871f573039b102eb7e3e1fbaeb314..d52d60e5de6b0e620f14877c72f5fa119b975bb8 100644 (file)
@@ -83,7 +83,6 @@ cdef class Transformer:
     cdef int _nfields, _ntuples
     cdef list _row_dumpers
     cdef list _row_loaders
-    cdef public object make_row
 
     def __cinit__(self, context: Optional["AdaptContext"] = None):
         if context is not None:
@@ -274,7 +273,7 @@ cdef class Transformer:
 
         return ps, ts, fs
 
-    def load_rows(self, int row0, int row1) -> List[Row]:
+    def load_rows(self, int row0, int row1, object make_row) -> List[Row]:
         if self._pgresult is None:
             raise e.InterfaceError("result not set")
 
@@ -334,7 +333,6 @@ cdef class Transformer:
                     Py_INCREF(pyval)
                     PyTuple_SET_ITEM(<object>brecord, col, pyval)
 
-        cdef object make_row = self.make_row
         if make_row is not tuple:
             for i in range(row1 - row0):
                 brecord = PyList_GET_ITEM(records, i)
@@ -345,7 +343,7 @@ cdef class Transformer:
                 Py_DECREF(<object>brecord)
         return records
 
-    def load_row(self, int row) -> Optional[Row]:
+    def load_row(self, int row, object make_row) -> Optional[Row]:
         if self._pgresult is None:
             return None
 
@@ -384,7 +382,6 @@ cdef class Transformer:
             Py_INCREF(pyval)
             PyTuple_SET_ITEM(record, col, pyval)
 
-        cdef object make_row = self.make_row
         if make_row is not tuple:
             record = PyObject_CallFunctionObjArgs(
                 make_row, <PyObject *>record, NULL)
index 836db33d8d98a188b8b3c2e778a9b777a763ceb9..166344ac65aa40f70e8245dc9490608f896f33a6 100644 (file)
@@ -8,7 +8,7 @@ def test_tuple_row(conn):
     row = cur.execute("select 1 as a").fetchone()
     assert row == (1,)
     assert type(row) is tuple
-    assert cur._tx.make_row is tuple
+    assert cur._make_row is tuple
 
 
 def test_dict_row(conn):