# Copyright (C) 2020 The Psycopg Team
-from typing import Any, AsyncIterator, Generic, List, Iterable, Iterator
+from typing import Any, AsyncIterator, List, Iterable, Iterator
from typing import Optional, TypeVar, TYPE_CHECKING, overload
from warnings import warn
DEFAULT_ITERSIZE = 100
-class ServerCursorHelper(Generic[ConnectionType, Row]):
- __slots__ = ("name", "scrollable", "withhold", "described", "_format")
- """Helper object for common ServerCursor code.
+class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
+ """Mixin to add ServerCursor behaviour and implementation a BaseCursor."""
- TODO: this should be a mixin, but couldn't find a way to work it
- correctly with the generic.
- """
+ __slots__ = "_name _scrollable _withhold _described itersize _format".split()
def __init__(
self,
scrollable: Optional[bool],
withhold: bool,
):
- self.name = name
- self.scrollable = scrollable
- self.withhold = withhold
- self.described = False
+ self._name = name
+ self._scrollable = scrollable
+ self._withhold = withhold
+ self._described = False
+ self.itersize: int = DEFAULT_ITERSIZE
self._format = pq.Format.TEXT
- def _repr(self, cur: BaseCursor[ConnectionType, Row]) -> str:
+ def __repr__(self) -> str:
# Insert the name as the second word
- parts = parts = BaseCursor.__repr__(cur).split(None, 1)
- parts.insert(1, f"{self.name!r}")
+ parts = super().__repr__().split(None, 1)
+ parts.insert(1, f"{self._name!r}")
return " ".join(parts)
+ @property
+ def name(self) -> str:
+ """The name of the cursor."""
+ return self._name
+
+ @property
+ def scrollable(self) -> Optional[bool]:
+ """
+ Whether the cursor is scrollable or not.
+
+ If `!None` leave the choice to the server. Use `!True` if you want to
+ use `scroll()` on the cursor.
+ """
+ return self._scrollable
+
+ @property
+ def withhold(self) -> bool:
+ """
+ If the cursor can be used after the creating transaction has committed.
+ """
+ return self._withhold
+
def _declare_gen(
self,
- cur: BaseCursor[ConnectionType, Row],
query: Query,
params: Optional[Params] = None,
binary: Optional[bool] = None,
) -> PQGen[None]:
"""Generator implementing `ServerCursor.execute()`."""
- query = self._make_declare_statement(cur, query)
+ query = self._make_declare_statement(query)
# If the cursor is being reused, the previous one must be closed.
- if self.described:
- yield from self._close_gen(cur)
- self.described = False
-
- yield from cur._start_query(query)
- pgq = cur._convert_query(query, params)
- cur._execute_send(pgq, no_pqexec=True)
- results = yield from execute(cur._conn.pgconn)
+ if self._described:
+ yield from self._close_gen()
+ self._described = False
+
+ yield from self._start_query(query)
+ pgq = self._convert_query(query, params)
+ self._execute_send(pgq, no_pqexec=True)
+ results = yield from execute(self._conn.pgconn)
if results[-1].status != pq.ExecStatus.COMMAND_OK:
- cur._raise_for_result(results[-1])
+ self._raise_for_result(results[-1])
# Set the format, which will be used by describe and fetch operations
if binary is None:
- self._format = cur.format
+ self._format = self.format
else:
self._format = pq.Format.BINARY if binary else pq.Format.TEXT
# The above result only returned COMMAND_OK. Get the cursor shape
- yield from self._describe_gen(cur)
+ yield from self._describe_gen()
- def _describe_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]:
- conn = cur._conn
- conn.pgconn.send_describe_portal(self.name.encode(cur._encoding))
- results = yield from execute(conn.pgconn)
- cur._check_results(results)
- cur._results = results
- cur._set_current_result(0, format=self._format)
- self.described = True
+ def _describe_gen(self) -> PQGen[None]:
+ self._pgconn.send_describe_portal(self._name.encode(self._encoding))
+ results = yield from execute(self._pgconn)
+ self._check_results(results)
+ self._results = results
+ self._set_current_result(0, format=self._format)
+ self._described = True
- def _close_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]:
- ts = cur._conn.pgconn.transaction_status
+ def _close_gen(self) -> PQGen[None]:
+ ts = self._conn.pgconn.transaction_status
# if the connection is not in a sane state, don't even try
if ts not in (pq.TransactionStatus.IDLE, pq.TransactionStatus.INTRANS):
return
# If we are IDLE, a WITHOUT HOLD cursor will surely have gone already.
- if not self.withhold and ts == pq.TransactionStatus.IDLE:
+ if not self._withhold and ts == pq.TransactionStatus.IDLE:
return
# if we didn't declare the cursor ourselves we still have to close it
# but we must make sure it exists.
- if not self.described:
+ if not self._described:
query = sql.SQL(
"SELECT 1 FROM pg_catalog.pg_cursors WHERE name = {}"
- ).format(sql.Literal(self.name))
- res = yield from cur._conn._exec_command(query)
+ ).format(sql.Literal(self._name))
+ res = yield from self._conn._exec_command(query)
# pipeline mode otherwise, unsupported here.
assert res is not None
if res.ntuples == 0:
return
- query = sql.SQL("CLOSE {}").format(sql.Identifier(self.name))
- yield from cur._conn._exec_command(query)
+ query = sql.SQL("CLOSE {}").format(sql.Identifier(self._name))
+ yield from self._conn._exec_command(query)
- def _fetch_gen(
- self, cur: BaseCursor[ConnectionType, Row], num: Optional[int]
- ) -> PQGen[List[Row]]:
- if cur.closed:
+ def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Row]]:
+ if self.closed:
raise e.InterfaceError("the cursor is closed")
# If we are stealing the cursor, make sure we know its shape
- if not self.described:
- yield from cur._start_query()
- yield from self._describe_gen(cur)
+ if not self._described:
+ yield from self._start_query()
+ yield from self._describe_gen()
query = sql.SQL("FETCH FORWARD {} FROM {}").format(
sql.SQL("ALL") if num is None else sql.Literal(num),
- sql.Identifier(self.name),
+ sql.Identifier(self._name),
)
- res = yield from cur._conn._exec_command(query, result_format=self._format)
+ res = yield from self._conn._exec_command(query, result_format=self._format)
# pipeline mode otherwise, unsupported here.
assert res is not None
- cur.pgresult = res
- cur._tx.set_pgresult(res, set_loaders=False)
- return cur._tx.load_rows(0, res.ntuples, cur._make_row)
+ self.pgresult = res
+ self._tx.set_pgresult(res, set_loaders=False)
+ return self._tx.load_rows(0, res.ntuples, self._make_row)
- def _scroll_gen(
- self, cur: BaseCursor[ConnectionType, Row], value: int, mode: str
- ) -> PQGen[None]:
+ def _scroll_gen(self, value: int, mode: str) -> PQGen[None]:
if mode not in ("relative", "absolute"):
raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
query = sql.SQL("MOVE{} {} FROM {}").format(
sql.SQL(" ABSOLUTE" if mode == "absolute" else ""),
sql.Literal(value),
- sql.Identifier(self.name),
+ sql.Identifier(self._name),
)
- yield from cur._conn._exec_command(query)
+ yield from self._conn._exec_command(query)
- def _make_declare_statement(
- self, cur: BaseCursor[ConnectionType, Row], query: Query
- ) -> sql.Composable:
+ def _make_declare_statement(self, query: Query) -> sql.Composable:
if isinstance(query, bytes):
- query = query.decode(cur._encoding)
+ query = query.decode(self._encoding)
if not isinstance(query, sql.Composable):
query = sql.SQL(query)
parts = [
sql.SQL("DECLARE"),
- sql.Identifier(self.name),
+ sql.Identifier(self._name),
]
- if self.scrollable is not None:
- parts.append(sql.SQL("SCROLL" if self.scrollable else "NO SCROLL"))
+ if self._scrollable is not None:
+ parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL"))
parts.append(sql.SQL("CURSOR"))
- if self.withhold:
+ if self._withhold:
parts.append(sql.SQL("WITH HOLD"))
parts.append(sql.SQL("FOR"))
parts.append(query)
_AC = TypeVar("_AC", bound="AsyncServerCursor[Any]")
-class ServerCursor(Cursor[Row]):
+class ServerCursor(ServerCursorMixin["Connection[Row]", Row], Cursor[Row]):
__module__ = "psycopg"
- __slots__ = ("_helper", "itersize")
+ __slots__ = ()
@overload
def __init__(
scrollable: Optional[bool] = None,
withhold: bool = False,
):
- super().__init__(connection, row_factory=row_factory or connection.row_factory)
- self._helper: ServerCursorHelper["Connection[Any]", Row]
- self._helper = ServerCursorHelper(name, scrollable, withhold)
- self.itersize: int = DEFAULT_ITERSIZE
+ Cursor.__init__(
+ self, connection, row_factory=row_factory or connection.row_factory
+ )
+ ServerCursorMixin.__init__(self, name, scrollable, withhold)
def __del__(self) -> None:
if not self.closed:
ResourceWarning,
)
- def __repr__(self) -> str:
- return self._helper._repr(self)
-
- @property
- def name(self) -> str:
- """The name of the cursor."""
- return self._helper.name
-
- @property
- def scrollable(self) -> Optional[bool]:
- """
- Whether the cursor is scrollable or not.
-
- If `!None` leave the choice to the server. Use `!True` if you want to
- use `scroll()` on the cursor.
- """
- return self._helper.scrollable
-
- @property
- def withhold(self) -> bool:
- """
- If the cursor can be used after the creating transaction has committed.
- """
- return self._helper.withhold
-
def close(self) -> None:
"""
Close the current cursor and free associated resources.
if self.closed:
return
if not self._conn.closed:
- self._conn.wait(self._helper._close_gen(self))
+ self._conn.wait(self._close_gen())
super().close()
def execute(
try:
with self._conn.lock:
- self._conn.wait(self._helper._declare_gen(self, query, params, binary))
+ self._conn.wait(self._declare_gen(query, params, binary))
except e.Error as ex:
raise ex.with_traceback(None)
def fetchone(self) -> Optional[Row]:
with self._conn.lock:
- recs = self._conn.wait(self._helper._fetch_gen(self, 1))
+ recs = self._conn.wait(self._fetch_gen(1))
if recs:
self._pos += 1
return recs[0]
if not size:
size = self.arraysize
with self._conn.lock:
- recs = self._conn.wait(self._helper._fetch_gen(self, size))
+ recs = self._conn.wait(self._fetch_gen(size))
self._pos += len(recs)
return recs
def fetchall(self) -> List[Row]:
with self._conn.lock:
- recs = self._conn.wait(self._helper._fetch_gen(self, None))
+ recs = self._conn.wait(self._fetch_gen(None))
self._pos += len(recs)
return recs
def __iter__(self) -> Iterator[Row]:
while True:
with self._conn.lock:
- recs = self._conn.wait(self._helper._fetch_gen(self, self.itersize))
+ recs = self._conn.wait(self._fetch_gen(self.itersize))
for rec in recs:
self._pos += 1
yield rec
def scroll(self, value: int, mode: str = "relative") -> None:
with self._conn.lock:
- self._conn.wait(self._helper._scroll_gen(self, value, mode))
+ self._conn.wait(self._scroll_gen(value, mode))
# Postgres doesn't have a reliable way to report a cursor out of bound
if mode == "relative":
self._pos += value
self._pos = value
-class AsyncServerCursor(AsyncCursor[Row]):
+class AsyncServerCursor(
+ ServerCursorMixin["AsyncConnection[Row]", Row], AsyncCursor[Row]
+):
__module__ = "psycopg"
- __slots__ = ("_helper", "itersize")
+ __slots__ = ()
@overload
def __init__(
scrollable: Optional[bool] = None,
withhold: bool = False,
):
- super().__init__(connection, row_factory=row_factory or connection.row_factory)
- self._helper: ServerCursorHelper["AsyncConnection[Any]", Row]
- self._helper = ServerCursorHelper(name, scrollable, withhold)
- self.itersize: int = DEFAULT_ITERSIZE
+ AsyncCursor.__init__(
+ self, connection, row_factory=row_factory or connection.row_factory
+ )
+ ServerCursorMixin.__init__(self, name, scrollable, withhold)
def __del__(self) -> None:
if not self.closed:
ResourceWarning,
)
- def __repr__(self) -> str:
- return self._helper._repr(self)
-
- @property
- def name(self) -> str:
- return self._helper.name
-
- @property
- def scrollable(self) -> Optional[bool]:
- return self._helper.scrollable
-
- @property
- def withhold(self) -> bool:
- return self._helper.withhold
-
async def close(self) -> None:
async with self._conn.lock:
if self.closed:
return
if not self._conn.closed:
- await self._conn.wait(self._helper._close_gen(self))
+ await self._conn.wait(self._close_gen())
await super().close()
async def execute(
try:
async with self._conn.lock:
- await self._conn.wait(
- self._helper._declare_gen(self, query, params, binary)
- )
+ await self._conn.wait(self._declare_gen(query, params, binary))
except e.Error as ex:
raise ex.with_traceback(None)
async def fetchone(self) -> Optional[Row]:
async with self._conn.lock:
- recs = await self._conn.wait(self._helper._fetch_gen(self, 1))
+ recs = await self._conn.wait(self._fetch_gen(1))
if recs:
self._pos += 1
return recs[0]
if not size:
size = self.arraysize
async with self._conn.lock:
- recs = await self._conn.wait(self._helper._fetch_gen(self, size))
+ recs = await self._conn.wait(self._fetch_gen(size))
self._pos += len(recs)
return recs
async def fetchall(self) -> List[Row]:
async with self._conn.lock:
- recs = await self._conn.wait(self._helper._fetch_gen(self, None))
+ recs = await self._conn.wait(self._fetch_gen(None))
self._pos += len(recs)
return recs
async def __aiter__(self) -> AsyncIterator[Row]:
while True:
async with self._conn.lock:
- recs = await self._conn.wait(
- self._helper._fetch_gen(self, self.itersize)
- )
+ recs = await self._conn.wait(self._fetch_gen(self.itersize))
for rec in recs:
self._pos += 1
yield rec
async def scroll(self, value: int, mode: str = "relative") -> None:
async with self._conn.lock:
- await self._conn.wait(self._helper._scroll_gen(self, value, mode))
+ await self._conn.wait(self._scroll_gen(value, mode))