from . import pq
from . import errors as e
from .oids import INVALID_OID
-from .proto import LoadFunc, AdaptContext
+from .proto import LoadFunc, AdaptContext, Row, RowMaker
from ._enums import Format
if TYPE_CHECKING:
__module__ = "psycopg3.adapt"
_adapters: "AdaptersMap"
_pgresult: Optional["PGresult"] = None
+ make_row: Optional[RowMaker] = None
def __init__(self, context: Optional[AdaptContext] = None):
dumper = cache[key1] = dumper.upgrade(obj, format)
return dumper
- def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]:
+ def load_rows(self, row0: int, row1: int) -> List[Row]:
res = self._pgresult
if not res:
raise e.InterfaceError("result not set")
f"rows must be included between 0 and {self._ntuples}"
)
- records: List[Tuple[Any, ...]]
+ records: List[Row]
records = [None] * (row1 - row0) # type: ignore[list-item]
+ if self.make_row:
+ mkrow = self.make_row
+ else:
+ mkrow = tuple
for row in range(row0, row1):
record: List[Any] = [None] * self._nfields
for col in range(self._nfields):
val = res.get_value(row, col)
if val is not None:
record[col] = self._row_loaders[col](val)
- records[row - row0] = tuple(record)
+ records[row - row0] = mkrow(record)
return records
- def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
+ def load_row(self, row: int) -> Optional[Row]:
res = self._pgresult
if not res:
return None
if val is not None:
record[col] = self._row_loaders[col](val)
- return tuple(record)
+ return self.make_row(record) if self.make_row else tuple(record)
def load_sequence(
self, record: Sequence[Optional[bytes]]
from .pq import ExecStatus, Format
from .copy import Copy, AsyncCopy
from .proto import ConnectionType, Query, Params, PQGen
-from .proto import Row, RowFactory, RowMaker
+from .proto import Row, RowFactory
from ._column import Column
from ._queries import PostgresQuery
from ._preparing import Prepare
if sys.version_info >= (3, 7):
__slots__ = """
_conn format _adapters arraysize _closed _results _pgresult _pos
- _iresult _rowcount _pgq _tx _last_query _row_factory _make_row
+ _iresult _rowcount _pgq _tx _last_query _row_factory
__weakref__
""".split()
def _reset(self) -> None:
self._results: List["PGresult"] = []
self._pgresult: Optional["PGresult"] = None
- self._make_row: Optional[RowMaker] = None
self._pos = 0
self._iresult = 0
self._rowcount = -1
elif res.status == ExecStatus.SINGLE_TUPLE:
if self._row_factory:
- self._make_row = self._row_factory(self)
+ self._tx.make_row = self._row_factory(self)
self.pgresult = res # will set it on the transformer too
# TODO: the transformer may do excessive work here: create a
# path that doesn't clear the loaders every time.
self._results = list(results)
self.pgresult = results[0]
if self._row_factory:
- self._make_row = self._row_factory(self)
+ self._tx.make_row = self._row_factory(self)
nrows = self.pgresult.command_tuples
if nrows is not None:
if self._rowcount < 0:
while self._conn.wait(self._stream_fetchone_gen()):
rec = self._tx.load_row(0)
assert rec is not None
- yield self._make_row(rec) if self._make_row else rec
+ yield rec
def fetchone(self) -> Optional[Row]:
"""
record = self._tx.load_row(self._pos)
if record is not None:
self._pos += 1
- return self._make_row(record) if self._make_row else record
- return record
+ return record # type: ignore[no-any-return]
def fetchmany(self, size: int = 0) -> Sequence[Row]:
"""
self._pos, min(self._pos + size, self.pgresult.ntuples)
)
self._pos += len(records)
- if self._make_row:
- return list(map(self._make_row, records))
return records
def fetchall(self) -> Sequence[Row]:
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
self._pos += self.pgresult.ntuples
- if self._make_row:
- return list(map(self._make_row, records))
return records
def __iter__(self) -> Iterator[Row]:
if row is None:
break
self._pos += 1
- yield self._make_row(row) if self._make_row else row
+ yield row
@contextmanager
def copy(self, statement: Query) -> Iterator[Copy]:
while await self._conn.wait(self._stream_fetchone_gen()):
rec = self._tx.load_row(0)
assert rec is not None
- yield self._make_row(rec) if self._make_row else rec
+ yield rec
async def fetchone(self) -> Optional[Row]:
self._check_result()
rv = self._tx.load_row(self._pos)
if rv is not None:
self._pos += 1
- return self._make_row(rv) if self._make_row else rv
- return rv
+ return rv # type: ignore[no-any-return]
async def fetchmany(self, size: int = 0) -> List[Row]:
self._check_result()
self._pos, min(self._pos + size, self.pgresult.ntuples)
)
self._pos += len(records)
- if self._make_row:
- return list(map(self._make_row, records))
return records
async def fetchall(self) -> List[Row]:
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
self._pos += self.pgresult.ntuples
- if self._make_row:
- return list(map(self._make_row, records))
return records
async def __aiter__(self) -> AsyncIterator[Row]:
if row is None:
break
self._pos += 1
- yield self._make_row(row) if self._make_row else row
+ yield row
@asynccontextmanager
async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]:
"""
+# Row factories
+
+Row = TypeVar("Row", Tuple[Any, ...], Any)
+
+
+class RowMaker(Protocol):
+ def __call__(self, __values: Sequence[Any]) -> Row:
+ ...
+
+
+class RowFactory(Protocol):
+ def __call__(self, __cursor: "BaseCursor[ConnectionType]") -> RowMaker:
+ ...
+
+
# Adaptation types
DumpFunc = Callable[[Any], bytes]
class Transformer(Protocol):
+ make_row: Optional[RowMaker] = None
+
def __init__(self, context: Optional[AdaptContext] = None):
...
def get_dumper(self, obj: Any, format: Format) -> "Dumper":
...
- def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]:
+ def load_rows(self, row0: int, row1: int) -> List[Row]:
...
- def load_row(self, row: int) -> Optional[Tuple[Any, ...]]:
+ def load_row(self, row: int) -> Optional[Row]:
...
def load_sequence(
def get_loader(self, oid: int, format: pq.Format) -> "Loader":
...
-
-
-# Row factories
-
-Row = TypeVar("Row", Tuple[Any, ...], Any)
-
-
-class RowMaker(Protocol):
- def __call__(self, __values: Sequence[Any]) -> Row:
- ...
-
-
-class RowFactory(Protocol):
- def __call__(self, __cursor: "BaseCursor[ConnectionType]") -> RowMaker:
- ...
@property
def adapters(self) -> AdaptersMap: ...
@property
+ def make_row(self) -> Optional[proto.RowMaker]: ...
+ @make_row.setter
+ def make_row(self, row_maker: proto.RowMaker) -> None: ...
+ @property
def pgresult(self) -> Optional[PGresult]: ...
@pgresult.setter
def pgresult(self, result: Optional[PGresult]) -> None: ...
self, params: Sequence[Any], formats: Sequence[Format]
) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: ...
def get_dumper(self, obj: Any, format: Format) -> Dumper: ...
- def load_rows(self, row0: int, row1: int) -> List[Tuple[Any, ...]]: ...
- def load_row(self, row: int) -> Optional[Tuple[Any, ...]]: ...
+ def load_rows(self, row0: int, row1: int) -> List[proto.Row]: ...
+ def load_row(self, row: int) -> Optional[proto.Row]: ...
def load_sequence(
self, record: Sequence[Optional[bytes]]
) -> Tuple[Any, ...]: ...
from psycopg3 import errors as e
from psycopg3._enums import Format as Pg3Format
from psycopg3.pq import Format as PqFormat
+from psycopg3.proto import Row, RowMaker
# internal structure: you are not supposed to know this. But it's worth some
# 10% of the innermost loop, so I'm willing to ask for forgiveness later...
cdef int _nfields, _ntuples
cdef list _row_dumpers
cdef list _row_loaders
+ cdef object _make_row
def __cinit__(self, context: Optional["AdaptContext"] = None):
if context is not None:
self.adapters = global_adapters
self.connection = None
+ @property
+ def make_row(self) -> Optional[RowMaker]:
+ return self._make_row
+
+ @make_row.setter
+ def make_row(self, row_maker: RowMaker) -> None:
+ self._make_row = row_maker
+
@property
def pgresult(self) -> Optional[PGresult]:
return self._pgresult
return ps, ts, fs
- def load_rows(self, int row0, int row1) -> List[Tuple[Any, ...]]:
+ def load_rows(self, int row0, int row1) -> List[Row]:
if self._pgresult is None:
raise e.InterfaceError("result not set")
Py_INCREF(pyval)
PyTuple_SET_ITEM(<object>brecord, col, pyval)
+ if self.make_row:
+ return list(map(self.make_row, records))
return records
- def load_row(self, int row) -> Optional[Tuple[Any, ...]]:
+ def load_row(self, int row) -> Optional[Row]:
if self._pgresult is None:
return None
Py_INCREF(pyval)
PyTuple_SET_ITEM(record, col, pyval)
+ if self.make_row:
+ return self.make_row(record)
return record
cpdef object load_sequence(self, record: Sequence[Optional[bytes]]):