conn._autocommit = autocommit
return conn
- def _exec_command(self, command: Query) -> PQGen[None]:
+ def _exec_command(self, command: Query) -> PQGen["PGresult"]:
"""
Generator to send a command and receive the result to the backend.
- Only used to implement internal commands such as commit, returning
- no result. The cursor can do more complex stuff.
+ Only used to implement internal commands such as "commit", with eventual
+ arguments bound client-side. The cursor can do more complex stuff.
"""
if self.pgconn.status != ConnStatus.OK:
if self.pgconn.status == ConnStatus.BAD:
self.pgconn.send_query(command)
result = (yield from execute(self.pgconn))[-1]
- if result.status != ExecStatus.COMMAND_OK:
+ if result.status not in (ExecStatus.COMMAND_OK, ExecStatus.TUPLES_OK):
if result.status == ExecStatus.FATAL_ERROR:
raise e.error_from_result(
result, encoding=self.client_encoding
f"unexpected result {ExecStatus(result.status).name}"
f" from command {command.decode('utf8')!r}"
)
+ return result
def _start_query(self) -> PQGen[None]:
"""Generator to start a transaction if necessary."""
import weakref
import warnings
from types import TracebackType
-from typing import Any, Generic, Optional, Type, TYPE_CHECKING
+from typing import Any, AsyncIterator, Generic, List, Iterator, Optional
+from typing import Sequence, Type, Tuple, TYPE_CHECKING
from . import sql
from .pq import Format
from .connection import BaseConnection # noqa: F401
from .connection import Connection, AsyncConnection # noqa: F401
+DEFAULT_ITERSIZE = 100
+
class NamedCursorHelper(Generic[ConnectionType]):
__slots__ = ("name", "_wcur")
) -> PQGen[None]:
"""Generator implementing `NamedCursor.execute()`."""
cur = self._cur
+ conn = cur._conn
yield from cur._start_query(query)
pgq = cur._convert_query(query, params)
cur._execute_send(pgq)
- results = yield from execute(cur._conn.pgconn)
+ results = yield from execute(conn.pgconn)
cur._execute_results(results)
# The above result is an COMMAND_OK. Get the cursor result shape
- cur._conn.pgconn.send_describe_portal(
- self.name.encode(cur._conn.client_encoding)
+ conn.pgconn.send_describe_portal(
+ self.name.encode(conn.client_encoding)
)
- results = yield from execute(cur._conn.pgconn)
+ results = yield from execute(conn.pgconn)
cur._execute_results(results)
def _close_gen(self) -> PQGen[None]:
query = sql.SQL("close {}").format(sql.Identifier(self.name))
yield from cur._conn._exec_command(query)
+ def _fetch_gen(self, num: Optional[int]) -> PQGen[List[Tuple[Any, ...]]]:
+ if num is not None:
+ howmuch: sql.Composable = sql.Literal(num)
+ else:
+ howmuch = sql.SQL("all")
+
+ cur = self._cur
+ query = sql.SQL("fetch forward {} from {}").format(
+ howmuch, sql.Identifier(self.name)
+ )
+ res = yield from cur._conn._exec_command(query)
+
+ # TODO: loaders don't need to be refreshed
+ cur.pgresult = res
+ nrows = res.ntuples
+ cur._pos += nrows
+ return cur._tx.load_rows(0, nrows)
+
def _make_declare_statement(
self, query: Query, scrollable: bool, hold: bool
) -> sql.Composable:
class NamedCursor(BaseCursor["Connection"]):
__module__ = "psycopg3"
- __slots__ = ("_helper",)
+ __slots__ = ("_helper", "itersize")
def __init__(
self,
):
super().__init__(connection, format=format)
self._helper = NamedCursorHelper(name, self)
+ self.itersize = DEFAULT_ITERSIZE
def __del__(self) -> None:
if not self._closed:
self._conn.wait(self._helper._declare_gen(query, params))
return self
+ def fetchone(self) -> Optional[Sequence[Any]]:
+ with self._conn.lock:
+ recs = self._conn.wait(self._helper._fetch_gen(1))
+ return recs[0] if recs else None
+
+ def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
+ if not size:
+ size = self.arraysize
+ with self._conn.lock:
+ recs = self._conn.wait(self._helper._fetch_gen(size))
+ return recs
+
+ def fetchall(self) -> Sequence[Sequence[Any]]:
+ with self._conn.lock:
+ recs = self._conn.wait(self._helper._fetch_gen(None))
+ return recs
+
+ def __iter__(self) -> Iterator[Sequence[Any]]:
+ while True:
+ with self._conn.lock:
+ recs = self._conn.wait(self._helper._fetch_gen(self.itersize))
+ for rec in recs:
+ yield rec
+ if len(recs) < self.itersize:
+ break
+
class AsyncNamedCursor(BaseCursor["AsyncConnection"]):
__module__ = "psycopg3"
- __slots__ = ("_helper",)
+ __slots__ = ("_helper", "itersize")
def __init__(
self,
):
super().__init__(connection, format=format)
self._helper = NamedCursorHelper(name, self)
+ self.itersize = DEFAULT_ITERSIZE
def __del__(self) -> None:
if not self._closed:
async with self._conn.lock:
await self._conn.wait(self._helper._declare_gen(query, params))
return self
+
+ async def fetchone(self) -> Optional[Sequence[Any]]:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._helper._fetch_gen(1))
+ return recs[0] if recs else None
+
+ async def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
+ if not size:
+ size = self.arraysize
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._helper._fetch_gen(size))
+ return recs
+
+ async def fetchall(self) -> Sequence[Sequence[Any]]:
+ async with self._conn.lock:
+ recs = await self._conn.wait(self._helper._fetch_gen(None))
+ return recs
+
+ async def __aiter__(self) -> AsyncIterator[Sequence[Any]]:
+ while True:
+ async with self._conn.lock:
+ recs = await self._conn.wait(
+ self._helper._fetch_gen(self.itersize)
+ )
+ for rec in recs:
+ yield rec
+ if len(recs) < self.itersize:
+ break
from . import sql
from .pq import TransactionStatus
from .proto import ConnectionType, PQGen
+from .pq.proto import PGresult
if TYPE_CHECKING:
from .connection import Connection, AsyncConnection # noqa: F401
sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
- def _enter_gen(self) -> PQGen[None]:
+ def _enter_gen(self) -> PQGen[PGresult]:
if self._entered:
raise TypeError("transaction blocks can be used only once")
self._entered = True
else:
return (yield from self._rollback_gen(exc_val))
- def _commit_gen(self) -> PQGen[None]:
+ def _commit_gen(self) -> PQGen[PGresult]:
assert self._conn._savepoints[-1] == self._savepoint_name
self._conn._savepoints.pop()
self._exited = True
conn = Connection.connect(dsn, autocommit=True)
yield conn
conn.close()
+
+
+@pytest.fixture
+def commands(conn, monkeypatch):
+ """The list of commands issued internally by the test connection."""
+ yield patch_exec(conn, monkeypatch)
+
+
+@pytest.fixture
+def acommands(aconn, monkeypatch):
+ """The list of commands issued internally by the test async connection."""
+ yield patch_exec(aconn, monkeypatch)
+
+
+def patch_exec(conn, monkeypatch):
+ """Helper to implement the commands fixture both sync and async."""
+ from psycopg3 import sql
+
+ _orig_exec_command = conn._exec_command
+ L = ListPopAll()
+
+ def _exec_command(command):
+ cmdcopy = command
+ if isinstance(cmdcopy, bytes):
+ cmdcopy = cmdcopy.decode(conn.client_encoding)
+ elif isinstance(cmdcopy, sql.Composable):
+ cmdcopy = cmdcopy.as_string(conn)
+
+ L.insert(0, cmdcopy)
+ return _orig_exec_command(command)
+
+ monkeypatch.setattr(conn, "_exec_command", _exec_command)
+ return L
+
+
+class ListPopAll(list):
+ """A list, with a popall() method."""
+
+ def popall(self):
+ out = self[:]
+ del self[:]
+ return out
cur.execute("select generate_series(1, 10) as bar")
del cur
assert ".close()" in str(recwarn.pop(ResourceWarning).message)
+
+
+def test_fetchone(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (2,))
+ assert cur.fetchone() == (1,)
+ assert cur.fetchone() == (2,)
+ assert cur.fetchone() is None
+
+
+def test_fetchmany(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (5,))
+ assert cur.fetchmany(3) == [(1,), (2,), (3,)]
+ assert cur.fetchone() == (4,)
+ assert cur.fetchmany(3) == [(5,)]
+ assert cur.fetchmany(3) == []
+
+
+def test_fetchall(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchall() == [(1,), (2,), (3,)]
+ assert cur.fetchall() == []
+
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchone() == (1,)
+ assert cur.fetchall() == [(2,), (3,)]
+ assert cur.fetchall() == []
+
+
+def test_rownumber(conn):
+ cur = conn.cursor("foo")
+ assert cur.rownumber is None
+
+ cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ cur.fetchone()
+ assert cur.rownumber == 1
+ cur.fetchone()
+ assert cur.rownumber == 2
+ cur.fetchmany(10)
+ assert cur.rownumber == 12
+ cur.fetchall()
+ assert cur.rownumber == 42
+
+
+def test_iter(conn):
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ recs = list(cur)
+ assert recs == [(1,), (2,), (3,)]
+
+ with conn.cursor("foo") as cur:
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert cur.fetchone() == (1,)
+ recs = list(cur)
+ assert recs == [(2,), (3,)]
+
+
+def test_itersize(conn, commands):
+ with conn.cursor("foo") as cur:
+ assert cur.itersize == 100
+ cur.itersize = 2
+ cur.execute("select generate_series(1, %s) as bar", (3,))
+ commands.popall() # flush begin and other noise
+
+ list(cur)
+ cmds = commands.popall()
+ assert len(cmds) == 2
+ for cmd in cmds:
+ assert ("fetch forward 2") in cmd.lower()
await cur.execute("select generate_series(1, 10) as bar")
del cur
assert ".close()" in str(recwarn.pop(ResourceWarning).message)
+
+
+async def test_fetchone(aconn):
+ async with await aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (2,))
+ assert await cur.fetchone() == (1,)
+ assert await cur.fetchone() == (2,)
+ assert await cur.fetchone() is None
+
+
+async def test_fetchmany(aconn):
+ async with await aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (5,))
+ assert await cur.fetchmany(3) == [(1,), (2,), (3,)]
+ assert await cur.fetchone() == (4,)
+ assert await cur.fetchmany(3) == [(5,)]
+ assert await cur.fetchmany(3) == []
+
+
+async def test_fetchall(aconn):
+ async with await aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchall() == [(1,), (2,), (3,)]
+ assert await cur.fetchall() == []
+
+ async with await aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchone() == (1,)
+ assert await cur.fetchall() == [(2,), (3,)]
+ assert await cur.fetchall() == []
+
+
+async def test_rownumber(aconn):
+ cur = await aconn.cursor("foo")
+ assert cur.rownumber is None
+
+ await cur.execute("select 1 from generate_series(1, 42)")
+ assert cur.rownumber == 0
+
+ await cur.fetchone()
+ assert cur.rownumber == 1
+ await cur.fetchone()
+ assert cur.rownumber == 2
+ await cur.fetchmany(10)
+ assert cur.rownumber == 12
+ await cur.fetchall()
+ assert cur.rownumber == 42
+
+
+async def test_iter(aconn):
+ async with await aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ recs = []
+ async for rec in cur:
+ recs.append(rec)
+ assert recs == [(1,), (2,), (3,)]
+
+ async with await aconn.cursor("foo") as cur:
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ assert await cur.fetchone() == (1,)
+ recs = []
+ async for rec in cur:
+ recs.append(rec)
+ assert recs == [(2,), (3,)]
+
+
+async def test_itersize(aconn, acommands):
+ async with await aconn.cursor("foo") as cur:
+ assert cur.itersize == 100
+ cur.itersize = 2
+ await cur.execute("select generate_series(1, %s) as bar", (3,))
+ acommands.popall() # flush begin and other noise
+
+ async for rec in cur:
+ pass
+ cmds = acommands.popall()
+ assert len(cmds) == 2
+ for cmd in cmds:
+ assert ("fetch forward 2") in cmd.lower()
return f()
-class ListPopAll(list):
- """A list, with a popall() method."""
-
- def popall(self):
- out = self[:]
- del self[:]
- return out
-
-
-@pytest.fixture
-def commands(conn, monkeypatch):
- """The list of commands issued internally by the test connection."""
- yield patch_exec(conn, monkeypatch)
-
-
-def patch_exec(conn, monkeypatch):
- """Helper to implement the commands fixture both sync and async."""
- _orig_exec_command = conn._exec_command
- L = ListPopAll()
-
- def _exec_command(command):
- if isinstance(command, bytes):
- command = command.decode(conn.client_encoding)
-
- L.insert(0, command)
- return _orig_exec_command(command)
-
- monkeypatch.setattr(conn, "_exec_command", _exec_command)
- return L
-
-
def in_transaction(conn):
if conn.pgconn.transaction_status == conn.TransactionStatus.IDLE:
return False
from psycopg3 import ProgrammingError, Rollback
from .test_transaction import in_transaction, insert_row, inserted
-from .test_transaction import ExpectedException, patch_exec
+from .test_transaction import ExpectedException
from .test_transaction import create_test_table # noqa # autouse fixture
pytestmark = pytest.mark.asyncio
-@pytest.fixture
-def commands(aconn, monkeypatch):
- """The list of commands issued internally by the test connection."""
- yield patch_exec(aconn, monkeypatch)
-
-
async def test_basic(aconn):
"""Basic use of transaction() to BEGIN and COMMIT a transaction."""
assert not in_transaction(aconn)
pass
-async def test_named_savepoints_successful_exit(aconn, commands):
+async def test_named_savepoints_successful_exit(aconn, acommands):
"""
Entering a transaction context will do one of these these things:
1. Begin an outer transaction (if one isn't already in progress)
...and exiting the context successfully will "commit" the same.
"""
+ commands = acommands
+
# Case 1
# Using Transaction explicitly becase conn.transaction() enters the contetx
async with aconn.transaction() as tx:
assert commands.popall() == ["commit"]
-async def test_named_savepoints_exception_exit(aconn, commands):
+async def test_named_savepoints_exception_exit(aconn, acommands):
"""
Same as the previous test but checks that when exiting the context with an
exception, whatever transaction and/or savepoint was started on enter will
be rolled-back as appropriate.
"""
+ commands = acommands
+
# Case 1
with pytest.raises(ExpectedException):
async with aconn.transaction() as tx: