# Copyright (C) 2020-2021 The Psycopg Team
-import weakref
import warnings
from types import TracebackType
from typing import Any, AsyncIterator, Generic, List, Iterator, Optional
class NamedCursorHelper(Generic[ConnectionType]):
- __slots__ = ("name", "_wcur")
+ __slots__ = ("name", "described")
"""Helper object for common NamedCursor code.
TODO: this should be a mixin, but couldn't find a way to work it
correctly with the generic.
"""
- def __init__(
- self,
- name: str,
- cursor: BaseCursor[ConnectionType],
- ):
+ def __init__(self, name: str):
self.name = name
- self._wcur = weakref.ref(cursor)
-
- @property
- def _cur(self) -> BaseCursor[Any]:
- cur = self._wcur()
- assert cur
- return cur
+ self.described = False
def _declare_gen(
- self, query: Query, params: Optional[Params] = None
+ self,
+ cur: BaseCursor[ConnectionType],
+ query: Query,
+ params: Optional[Params] = None,
) -> PQGen[None]:
"""Generator implementing `NamedCursor.execute()`."""
- cur = self._cur
conn = cur._conn
yield from cur._start_query(query)
pgq = cur._convert_query(query, params)
cur._execute_results(results)
# The above result is an COMMAND_OK. Get the cursor result shape
+ yield from self._describe_gen(cur)
+
+ def _describe_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]:
+ conn = cur._conn
conn.pgconn.send_describe_portal(
self.name.encode(conn.client_encoding)
)
results = yield from execute(conn.pgconn)
cur._execute_results(results)
+ self.described = True
- def _close_gen(self) -> PQGen[None]:
- cur = self._cur
+ def _close_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]:
query = sql.SQL("close {}").format(sql.Identifier(self.name))
yield from cur._conn._exec_command(query)
- def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Tuple[Any, ...]]]:
+ def _fetch_gen(
+ self, cur: BaseCursor[ConnectionType], num: Optional[int]
+ ) -> PQGen[List[Tuple[Any, ...]]]:
+ # If we are stealing the cursor, make sure we know its shape
+ if not self.described:
+ yield from cur._start_query()
+ yield from self._describe_gen(cur)
+
if num is not None:
howmuch: sql.Composable = sql.Literal(num)
else:
howmuch = sql.SQL("all")
- cur = self._cur
query = sql.SQL("fetch forward {} from {}").format(
howmuch, sql.Identifier(self.name)
)
cur.pgresult = res
return cur._tx.load_rows(0, res.ntuples)
- def _scroll_gen(self, value: int, mode: str) -> PQGen[None]:
+ def _scroll_gen(
+ self, cur: BaseCursor[ConnectionType], value: int, mode: str
+ ) -> PQGen[None]:
if mode not in ("relative", "absolute"):
raise ValueError(
f"bad mode: {mode}. It should be 'relative' or 'absolute'"
sql.Literal(value),
sql.Identifier(self.name),
)
- cur = self._cur
yield from cur._conn._exec_command(query)
def _make_declare_statement(
- self, query: Query, scrollable: bool, hold: bool
+ self,
+ cur: BaseCursor[ConnectionType],
+ query: Query,
+ scrollable: bool,
+ hold: bool,
) -> sql.Composable:
- cur = self._cur
if isinstance(query, bytes):
query = query.decode(cur._conn.client_encoding)
if not isinstance(query, sql.Composable):
format: Format = Format.TEXT,
):
super().__init__(connection, format=format)
- self._helper = NamedCursorHelper(name, self)
+ self._helper: NamedCursorHelper["Connection"] = NamedCursorHelper(name)
self.itersize = DEFAULT_ITERSIZE
def __del__(self) -> None:
Close the current cursor and free associated resources.
"""
with self._conn.lock:
- self._conn.wait(self._helper._close_gen())
+ self._conn.wait(self._helper._close_gen(self))
self._close()
def execute(
Execute a query or command to the database.
"""
query = self._helper._make_declare_statement(
- query, scrollable=scrollable, hold=hold
+ self, query, scrollable=scrollable, hold=hold
)
with self._conn.lock:
- self._conn.wait(self._helper._declare_gen(query, params))
+ self._conn.wait(self._helper._declare_gen(self, query, params))
return self
def fetchone(self) -> Optional[Sequence[Any]]:
with self._conn.lock:
- recs = self._conn.wait(self._helper._fetch_gen(1))
+ recs = self._conn.wait(self._helper._fetch_gen(self, 1))
if recs:
self._pos += 1
return recs[0]
if not size:
size = self.arraysize
with self._conn.lock:
- recs = self._conn.wait(self._helper._fetch_gen(size))
+ recs = self._conn.wait(self._helper._fetch_gen(self, size))
self._pos += len(recs)
return recs
def fetchall(self) -> Sequence[Sequence[Any]]:
with self._conn.lock:
- recs = self._conn.wait(self._helper._fetch_gen(None))
+ recs = self._conn.wait(self._helper._fetch_gen(self, None))
self._pos += len(recs)
return recs
def __iter__(self) -> Iterator[Sequence[Any]]:
while True:
with self._conn.lock:
- recs = self._conn.wait(self._helper._fetch_gen(self.itersize))
+ recs = self._conn.wait(
+ self._helper._fetch_gen(self, self.itersize)
+ )
for rec in recs:
self._pos += 1
yield rec
def scroll(self, value: int, mode: str = "relative") -> None:
with self._conn.lock:
- self._conn.wait(self._helper._scroll_gen(value, mode))
+ self._conn.wait(self._helper._scroll_gen(self, value, mode))
# Postgres doesn't have a reliable way to report a cursor out of bound
if mode == "relative":
self._pos += value
format: Format = Format.TEXT,
):
super().__init__(connection, format=format)
- self._helper = NamedCursorHelper(name, self)
+ self._helper: NamedCursorHelper["AsyncConnection"]
+ self._helper = NamedCursorHelper(name)
self.itersize = DEFAULT_ITERSIZE
def __del__(self) -> None:
Close the current cursor and free associated resources.
"""
async with self._conn.lock:
- await self._conn.wait(self._helper._close_gen())
+ await self._conn.wait(self._helper._close_gen(self))
self._close()
async def execute(
Execute a query or command to the database.
"""
query = self._helper._make_declare_statement(
- query, scrollable=scrollable, hold=hold
+ self, query, scrollable=scrollable, hold=hold
)
async with self._conn.lock:
- await self._conn.wait(self._helper._declare_gen(query, params))
+ await self._conn.wait(
+ self._helper._declare_gen(self, query, params)
+ )
return self
async def fetchone(self) -> Optional[Sequence[Any]]:
async with self._conn.lock:
- recs = await self._conn.wait(self._helper._fetch_gen(1))
+ recs = await self._conn.wait(self._helper._fetch_gen(self, 1))
if recs:
self._pos += 1
return recs[0]
if not size:
size = self.arraysize
async with self._conn.lock:
- recs = await self._conn.wait(self._helper._fetch_gen(size))
+ recs = await self._conn.wait(self._helper._fetch_gen(self, size))
self._pos += len(recs)
return recs
async def fetchall(self) -> Sequence[Sequence[Any]]:
async with self._conn.lock:
- recs = await self._conn.wait(self._helper._fetch_gen(None))
+ recs = await self._conn.wait(self._helper._fetch_gen(self, None))
self._pos += len(recs)
return recs
while True:
async with self._conn.lock:
recs = await self._conn.wait(
- self._helper._fetch_gen(self.itersize)
+ self._helper._fetch_gen(self, self.itersize)
)
for rec in recs:
self._pos += 1
async def scroll(self, value: int, mode: str = "relative") -> None:
async with self._conn.lock:
- await self._conn.wait(self._helper._scroll_gen(value, mode))
+ await self._conn.wait(self._helper._scroll_gen(self, value, mode))