from . import encodings
from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
from .sql import Composable
-from .proto import PQGen, PQGenConn, RV, Query, Params, AdaptContext
-from .proto import ConnectionType
+from .proto import PQGen, PQGenConn, RV, RowFactory, Query, Params
+from .proto import AdaptContext, ConnectionType
from .conninfo import make_conninfo
from .generators import notifies
from .transaction import Transaction, AsyncTransaction
"""Close the database connection."""
self.pgconn.finish()
- def cursor(self, name: str = "", binary: bool = False) -> "Cursor":
+ def cursor(
+ self,
+ name: str = "",
+ binary: bool = False,
+ row_factory: RowFactory = cursor.default_row_factory,
+ ) -> "Cursor":
"""
Return a new `Cursor` to send commands and queries to the connection.
"""
raise NotImplementedError
format = Format.BINARY if binary else Format.TEXT
- return self.cursor_factory(self, format=format)
+ return self.cursor_factory(self, row_factory, format=format)
def execute(
self,
self.pgconn.finish()
async def cursor(
- self, name: str = "", binary: bool = False
+ self,
+ name: str = "",
+ binary: bool = False,
+ row_factory: RowFactory = cursor.default_row_factory,
) -> "AsyncCursor":
"""
Return a new `AsyncCursor` to send commands and queries to the connection.
raise NotImplementedError
format = Format.BINARY if binary else Format.TEXT
- return self.cursor_factory(self, format=format)
+ return self.cursor_factory(self, row_factory, format=format)
async def execute(
self,
from .pq import ExecStatus, Format
from .copy import Copy, AsyncCopy
from .proto import ConnectionType, Query, Params, PQGen
+from .proto import Row, RowFactory, RowMaker
from ._column import Column
from ._queries import PostgresQuery
from ._preparing import Prepare
execute = generators.execute
+def default_row_factory(cursor: Any) -> RowMaker:
+ return lambda values: values
+
+
class BaseCursor(Generic[ConnectionType]):
# Slots with __weakref__ and generic bases don't work on Py 3.6
# https://bugs.python.org/issue41451
if sys.version_info >= (3, 7):
__slots__ = """
_conn format _adapters arraysize _closed _results _pgresult _pos
- _iresult _rowcount _pgq _tx _last_query
+ _iresult _rowcount _pgq _tx _last_query _row_factory _make_row
__weakref__
""".split()
def __init__(
self,
connection: ConnectionType,
+ row_factory: RowFactory,
format: Format = Format.TEXT,
):
self._conn = connection
self.format = format
self._adapters = adapt.AdaptersMap(connection.adapters)
+ self._row_factory = row_factory
self.arraysize = 1
self._closed = False
self._last_query: Optional[Query] = None
def _reset(self) -> None:
self._results: List["PGresult"] = []
self._pgresult: Optional["PGresult"] = None
+ self._make_row: Optional[RowMaker] = None
self._pos = 0
self._iresult = 0
self._rowcount = -1
return None
elif res.status == ExecStatus.SINGLE_TUPLE:
+ self._make_row = self._row_factory(self)
self.pgresult = res # will set it on the transformer too
# TODO: the transformer may do excessive work here: create a
# path that doesn't clear the loaders every time.
self._results = list(results)
self.pgresult = results[0]
+ self._make_row = self._row_factory(self)
nrows = self.pgresult.command_tuples
if nrows is not None:
if self._rowcount < 0:
def stream(
self, query: Query, params: Optional[Params] = None
- ) -> Iterator[Sequence[Any]]:
+ ) -> Iterator[Row]:
"""
Iterate row-by-row on a result from the database.
"""
while self._conn.wait(self._stream_fetchone_gen()):
rec = self._tx.load_row(0)
assert rec is not None
- yield rec
+ assert self._make_row is not None
+ yield self._make_row(rec)
- def fetchone(self) -> Optional[Sequence[Any]]:
+ def fetchone(self) -> Optional[Row]:
"""
Return the next record from the current recordset.
record = self._tx.load_row(self._pos)
if record is not None:
self._pos += 1
+ assert self._make_row is not None
+ return self._make_row(record)
return record
- def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
+ def fetchmany(self, size: int = 0) -> Sequence[Row]:
"""
Return the next *size* records from the current recordset.
self._pos, min(self._pos + size, self.pgresult.ntuples)
)
self._pos += len(records)
- return records
+ assert self._make_row is not None
+ return [self._make_row(r) for r in records]
- def fetchall(self) -> Sequence[Sequence[Any]]:
+ def fetchall(self) -> Sequence[Row]:
"""
Return all the remaining records from the current recordset.
"""
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
self._pos += self.pgresult.ntuples
- return records
+ assert self._make_row is not None
+ return [self._make_row(r) for r in records]
- def __iter__(self) -> Iterator[Sequence[Any]]:
+ def __iter__(self) -> Iterator[Row]:
self._check_result()
load = self._tx.load_row
if row is None:
break
self._pos += 1
- yield row
+ assert self._make_row is not None
+ yield self._make_row(row)
@contextmanager
def copy(self, statement: Query) -> Iterator[Copy]:
async def stream(
self, query: Query, params: Optional[Params] = None
- ) -> AsyncIterator[Sequence[Any]]:
+ ) -> AsyncIterator[Row]:
async with self._conn.lock:
await self._conn.wait(self._stream_send_gen(query, params))
while await self._conn.wait(self._stream_fetchone_gen()):
rec = self._tx.load_row(0)
assert rec is not None
- yield rec
+ assert self._make_row is not None
+ yield self._make_row(rec)
- async def fetchone(self) -> Optional[Sequence[Any]]:
+ async def fetchone(self) -> Optional[Row]:
self._check_result()
rv = self._tx.load_row(self._pos)
if rv is not None:
self._pos += 1
+ assert self._make_row is not None
+ return self._make_row(rv)
return rv
- async def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
+ async def fetchmany(self, size: int = 0) -> List[Row]:
self._check_result()
assert self.pgresult
self._pos, min(self._pos + size, self.pgresult.ntuples)
)
self._pos += len(records)
- return records
+ assert self._make_row is not None
+ return [self._make_row(r) for r in records]
- async def fetchall(self) -> Sequence[Sequence[Any]]:
+ async def fetchall(self) -> List[Row]:
self._check_result()
assert self.pgresult
records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
self._pos += self.pgresult.ntuples
- return records
+ assert self._make_row is not None
+ return [self._make_row(r) for r in records]
- async def __aiter__(self) -> AsyncIterator[Sequence[Any]]:
+ async def __aiter__(self) -> AsyncIterator[Row]:
self._check_result()
load = self._tx.load_row
if row is None:
break
self._pos += 1
- yield row
+ assert self._make_row is not None
+ yield self._make_row(row)
@asynccontextmanager
async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: