From: Denis Laxalde Date: Wed, 10 Feb 2021 16:58:42 +0000 (+0100) Subject: Make row factory optional X-Git-Tag: 3.0.dev0~106^2~27 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4361d465509e8e5d156731a11e644bebb58d1c92;p=thirdparty%2Fpsycopg.git Make row factory optional We change the default value of row_factory argument in connection.cursor() to None and thus use a keyword argument. On cursor side, we only set the '_make_row' attribute if a 'row_factory' got passed and we guard all possible calls to _make_row() by an 'if self._make_row' to avoid a Python call per row. Note that, on the other hand, we now need to cast 'row' values to the 'Row' type in order to satisfy type checking. The default_row_factory() is now useless and thus dropped. --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 79842ae4a..0caca8bca 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -452,7 +452,7 @@ class Connection(BaseConnection): self, name: str = "", binary: bool = False, - row_factory: RowFactory = cursor.default_row_factory, + row_factory: Optional[RowFactory] = None, ) -> "Cursor": """ Return a new `Cursor` to send commands and queries to the connection. @@ -461,7 +461,9 @@ class Connection(BaseConnection): raise NotImplementedError format = Format.BINARY if binary else Format.TEXT - return self.cursor_factory(self, row_factory, format=format) + return self.cursor_factory( + self, format=format, row_factory=row_factory + ) def execute( self, @@ -592,7 +594,7 @@ class AsyncConnection(BaseConnection): self, name: str = "", binary: bool = False, - row_factory: RowFactory = cursor.default_row_factory, + row_factory: Optional[RowFactory] = None, ) -> "AsyncCursor": """ Return a new `AsyncCursor` to send commands and queries to the connection. @@ -601,7 +603,9 @@ class AsyncConnection(BaseConnection): raise NotImplementedError format = Format.BINARY if binary else Format.TEXT - return self.cursor_factory(self, row_factory, format=format) + return self.cursor_factory( + self, format=format, row_factory=row_factory + ) async def execute( self, diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index f7bdd7c20..171318936 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -7,7 +7,7 @@ psycopg3 cursor objects import sys from types import TracebackType from typing import Any, AsyncIterator, Callable, Generic, Iterator, List -from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING +from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING, cast from contextlib import contextmanager from . import pq @@ -44,10 +44,6 @@ else: execute = generators.execute -def default_row_factory(cursor: Any) -> RowMaker: - return lambda values: values - - class BaseCursor(Generic[ConnectionType]): # Slots with __weakref__ and generic bases don't work on Py 3.6 # https://bugs.python.org/issue41451 @@ -65,8 +61,8 @@ class BaseCursor(Generic[ConnectionType]): def __init__( self, connection: ConnectionType, - row_factory: RowFactory, format: Format = Format.TEXT, + row_factory: Optional[RowFactory] = None, ): self._conn = connection self.format = format @@ -269,7 +265,8 @@ class BaseCursor(Generic[ConnectionType]): return None elif res.status == ExecStatus.SINGLE_TUPLE: - self._make_row = self._row_factory(self) + if self._row_factory: + self._make_row = self._row_factory(self) self.pgresult = res # will set it on the transformer too # TODO: the transformer may do excessive work here: create a # path that doesn't clear the loaders every time. @@ -373,7 +370,8 @@ class BaseCursor(Generic[ConnectionType]): self._results = list(results) self.pgresult = results[0] - self._make_row = self._row_factory(self) + if self._row_factory: + self._make_row = self._row_factory(self) nrows = self.pgresult.command_tuples if nrows is not None: if self._rowcount < 0: @@ -497,8 +495,7 @@ class Cursor(BaseCursor["Connection"]): while self._conn.wait(self._stream_fetchone_gen()): rec = self._tx.load_row(0) assert rec is not None - assert self._make_row is not None - yield self._make_row(rec) + yield self._make_row(rec) if self._make_row else cast(Row, rec) def fetchone(self) -> Optional[Row]: """ @@ -510,8 +507,9 @@ class Cursor(BaseCursor["Connection"]): record = self._tx.load_row(self._pos) if record is not None: self._pos += 1 - assert self._make_row is not None - return self._make_row(record) + return ( + self._make_row(record) if self._make_row else cast(Row, record) + ) return record def fetchmany(self, size: int = 0) -> Sequence[Row]: @@ -529,8 +527,9 @@ class Cursor(BaseCursor["Connection"]): self._pos, min(self._pos + size, self.pgresult.ntuples) ) self._pos += len(records) - assert self._make_row is not None - return [self._make_row(r) for r in records] + if self._make_row: + return list(map(self._make_row, records)) + return cast(Sequence[Row], records) def fetchall(self) -> Sequence[Row]: """ @@ -540,8 +539,9 @@ class Cursor(BaseCursor["Connection"]): assert self.pgresult records = self._tx.load_rows(self._pos, self.pgresult.ntuples) self._pos += self.pgresult.ntuples - assert self._make_row is not None - return [self._make_row(r) for r in records] + if self._make_row: + return list(map(self._make_row, records)) + return cast(Sequence[Row], records) def __iter__(self) -> Iterator[Row]: self._check_result() @@ -553,8 +553,7 @@ class Cursor(BaseCursor["Connection"]): if row is None: break self._pos += 1 - assert self._make_row is not None - yield self._make_row(row) + yield self._make_row(row) if self._make_row else cast(Row, row) @contextmanager def copy(self, statement: Query) -> Iterator[Copy]: @@ -613,16 +612,14 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): while await self._conn.wait(self._stream_fetchone_gen()): rec = self._tx.load_row(0) assert rec is not None - assert self._make_row is not None - yield self._make_row(rec) + yield self._make_row(rec) if self._make_row else cast(Row, rec) async def fetchone(self) -> Optional[Row]: self._check_result() rv = self._tx.load_row(self._pos) if rv is not None: self._pos += 1 - assert self._make_row is not None - return self._make_row(rv) + return self._make_row(rv) if self._make_row else cast(Row, rv) return rv async def fetchmany(self, size: int = 0) -> List[Row]: @@ -635,16 +632,18 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): self._pos, min(self._pos + size, self.pgresult.ntuples) ) self._pos += len(records) - assert self._make_row is not None - return [self._make_row(r) for r in records] + if self._make_row: + return list(map(self._make_row, records)) + return cast(List[Row], records) async def fetchall(self) -> List[Row]: self._check_result() assert self.pgresult records = self._tx.load_rows(self._pos, self.pgresult.ntuples) self._pos += self.pgresult.ntuples - assert self._make_row is not None - return [self._make_row(r) for r in records] + if self._make_row: + return list(map(self._make_row, records)) + return cast(List[Row], records) async def __aiter__(self) -> AsyncIterator[Row]: self._check_result() @@ -656,8 +655,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): if row is None: break self._pos += 1 - assert self._make_row is not None - yield self._make_row(row) + yield self._make_row(row) if self._make_row else cast(Row, row) @asynccontextmanager async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: