]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Make row factory optional
authorDenis Laxalde <denis.laxalde@dalibo.com>
Wed, 10 Feb 2021 16:58:42 +0000 (17:58 +0100)
committerDenis Laxalde <denis.laxalde@dalibo.com>
Wed, 10 Feb 2021 17:17:02 +0000 (18:17 +0100)
We change the default value of row_factory argument in
connection.cursor() to None and thus use a keyword argument.

On cursor side, we only set the '_make_row' attribute if a 'row_factory'
got passed and we guard all possible calls to _make_row() by an
'if self._make_row' to avoid a Python call per row. Note that, on the
other hand, we now need to cast 'row' values to the 'Row' type in order
to satisfy type checking.

The default_row_factory() is now useless and thus dropped.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/cursor.py

index 79842ae4a59f5e3081a8d98a49c8b35824cad95e..0caca8bca249f756da320e516527bcb34f1e1abb 100644 (file)
@@ -452,7 +452,7 @@ class Connection(BaseConnection):
         self,
         name: str = "",
         binary: bool = False,
-        row_factory: RowFactory = cursor.default_row_factory,
+        row_factory: Optional[RowFactory] = None,
     ) -> "Cursor":
         """
         Return a new `Cursor` to send commands and queries to the connection.
@@ -461,7 +461,9 @@ class Connection(BaseConnection):
             raise NotImplementedError
 
         format = Format.BINARY if binary else Format.TEXT
-        return self.cursor_factory(self, row_factory, format=format)
+        return self.cursor_factory(
+            self, format=format, row_factory=row_factory
+        )
 
     def execute(
         self,
@@ -592,7 +594,7 @@ class AsyncConnection(BaseConnection):
         self,
         name: str = "",
         binary: bool = False,
-        row_factory: RowFactory = cursor.default_row_factory,
+        row_factory: Optional[RowFactory] = None,
     ) -> "AsyncCursor":
         """
         Return a new `AsyncCursor` to send commands and queries to the connection.
@@ -601,7 +603,9 @@ class AsyncConnection(BaseConnection):
             raise NotImplementedError
 
         format = Format.BINARY if binary else Format.TEXT
-        return self.cursor_factory(self, row_factory, format=format)
+        return self.cursor_factory(
+            self, format=format, row_factory=row_factory
+        )
 
     async def execute(
         self,
index f7bdd7c20968a07e8bcc48b3f626d940eb99049c..171318936b4a04c99e835992084c19263b490bab 100644 (file)
@@ -7,7 +7,7 @@ psycopg3 cursor objects
 import sys
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
-from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING
+from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING, cast
 from contextlib import contextmanager
 
 from . import pq
@@ -44,10 +44,6 @@ else:
     execute = generators.execute
 
 
-def default_row_factory(cursor: Any) -> RowMaker:
-    return lambda values: values
-
-
 class BaseCursor(Generic[ConnectionType]):
     # Slots with __weakref__ and generic bases don't work on Py 3.6
     # https://bugs.python.org/issue41451
@@ -65,8 +61,8 @@ class BaseCursor(Generic[ConnectionType]):
     def __init__(
         self,
         connection: ConnectionType,
-        row_factory: RowFactory,
         format: Format = Format.TEXT,
+        row_factory: Optional[RowFactory] = None,
     ):
         self._conn = connection
         self.format = format
@@ -269,7 +265,8 @@ class BaseCursor(Generic[ConnectionType]):
             return None
 
         elif res.status == ExecStatus.SINGLE_TUPLE:
-            self._make_row = self._row_factory(self)
+            if self._row_factory:
+                self._make_row = self._row_factory(self)
             self.pgresult = res  # will set it on the transformer too
             # TODO: the transformer may do excessive work here: create a
             # path that doesn't clear the loaders every time.
@@ -373,7 +370,8 @@ class BaseCursor(Generic[ConnectionType]):
 
         self._results = list(results)
         self.pgresult = results[0]
-        self._make_row = self._row_factory(self)
+        if self._row_factory:
+            self._make_row = self._row_factory(self)
         nrows = self.pgresult.command_tuples
         if nrows is not None:
             if self._rowcount < 0:
@@ -497,8 +495,7 @@ class Cursor(BaseCursor["Connection"]):
             while self._conn.wait(self._stream_fetchone_gen()):
                 rec = self._tx.load_row(0)
                 assert rec is not None
-                assert self._make_row is not None
-                yield self._make_row(rec)
+                yield self._make_row(rec) if self._make_row else cast(Row, rec)
 
     def fetchone(self) -> Optional[Row]:
         """
@@ -510,8 +507,9 @@ class Cursor(BaseCursor["Connection"]):
         record = self._tx.load_row(self._pos)
         if record is not None:
             self._pos += 1
-            assert self._make_row is not None
-            return self._make_row(record)
+            return (
+                self._make_row(record) if self._make_row else cast(Row, record)
+            )
         return record
 
     def fetchmany(self, size: int = 0) -> Sequence[Row]:
@@ -529,8 +527,9 @@ class Cursor(BaseCursor["Connection"]):
             self._pos, min(self._pos + size, self.pgresult.ntuples)
         )
         self._pos += len(records)
-        assert self._make_row is not None
-        return [self._make_row(r) for r in records]
+        if self._make_row:
+            return list(map(self._make_row, records))
+        return cast(Sequence[Row], records)
 
     def fetchall(self) -> Sequence[Row]:
         """
@@ -540,8 +539,9 @@ class Cursor(BaseCursor["Connection"]):
         assert self.pgresult
         records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
         self._pos += self.pgresult.ntuples
-        assert self._make_row is not None
-        return [self._make_row(r) for r in records]
+        if self._make_row:
+            return list(map(self._make_row, records))
+        return cast(Sequence[Row], records)
 
     def __iter__(self) -> Iterator[Row]:
         self._check_result()
@@ -553,8 +553,7 @@ class Cursor(BaseCursor["Connection"]):
             if row is None:
                 break
             self._pos += 1
-            assert self._make_row is not None
-            yield self._make_row(row)
+            yield self._make_row(row) if self._make_row else cast(Row, row)
 
     @contextmanager
     def copy(self, statement: Query) -> Iterator[Copy]:
@@ -613,16 +612,14 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             while await self._conn.wait(self._stream_fetchone_gen()):
                 rec = self._tx.load_row(0)
                 assert rec is not None
-                assert self._make_row is not None
-                yield self._make_row(rec)
+                yield self._make_row(rec) if self._make_row else cast(Row, rec)
 
     async def fetchone(self) -> Optional[Row]:
         self._check_result()
         rv = self._tx.load_row(self._pos)
         if rv is not None:
             self._pos += 1
-            assert self._make_row is not None
-            return self._make_row(rv)
+            return self._make_row(rv) if self._make_row else cast(Row, rv)
         return rv
 
     async def fetchmany(self, size: int = 0) -> List[Row]:
@@ -635,16 +632,18 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             self._pos, min(self._pos + size, self.pgresult.ntuples)
         )
         self._pos += len(records)
-        assert self._make_row is not None
-        return [self._make_row(r) for r in records]
+        if self._make_row:
+            return list(map(self._make_row, records))
+        return cast(List[Row], records)
 
     async def fetchall(self) -> List[Row]:
         self._check_result()
         assert self.pgresult
         records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
         self._pos += self.pgresult.ntuples
-        assert self._make_row is not None
-        return [self._make_row(r) for r in records]
+        if self._make_row:
+            return list(map(self._make_row, records))
+        return cast(List[Row], records)
 
     async def __aiter__(self) -> AsyncIterator[Row]:
         self._check_result()
@@ -656,8 +655,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             if row is None:
                 break
             self._pos += 1
-            assert self._make_row is not None
-            yield self._make_row(row)
+            yield self._make_row(row) if self._make_row else cast(Row, row)
 
     @asynccontextmanager
     async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: