From 151465ff301d985fb1c40feabd8e9d15e31c800c Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Tue, 22 Dec 2020 18:09:34 +0100 Subject: [PATCH] Implement connection/cursor procedures as generators Generators are common to both sync and async code so there is much less code duplication, which will be useful in the light of doing more complex things such as prepared transactions. As a side effect of the refactoring, operations on closed connections are raised as OperationalError rather than InterfaceError because connection closure may depend on external factors. Dropped fd from PQGen stream: the having the fd in the middle of the `Ready` values made more difficult to chain more than one operation in the same generator. The fd is not passed as extra parameter. --- psycopg3/psycopg3/connection.py | 311 +++++++++++++-------------- psycopg3/psycopg3/cursor.py | 175 +++++++-------- psycopg3/psycopg3/generators.py | 5 - psycopg3/psycopg3/proto.py | 2 +- psycopg3/psycopg3/transaction.py | 100 ++++----- psycopg3/psycopg3/waiting.py | 208 ++++++++++-------- psycopg3_c/psycopg3_c/generators.pyx | 4 - tests/pq/test_async.py | 24 ++- tests/pq/test_pgconn.py | 2 +- tests/test_connection.py | 2 +- tests/test_connection_async.py | 2 +- tests/test_transaction.py | 11 +- tests/test_transaction_async.py | 19 +- 13 files changed, 416 insertions(+), 449 deletions(-) diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index d801a2238..9af0bc691 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -10,7 +10,7 @@ import logging import threading from types import TracebackType from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple -from typing import Optional, Type, TYPE_CHECKING, Union +from typing import Optional, Type, TYPE_CHECKING, TypeVar from weakref import ref, ReferenceType from functools import partial from contextlib import contextmanager @@ -23,11 +23,11 @@ else: from . import pq from . import cursor from . import errors as e +from . import waiting from . import encodings from .pq import TransactionStatus, ExecStatus, Format from .sql import Composable from .proto import DumpersMap, LoadersMap, PQGen, PQGenConn, RV, Query, Params -from .waiting import wait, wait_async from .conninfo import make_conninfo from .generators import notifies from .transaction import Transaction, AsyncTransaction @@ -73,6 +73,8 @@ Notify.__module__ = "psycopg3" NoticeHandler = Callable[[e.Diagnostic], None] NotifyHandler = Callable[[Notify], None] +C = TypeVar("C", bound="BaseConnection") + class BaseConnection: """ @@ -164,6 +166,15 @@ class BaseConnection: def _set_client_encoding(self, name: str) -> None: raise NotImplementedError + def _set_client_encoding_gen(self, name: str) -> PQGen[None]: + self.pgconn.send_query_params( + b"select set_config('client_encoding', $1, false)", + [encodings.py2pg(name)], + ) + (result,) = yield from execute(self.pgconn) + if result.status != ExecStatus.TUPLES_OK: + raise e.error_from_result(result, encoding=self.client_encoding) + def cancel(self) -> None: """Cancel the current operation on the connection.""" c = self.pgconn.get_cancel() @@ -223,6 +234,100 @@ class BaseConnection: for cb in self._notify_handlers: cb(n) + # Generators to perform high-level operations on the connection + # + # These operations are expressed in terms of non-blocking generators + # and the task of waiting when needed (when the generators yield) is left + # to the connections subclass, which might wait either in blocking mode + # or through asyncio. + # + # All these generators assume exclusive acces to the connection: subclasses + # should have a lock and hold it before calling and consuming them. + + @classmethod + def _connect_gen( + cls: Type[C], + conninfo: str = "", + *, + autocommit: bool = False, + **kwargs: Any, + ) -> PQGenConn[C]: + """Generator to connect to the database and create a new instance.""" + conninfo = make_conninfo(conninfo, **kwargs) + pgconn = yield from connect(conninfo) + conn = cls(pgconn) + conn._autocommit = autocommit + return conn + + def _exec_command(self, command: Query) -> PQGen[None]: + """ + Generator to send a command and receive the result to the backend. + + Only used to implement internal commands such as commit, returning + no result. The cursor can do more complex stuff. + """ + if self.pgconn.status != self.ConnStatus.OK: + if self.pgconn.status == self.ConnStatus.BAD: + raise e.OperationalError("the connection is closed") + raise e.InterfaceError( + f"cannot execute operations: the connection is" + f" in status {self.pgconn.status}" + ) + + if isinstance(command, str): + command = command.encode(self.client_encoding) + elif isinstance(command, Composable): + command = command.as_bytes(self) + + self.pgconn.send_query(command) + result = (yield from execute(self.pgconn))[-1] + if result.status != ExecStatus.COMMAND_OK: + if result.status == ExecStatus.FATAL_ERROR: + raise e.error_from_result( + result, encoding=self.client_encoding + ) + else: + raise e.InterfaceError( + f"unexpected result {ExecStatus(result.status).name}" + f" from command {command.decode('utf8')!r}" + ) + + def _start_query(self) -> PQGen[None]: + """Generator to start a transaction if necessary.""" + if self._autocommit: + return + + if self.pgconn.transaction_status != TransactionStatus.IDLE: + return + + yield from self._exec_command(b"begin") + + def _commit_gen(self) -> PQGen[None]: + """Generator implementing `Connection.commit()`.""" + if self._savepoints: + raise e.ProgrammingError( + "Explicit commit() forbidden within a Transaction " + "context. (Transaction will be automatically committed " + "on successful exit from context.)" + ) + if self.pgconn.transaction_status == TransactionStatus.IDLE: + return + + yield from self._exec_command(b"commit") + + def _rollback_gen(self) -> PQGen[None]: + """Generator implementing `Connection.rollback()`.""" + if self._savepoints: + raise e.ProgrammingError( + "Explicit rollback() forbidden within a Transaction " + "context. (Either raise Rollback() or allow " + "an exception to propagate out of the context.)" + ) + if self.pgconn.transaction_status == TransactionStatus.IDLE: + return + + yield from self._exec_command(b"rollback") + class Connection(BaseConnection): """ @@ -247,13 +352,9 @@ class Connection(BaseConnection): TODO: connection_timeout to be implemented. """ - - conninfo = make_conninfo(conninfo, **kwargs) - gen = connect(conninfo) - pgconn = cls.wait(gen) - conn = cls(pgconn) - conn._autocommit = autocommit - return conn + return cls._wait_conn( + cls._connect_gen(conninfo, autocommit=autocommit, **kwargs) + ) def __enter__(self) -> "Connection": return self @@ -284,16 +385,6 @@ class Connection(BaseConnection): return self.cursor_factory(self, format=format) - def _start_query(self) -> None: - # the function is meant to be called by a cursor once the lock is taken - if self._autocommit: - return - - if self.pgconn.transaction_status != TransactionStatus.IDLE: - return - - self._exec_command(b"begin") - def execute( self, query: Query, params: Optional[Params] = None ) -> "Cursor": @@ -304,49 +395,12 @@ class Connection(BaseConnection): def commit(self) -> None: """Commit any pending transaction to the database.""" with self.lock: - if self._savepoints: - raise e.ProgrammingError( - "Explicit commit() forbidden within a Transaction " - "context. (Transaction will be automatically committed " - "on successful exit from context.)" - ) - if self.pgconn.transaction_status == TransactionStatus.IDLE: - return - self._exec_command(b"commit") + self.wait(self._commit_gen()) def rollback(self) -> None: """Roll back to the start of any pending transaction.""" with self.lock: - if self._savepoints: - raise e.ProgrammingError( - "Explicit rollback() forbidden within a Transaction " - "context. (Either raise Rollback() or allow " - "an exception to propagate out of the context.)" - ) - if self.pgconn.transaction_status == TransactionStatus.IDLE: - return - self._exec_command(b"rollback") - - def _exec_command(self, command: Query) -> None: - # Caller must hold self.lock - - if isinstance(command, str): - command = command.encode(self.client_encoding) - elif isinstance(command, Composable): - command = command.as_string(self).encode(self.client_encoding) - - self.pgconn.send_query(command) - result = self.wait(execute(self.pgconn))[-1] - if result.status != ExecStatus.COMMAND_OK: - if result.status == ExecStatus.FATAL_ERROR: - raise e.error_from_result( - result, encoding=self.client_encoding - ) - else: - raise e.InterfaceError( - f"unexpected result {ExecStatus(result.status).name}" - f" from command {command.decode('utf8')!r}" - ) + self.wait(self._rollback_gen()) @contextmanager def transaction( @@ -365,27 +419,6 @@ class Connection(BaseConnection): with Transaction(self, savepoint_name, force_rollback) as tx: yield tx - @classmethod - def wait( - cls, - gen: Union[PQGen[RV], PQGenConn[RV]], - timeout: Optional[float] = 0.1, - ) -> RV: - return wait(gen, timeout=timeout) - - def _set_client_encoding(self, name: str) -> None: - with self.lock: - self.pgconn.send_query_params( - b"select set_config('client_encoding', $1, false)", - [encodings.py2pg(name)], - ) - gen = execute(self.pgconn) - (result,) = self.wait(gen) - if result.status != ExecStatus.TUPLES_OK: - raise e.error_from_result( - result, encoding=self.client_encoding - ) - def notifies(self) -> Iterator[Notify]: """ Yield `Notify` objects as soon as they are received from the database. @@ -400,10 +433,30 @@ class Connection(BaseConnection): ) yield n + def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: + """ + Consume a generator operating on the connection. + + The function must be used on generators that don't change connection + fd (i.e. not on connect and reset). + """ + return waiting.wait(gen, self.pgconn.socket, timeout=timeout) + + @classmethod + def _wait_conn( + cls, gen: PQGenConn[RV], timeout: Optional[float] = 0.1 + ) -> RV: + """Consume a connection generator.""" + return waiting.wait_conn(gen, timeout=timeout) + def _set_autocommit(self, value: bool) -> None: with self.lock: super()._set_autocommit(value) + def _set_client_encoding(self, name: str) -> None: + with self.lock: + self.wait(self._set_client_encoding_gen(name)) + class AsyncConnection(BaseConnection): """ @@ -423,12 +476,9 @@ class AsyncConnection(BaseConnection): async def connect( cls, conninfo: str = "", *, autocommit: bool = False, **kwargs: Any ) -> "AsyncConnection": - conninfo = make_conninfo(conninfo, **kwargs) - gen = connect(conninfo) - pgconn = await cls.wait(gen) - conn = cls(pgconn) - conn._autocommit = autocommit - return conn + return await cls._wait_conn( + cls._connect_gen(conninfo, autocommit=autocommit, **kwargs) + ) async def __aenter__(self) -> "AsyncConnection": return self @@ -460,16 +510,6 @@ class AsyncConnection(BaseConnection): return self.cursor_factory(self, format=format) - async def _start_query(self) -> None: - # the function is meant to be called by a cursor once the lock is taken - if self._autocommit: - return - - if self.pgconn.transaction_status != TransactionStatus.IDLE: - return - - await self._exec_command(b"begin") - async def execute( self, query: Query, params: Optional[Params] = None ) -> "AsyncCursor": @@ -478,48 +518,11 @@ class AsyncConnection(BaseConnection): async def commit(self) -> None: async with self.lock: - if self._savepoints: - raise e.ProgrammingError( - "Explicit commit() forbidden within a Transaction " - "context. (Transaction will be automatically committed " - "on successful exit from context.)" - ) - if self.pgconn.transaction_status == TransactionStatus.IDLE: - return - await self._exec_command(b"commit") + await self.wait(self._commit_gen()) async def rollback(self) -> None: async with self.lock: - if self._savepoints: - raise e.ProgrammingError( - "Explicit rollback() forbidden within a Transaction " - "context. (Either raise Rollback() or allow " - "an exception to propagate out of the context.)" - ) - if self.pgconn.transaction_status == TransactionStatus.IDLE: - return - await self._exec_command(b"rollback") - - async def _exec_command(self, command: Query) -> None: - # Caller must hold self.lock - - if isinstance(command, str): - command = command.encode(self.client_encoding) - elif isinstance(command, Composable): - command = command.as_string(self).encode(self.client_encoding) - - self.pgconn.send_query(command) - result = (await self.wait(execute(self.pgconn)))[-1] - if result.status != ExecStatus.COMMAND_OK: - if result.status == ExecStatus.FATAL_ERROR: - raise e.error_from_result( - result, encoding=self.client_encoding - ) - else: - raise e.InterfaceError( - f"unexpected result {ExecStatus(result.status).name}" - f" from command {command.decode('utf8')!r}" - ) + await self.wait(self._rollback_gen()) @asynccontextmanager async def transaction( @@ -534,9 +537,23 @@ class AsyncConnection(BaseConnection): async with tx: yield tx + async def notifies(self) -> AsyncIterator[Notify]: + while 1: + async with self.lock: + ns = await self.wait(notifies(self.pgconn)) + enc = self.client_encoding + for pgn in ns: + n = Notify( + pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid + ) + yield n + + async def wait(self, gen: PQGen[RV]) -> RV: + return await waiting.wait_async(gen, self.pgconn.socket) + @classmethod - async def wait(cls, gen: Union[PQGen[RV], PQGenConn[RV]]) -> RV: - return await wait_async(gen) + async def _wait_conn(cls, gen: PQGenConn[RV]) -> RV: + return await waiting.wait_async_conn(gen) def _set_client_encoding(self, name: str) -> None: raise AttributeError( @@ -547,27 +564,7 @@ class AsyncConnection(BaseConnection): async def set_client_encoding(self, name: str) -> None: """Async version of the `~Connection.client_encoding` setter.""" async with self.lock: - self.pgconn.send_query_params( - b"select set_config('client_encoding', $1, false)", - [encodings.py2pg(name)], - ) - gen = execute(self.pgconn) - (result,) = await self.wait(gen) - if result.status != ExecStatus.TUPLES_OK: - raise e.error_from_result( - result, encoding=self.client_encoding - ) - - async def notifies(self) -> AsyncIterator[Notify]: - while 1: - async with self.lock: - ns = await self.wait(notifies(self.pgconn)) - enc = self.client_encoding - for pgn in ns: - n = Notify( - pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid - ) - yield n + await self.wait(self._set_client_encoding_gen(name)) def _set_autocommit(self, value: bool) -> None: raise AttributeError( diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index 1f4b854b2..af161b4e7 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -12,7 +12,7 @@ from contextlib import contextmanager from . import errors as e from . import pq -from .pq import ConnStatus, ExecStatus, Format +from .pq import ExecStatus, Format from .copy import Copy, AsyncCopy from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen from ._column import Column @@ -148,29 +148,76 @@ class BaseCursor(Generic[ConnectionType]): else: return None - def _start_query(self) -> None: + # + # Generators for the high level operations on the cursor + # + # Like for sync/async connections, these are implemented as generators + # so that different concurrency strategies (threads,asyncio) can use their + # own way of waiting (or better, `connection.wait()`). + # + + def _execute_gen( + self, query: Query, params: Optional[Params] = None + ) -> PQGen[None]: + """Generator implementing `Cursor.execute()`.""" + yield from self._start_query() + self._execute_send(query, params) + results = yield from execute(self._conn.pgconn) + self._execute_results(results) + + def _executemany_gen( + self, query: Query, params_seq: Sequence[Params] + ) -> PQGen[None]: + """Generator implementing `Cursor.executemany()`.""" + yield from self._start_query() + first = True + for params in params_seq: + if first: + pgq = self._send_prepare(b"", query, params) + (result,) = yield from execute(self._conn.pgconn) + if result.status == ExecStatus.FATAL_ERROR: + raise e.error_from_result( + result, encoding=self._conn.client_encoding + ) + else: + pgq.dump(params) + + self._send_query_prepared(b"", pgq) + (result,) = yield from execute(self._conn.pgconn) + self._execute_results((result,)) + + def _start_query(self) -> PQGen[None]: + """Generator to start the processing of a query. + + It is implemented as generator because it may send additional queries, + such as `begin`. + """ from . import adapt if self.closed: raise e.InterfaceError("the cursor is closed") - if self._conn.closed: - raise e.InterfaceError("the connection is closed") - - if self._conn.pgconn.status != ConnStatus.OK: - raise e.InterfaceError( - f"cannot execute operations: the connection is" - f" in status {self._conn.pgconn.status}" - ) - self._reset() self._transformer = adapt.Transformer(self) + yield from self._conn._start_query() + + def _start_copy_gen(self, statement: Query) -> PQGen[None]: + """Generator implementing sending a command for `Cursor.copy().""" + yield from self._start_query() + # Make sure to avoid PQexec to avoid receiving a mix of COPY and + # other operations. + self._execute_send(statement, None, no_pqexec=True) + (result,) = yield from execute(self._conn.pgconn) + self._check_copy_result(result) + self.pgresult = result # will set it on the transformer too def _execute_send( self, query: Query, params: Optional[Params], no_pqexec: bool = False ) -> None: """ - Implement part of execute() before waiting common to sync and async + Implement part of execute() before waiting common to sync and async. + + This is not a generator, but a normal, non-blocking function. """ pgq = PostgresQuery(self._transformer) pgq.convert(query, params) @@ -206,6 +253,8 @@ class BaseCursor(Generic[ConnectionType]): def _execute_results(self, results: Sequence["PGresult"]) -> None: """ Implement part of execute() after waiting common to sync and async + + This is not a generator, but a normal, non-blocking function. """ if not results: raise e.InternalError("got no result from the query") @@ -241,9 +290,6 @@ class BaseCursor(Generic[ConnectionType]): def _send_prepare( self, name: bytes, query: Query, params: Optional[Params] ) -> PostgresQuery: - """ - Implement part of execute() before waiting common to sync and async - """ pgq = PostgresQuery(self._transformer) pgq.convert(query, params) @@ -270,16 +316,10 @@ class BaseCursor(Generic[ConnectionType]): "the last operation didn't produce a result" ) - def _check_copy_results(self, results: Sequence["PGresult"]) -> None: + def _check_copy_result(self, result: "PGresult") -> None: """ Check that the value returned in a copy() operation is a legit COPY. """ - if len(results) != 1: - raise e.InternalError( - f"expected 1 result from copy, got {len(results)} instead" - ) - - result = results[0] status = result.status if status in (ExecStatus.COPY_IN, ExecStatus.COPY_OUT): return @@ -322,12 +362,7 @@ class Cursor(BaseCursor["Connection"]): Execute a query or command to the database. """ with self._conn.lock: - self._start_query() - self._conn._start_query() - self._execute_send(query, params) - gen = execute(self._conn.pgconn) - results = self._conn.wait(gen) - self._execute_results(results) + self._conn.wait(self._execute_gen(query, params)) return self def executemany(self, query: Query, params_seq: Sequence[Params]) -> None: @@ -335,25 +370,7 @@ class Cursor(BaseCursor["Connection"]): Execute the same command with a sequence of input data. """ with self._conn.lock: - self._start_query() - self._conn._start_query() - first = True - for params in params_seq: - if first: - pgq = self._send_prepare(b"", query, params) - gen = execute(self._conn.pgconn) - (result,) = self._conn.wait(gen) - if result.status == ExecStatus.FATAL_ERROR: - raise e.error_from_result( - result, encoding=self._conn.client_encoding - ) - else: - pgq.dump(params) - - self._send_query_prepared(b"", pgq) - gen = execute(self._conn.pgconn) - (result,) = self._conn.wait(gen) - self._execute_results((result,)) + self._conn.wait(self._executemany_gen(query, params_seq)) def fetchone(self) -> Optional[Sequence[Any]]: """ @@ -414,22 +431,11 @@ class Cursor(BaseCursor["Connection"]): """ Initiate a :sql:`COPY` operation and return an object to manage it. """ - with self._start_copy(statement) as copy: - yield copy - - def _start_copy(self, statement: Query) -> Copy: with self._conn.lock: - self._start_query() - self._conn._start_query() - # Make sure to avoid PQexec to avoid receiving a mix of COPY and - # other operations. - self._execute_send(statement, None, no_pqexec=True) - gen = execute(self._conn.pgconn) - results = self._conn.wait(gen) - self._check_copy_results(results) - self.pgresult = results[0] # will set it on the transformer too + self._conn.wait(self._start_copy_gen(statement)) - return Copy(self) + with Copy(self) as copy: + yield copy class AsyncCursor(BaseCursor["AsyncConnection"]): @@ -454,37 +460,14 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): self, query: Query, params: Optional[Params] = None ) -> "AsyncCursor": async with self._conn.lock: - self._start_query() - await self._conn._start_query() - self._execute_send(query, params) - gen = execute(self._conn.pgconn) - results = await self._conn.wait(gen) - self._execute_results(results) + await self._conn.wait(self._execute_gen(query, params)) return self async def executemany( self, query: Query, params_seq: Sequence[Params] ) -> None: async with self._conn.lock: - self._start_query() - await self._conn._start_query() - first = True - for params in params_seq: - if first: - pgq = self._send_prepare(b"", query, params) - gen = execute(self._conn.pgconn) - (result,) = await self._conn.wait(gen) - if result.status == ExecStatus.FATAL_ERROR: - raise e.error_from_result( - result, encoding=self._conn.client_encoding - ) - else: - pgq.dump(params) - - self._send_query_prepared(b"", pgq) - gen = execute(self._conn.pgconn) - (result,) = await self._conn.wait(gen) - self._execute_results((result,)) + await self._conn.wait(self._executemany_gen(query, params_seq)) async def fetchone(self) -> Optional[Sequence[Any]]: self._check_result() @@ -533,23 +516,11 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): @asynccontextmanager async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: - copy = await self._start_copy(statement) - async with copy: - yield copy - - async def _start_copy(self, statement: Query) -> AsyncCopy: async with self._conn.lock: - self._start_query() - await self._conn._start_query() - # Make sure to avoid PQexec to avoid receiving a mix of COPY and - # other operations. - self._execute_send(statement, None, no_pqexec=True) - gen = execute(self._conn.pgconn) - results = await self._conn.wait(gen) - self._check_copy_results(results) - self.pgresult = results[0] # will set it on the transformer too - - return AsyncCopy(self) + await self._conn.wait(self._start_copy_gen(statement)) + + async with AsyncCopy(self) as copy: + yield copy class NamedCursorMixin: diff --git a/psycopg3/psycopg3/generators.py b/psycopg3/psycopg3/generators.py index 50a6c4e5b..cbf5d6dd6 100644 --- a/psycopg3/psycopg3/generators.py +++ b/psycopg3/psycopg3/generators.py @@ -88,7 +88,6 @@ def send(pgconn: PGconn) -> PQGen[None]: After this generator has finished you may want to cycle using `_fetch()` to retrieve the results available. """ - yield pgconn.socket while 1: f = pgconn.flush() if f == 0: @@ -150,7 +149,6 @@ _copy_statuses = ( def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]: - yield pgconn.socket yield Wait.R pgconn.consume_input() @@ -166,7 +164,6 @@ def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]: def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]: - yield pgconn.socket while 1: nbytes, data = pgconn.get_copy_data(1) if nbytes != 0: @@ -192,14 +189,12 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]: def copy_to(pgconn: PGconn, buffer: bytes) -> PQGen[None]: - yield pgconn.socket # Retry enqueuing data until successful while pgconn.put_copy_data(buffer) == 0: yield Wait.W def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]: - yield pgconn.socket # Retry enqueuing end copy message until successful while pgconn.put_copy_end(error) == 0: yield Wait.W diff --git a/psycopg3/psycopg3/proto.py b/psycopg3/psycopg3/proto.py index 3deb72dc2..fff6ffee0 100644 --- a/psycopg3/psycopg3/proto.py +++ b/psycopg3/psycopg3/proto.py @@ -33,7 +33,7 @@ PQGenConn = Generator[Tuple[int, "Wait"], "Ready", RV] This can happen in connection and reset, but not in normal querying. """ -PQGen = Generator[Union[int, "Wait"], "Ready", RV] +PQGen = Generator["Wait", "Ready", RV] """Generator for processes where the connection file number won't change. The first item generated is the file descriptor; following items are be the diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index 21c28fda0..6be6e34f4 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -7,11 +7,11 @@ Transaction context managers returned by Connection.transaction() import logging from types import TracebackType -from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING +from typing import Generic, Optional, Type, Union, TYPE_CHECKING from . import sql from .pq import TransactionStatus -from .proto import ConnectionType +from .proto import ConnectionType, PQGen if TYPE_CHECKING: from .connection import Connection, AsyncConnection # noqa: F401 @@ -74,7 +74,7 @@ class BaseTransaction(Generic[ConnectionType]): args.append("force_rollback=True") return f"{self.__class__.__qualname__}({', '.join(args)})" - def _enter_commands(self) -> List[bytes]: + def _enter_gen(self) -> PQGen[None]: if not self._yolo: raise TypeError("transaction blocks can be used only once") else: @@ -107,9 +107,21 @@ class BaseTransaction(Generic[ConnectionType]): ) self._conn._savepoints.append(self._savepoint_name) - return commands + return self._conn._exec_command(b"; ".join(commands)) - def _commit_commands(self) -> List[bytes]: + def _exit_gen( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> PQGen[bool]: + if not exc_val and not self.force_rollback: + yield from self._commit_gen() + return False + else: + return (yield from self._rollback_gen(exc_val)) + + def _commit_gen(self) -> PQGen[None]: assert self._conn._savepoints[-1] == self._savepoint_name self._conn._savepoints.pop() @@ -125,9 +137,14 @@ class BaseTransaction(Generic[ConnectionType]): assert not self._conn._savepoints commands.append(b"commit") - return commands + return self._conn._exec_command(b"; ".join(commands)) + + def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]: + if isinstance(exc_val, Rollback): + _log.debug( + f"{self._conn}: Explicit rollback from: ", exc_info=True + ) - def _rollback_commands(self) -> List[bytes]: assert self._conn._savepoints[-1] == self._savepoint_name self._conn._savepoints.pop() @@ -143,7 +160,13 @@ class BaseTransaction(Generic[ConnectionType]): assert not self._conn._savepoints commands.append(b"rollback") - return commands + yield from self._conn._exec_command(b"; ".join(commands)) + + if isinstance(exc_val, Rollback): + if not exc_val.transaction or exc_val.transaction is self: + return True # Swallow the exception + + return False class Transaction(BaseTransaction["Connection"]): @@ -155,7 +178,7 @@ class Transaction(BaseTransaction["Connection"]): def __enter__(self) -> "Transaction": with self._conn.lock: - self._execute(self._enter_commands()) + self._conn.wait(self._enter_gen()) return self def __exit__( @@ -165,33 +188,7 @@ class Transaction(BaseTransaction["Connection"]): exc_tb: Optional[TracebackType], ) -> bool: with self._conn.lock: - if not exc_val and not self.force_rollback: - self._commit() - return False - else: - return self._rollback(exc_val) - - def _commit(self) -> None: - """Commit changes made in the transaction context.""" - self._execute(self._commit_commands()) - - def _rollback(self, exc_val: Optional[BaseException]) -> bool: - # Rollback changes made in the transaction context - if isinstance(exc_val, Rollback): - _log.debug( - f"{self._conn}: Explicit rollback from: ", exc_info=True - ) - - self._execute(self._rollback_commands()) - - if isinstance(exc_val, Rollback): - if exc_val.transaction in (self, None): - return True # Swallow the exception - - return False - - def _execute(self, commands: List[bytes]) -> None: - self._conn._exec_command(b"; ".join(commands)) + return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) class AsyncTransaction(BaseTransaction["AsyncConnection"]): @@ -203,8 +200,7 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]): async def __aenter__(self) -> "AsyncTransaction": async with self._conn.lock: - await self._execute(self._enter_commands()) - + await self._conn.wait(self._enter_gen()) return self async def __aexit__( @@ -214,30 +210,6 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]): exc_tb: Optional[TracebackType], ) -> bool: async with self._conn.lock: - if not exc_val and not self.force_rollback: - await self._commit() - return False - else: - return await self._rollback(exc_val) - - async def _commit(self) -> None: - """Commit changes made in the transaction context.""" - await self._execute(self._commit_commands()) - - async def _rollback(self, exc_val: Optional[BaseException]) -> bool: - # Rollback changes made in the transaction context - if isinstance(exc_val, Rollback): - _log.debug( - f"{self._conn}: Explicit rollback from: ", exc_info=True + return await self._conn.wait( + self._exit_gen(exc_type, exc_val, exc_tb) ) - - await self._execute(self._rollback_commands()) - - if isinstance(exc_val, Rollback): - if exc_val.transaction in (self, None): - return True # Swallow the exception - - return False - - async def _execute(self, commands: List[bytes]) -> None: - await self._conn._exec_command(b"; ".join(commands)) diff --git a/psycopg3/psycopg3/waiting.py b/psycopg3/psycopg3/waiting.py index 85e111966..af818fbb8 100644 --- a/psycopg3/psycopg3/waiting.py +++ b/psycopg3/psycopg3/waiting.py @@ -10,7 +10,7 @@ These functions are designed to consume the generators returned by the from enum import IntEnum -from typing import Optional, Union +from typing import Optional from asyncio import get_event_loop, Event from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE @@ -29,11 +29,40 @@ class Ready(IntEnum): W = EVENT_WRITE -def wait( - gen: Union[PQGen[RV], PQGenConn[RV]], timeout: Optional[float] = None -) -> RV: +def wait(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV: """ - Wait for a generator using the best option available on the platform. + Wait for a generator using the best strategy available. + + :param gen: a generator performing database operations and yielding + `Ready` values when it would block. + :param fileno: the file descriptor to wait on. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. + :type timeout: float + :return: whatever *gen* returns on completion. + + Consume *gen*, scheduling `fileno` for completion when it is reported to + block. Once ready again send the ready state back to *gen*. + """ + sel = DefaultSelector() + try: + s = next(gen) + while 1: + sel.register(fileno, s) + ready = None + while not ready: + ready = sel.select(timeout=timeout) + sel.unregister(fileno) + s = gen.send(ready[0][1]) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: + """ + Wait for a connection generator using the best strategy available. :param gen: a generator performing database operations and yielding (fd, `Ready`) pairs when it would block. @@ -41,59 +70,94 @@ def wait( to allow Ctrl-C. :type timeout: float :return: whatever *gen* returns on completion. + + Behave like in `wait()`, but take the fileno to wait from the generator + itself, which might change during processing. """ - fd: int - s: Wait sel = DefaultSelector() try: - # Use the first generated item to tell if it's a PQgen or PQgenConn. - # Note: mypy gets confused by the behaviour of this generator. - item = next(gen) - if isinstance(item, tuple): - fd, s = item - while 1: - sel.register(fd, s) - ready = None - while not ready: - ready = sel.select(timeout=timeout) - sel.unregister(fd) - - assert len(ready) == 1 - fd, s = gen.send(ready[0][1]) - else: - fd = item # type: ignore[assignment] - s = next(gen) # type: ignore[assignment] - while 1: - sel.register(fd, s) - ready = None - while not ready: - ready = sel.select(timeout=timeout) - sel.unregister(fd) - - assert len(ready) == 1 - s = gen.send(ready[0][1]) # type: ignore[arg-type,assignment] + fileno, s = next(gen) + while 1: + sel.register(fileno, s) + ready = None + while not ready: + ready = sel.select(timeout=timeout) + sel.unregister(fileno) + fileno, s = gen.send(ready[0][1]) except StopIteration as ex: rv: RV = ex.args[0] if ex.args else None return rv -async def wait_async(gen: Union[PQGen[RV], PQGenConn[RV]]) -> RV: +async def wait_async(gen: PQGen[RV], fileno: int) -> RV: """ Coroutine waiting for a generator to complete. - *gen* is expected to generate tuples (fd, status). consume it and block - according to the status until fd is ready. Send back the ready state - to the generator. + :param gen: a generator performing database operations and yielding + `Ready` values when it would block. + :param fileno: the file descriptor to wait on. + :return: whatever *gen* returns on completion. + + Behave like in `wait()`, but exposing an `asyncio` interface. + """ + # Use an event to block and restart after the fd state changes. + # Not sure this is the best implementation but it's a start. + ev = Event() + loop = get_event_loop() + ready: Ready + s: Wait + + def wakeup(state: Ready) -> None: + nonlocal ready + ready = state + ev.set() + + try: + s = next(gen) + while 1: + ev.clear() + if s == Wait.R: + loop.add_reader(fileno, wakeup, Ready.R) + await ev.wait() + loop.remove_reader(fileno) + elif s == Wait.W: + loop.add_writer(fileno, wakeup, Ready.W) + await ev.wait() + loop.remove_writer(fileno) + elif s == Wait.RW: + loop.add_reader(fileno, wakeup, Ready.R) + loop.add_writer(fileno, wakeup, Ready.W) + await ev.wait() + loop.remove_reader(fileno) + loop.remove_writer(fileno) + else: + raise e.InternalError("bad poll status: %s") + s = gen.send(ready) + + except StopIteration as ex: + rv: RV = ex.args[0] if ex.args else None + return rv + + +async def wait_async_conn(gen: PQGenConn[RV]) -> RV: + """ + Coroutine waiting for a connection generator to complete. + + :param gen: a generator performing database operations and yielding + (fd, `Ready`) pairs when it would block. + :param timeout: timeout (in seconds) to check for other interrupt, e.g. + to allow Ctrl-C. + :return: whatever *gen* returns on completion. - Return what the generator eventually returned. + Behave like in `wait()`, but take the fileno to wait from the generator + itself, which might change during processing. """ # Use an event to block and restart after the fd state changes. # Not sure this is the best implementation but it's a start. ev = Event() loop = get_event_loop() ready: Ready - fd: int s: Wait def wakeup(state: Ready) -> None: @@ -102,52 +166,26 @@ async def wait_async(gen: Union[PQGen[RV], PQGenConn[RV]]) -> RV: ev.set() try: - # Use the first generated item to tell if it's a PQgen or PQgenConn. - # Note: mypy gets confused by the behaviour of this generator. - item = next(gen) - if isinstance(item, tuple): - fd, s = item - while 1: - ev.clear() - if s == Wait.R: - loop.add_reader(fd, wakeup, Ready.R) - await ev.wait() - loop.remove_reader(fd) - elif s == Wait.W: - loop.add_writer(fd, wakeup, Ready.W) - await ev.wait() - loop.remove_writer(fd) - elif s == Wait.RW: - loop.add_reader(fd, wakeup, Ready.R) - loop.add_writer(fd, wakeup, Ready.W) - await ev.wait() - loop.remove_reader(fd) - loop.remove_writer(fd) - else: - raise e.InternalError("bad poll status: %s") - fd, s = gen.send(ready) # type: ignore[misc] - else: - fd = item # type: ignore[assignment] - s = next(gen) # type: ignore[assignment] - while 1: - ev.clear() - if s == Wait.R: - loop.add_reader(fd, wakeup, Ready.R) - await ev.wait() - loop.remove_reader(fd) - elif s == Wait.W: - loop.add_writer(fd, wakeup, Ready.W) - await ev.wait() - loop.remove_writer(fd) - elif s == Wait.RW: - loop.add_reader(fd, wakeup, Ready.R) - loop.add_writer(fd, wakeup, Ready.W) - await ev.wait() - loop.remove_reader(fd) - loop.remove_writer(fd) - else: - raise e.InternalError("bad poll status: %s") - s = gen.send(ready) # type: ignore[arg-type,assignment] + fileno, s = next(gen) + while 1: + ev.clear() + if s == Wait.R: + loop.add_reader(fileno, wakeup, Ready.R) + await ev.wait() + loop.remove_reader(fileno) + elif s == Wait.W: + loop.add_writer(fileno, wakeup, Ready.W) + await ev.wait() + loop.remove_writer(fileno) + elif s == Wait.RW: + loop.add_reader(fileno, wakeup, Ready.R) + loop.add_writer(fileno, wakeup, Ready.W) + await ev.wait() + loop.remove_reader(fileno) + loop.remove_writer(fileno) + else: + raise e.InternalError("bad poll status: %s") + fileno, s = gen.send(ready) except StopIteration as ex: rv: RV = ex.args[0] if ex.args else None diff --git a/psycopg3_c/psycopg3_c/generators.pyx b/psycopg3_c/psycopg3_c/generators.pyx index 16eb8267b..baf0119e7 100644 --- a/psycopg3_c/psycopg3_c/generators.pyx +++ b/psycopg3_c/psycopg3_c/generators.pyx @@ -73,10 +73,6 @@ def execute(PGconn pgconn) -> PQGen[List[pq.proto.PGresult]]: cdef libpq.PGresult *pgres cdef int cires, ibres - # Start the generator by sending the connection fd, which won't change - # during the query process. - yield libpq.PQsocket(pgconn_ptr) - # Sending the query while 1: if libpq.PQflush(pgconn_ptr) == 0: diff --git a/tests/pq/test_async.py b/tests/pq/test_async.py index 9fddde267..820f6e85d 100644 --- a/tests/pq/test_async.py +++ b/tests/pq/test_async.py @@ -5,6 +5,10 @@ from psycopg3 import pq from psycopg3.generators import execute +def execute_wait(pgconn): + return psycopg3.waiting.wait(execute(pgconn), pgconn.socket) + + def test_send_query(pgconn): # This test shows how to process an async query in all its glory pgconn.nonblocking = 1 @@ -63,7 +67,7 @@ def test_send_query_compact_test(pgconn): b"/* %s */ select pg_sleep(0.01); select 1 as foo;" % (b"x" * 1_000_000) ) - results = psycopg3.waiting.wait(execute(pgconn)) + results = execute_wait(pgconn) assert len(results) == 2 assert results[0].nfields == 1 @@ -80,7 +84,7 @@ def test_send_query_compact_test(pgconn): def test_send_query_params(pgconn): pgconn.send_query_params(b"select $1::int + $2", [b"5", b"3"]) - (res,) = psycopg3.waiting.wait(execute(pgconn)) + (res,) = execute_wait(pgconn) assert res.status == pq.ExecStatus.TUPLES_OK assert res.get_value(0, 0) == b"8" @@ -91,11 +95,11 @@ def test_send_query_params(pgconn): def test_send_prepare(pgconn): pgconn.send_prepare(b"prep", b"select $1::int + $2::int") - (res,) = psycopg3.waiting.wait(execute(pgconn)) + (res,) = execute_wait(pgconn) assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message pgconn.send_query_prepared(b"prep", [b"3", b"5"]) - (res,) = psycopg3.waiting.wait(execute(pgconn)) + (res,) = execute_wait(pgconn) assert res.get_value(0, 0) == b"8" pgconn.finish() @@ -107,22 +111,22 @@ def test_send_prepare(pgconn): def test_send_prepare_types(pgconn): pgconn.send_prepare(b"prep", b"select $1 + $2", [23, 23]) - (res,) = psycopg3.waiting.wait(execute(pgconn)) + (res,) = execute_wait(pgconn) assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message pgconn.send_query_prepared(b"prep", [b"3", b"5"]) - (res,) = psycopg3.waiting.wait(execute(pgconn)) + (res,) = execute_wait(pgconn) assert res.get_value(0, 0) == b"8" def test_send_prepared_binary_in(pgconn): val = b"foo\00bar" pgconn.send_prepare(b"", b"select length($1::bytea), length($2::bytea)") - (res,) = psycopg3.waiting.wait(execute(pgconn)) + (res,) = execute_wait(pgconn) assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message pgconn.send_query_prepared(b"", [val, val], param_formats=[0, 1]) - (res,) = psycopg3.waiting.wait(execute(pgconn)) + (res,) = execute_wait(pgconn) assert res.status == pq.ExecStatus.TUPLES_OK assert res.get_value(0, 0) == b"3" assert res.get_value(0, 1) == b"7" @@ -137,12 +141,12 @@ def test_send_prepared_binary_in(pgconn): def test_send_prepared_binary_out(pgconn, fmt, out): val = b"foo\00bar" pgconn.send_prepare(b"", b"select $1::bytea") - (res,) = psycopg3.waiting.wait(execute(pgconn)) + (res,) = execute_wait(pgconn) assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message pgconn.send_query_prepared( b"", [val], param_formats=[1], result_format=fmt ) - (res,) = psycopg3.waiting.wait(execute(pgconn)) + (res,) = execute_wait(pgconn) assert res.status == pq.ExecStatus.TUPLES_OK assert res.get_value(0, 0) == out diff --git a/tests/pq/test_pgconn.py b/tests/pq/test_pgconn.py index 17c28877c..01e2e2362 100644 --- a/tests/pq/test_pgconn.py +++ b/tests/pq/test_pgconn.py @@ -242,7 +242,7 @@ def test_transaction_status(pgconn): assert pgconn.transaction_status == pq.TransactionStatus.INTRANS pgconn.send_query(b"select 1") assert pgconn.transaction_status == pq.TransactionStatus.ACTIVE - psycopg3.waiting.wait(psycopg3.generators.execute(pgconn)) + psycopg3.waiting.wait(psycopg3.generators.execute(pgconn), pgconn.socket) assert pgconn.transaction_status == pq.TransactionStatus.INTRANS pgconn.finish() assert pgconn.transaction_status == pq.TransactionStatus.UNKNOWN diff --git a/tests/test_connection.py b/tests/test_connection.py index af938857e..0be39e0e9 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -68,7 +68,7 @@ def test_close(conn): assert conn.closed assert conn.pgconn.status == conn.ConnStatus.BAD - with pytest.raises(psycopg3.InterfaceError): + with pytest.raises(psycopg3.OperationalError): cur.execute("select 1") diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index de1879e9c..2d57d9c0f 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -73,7 +73,7 @@ async def test_close(aconn): assert aconn.closed assert aconn.pgconn.status == aconn.ConnStatus.BAD - with pytest.raises(psycopg3.InterfaceError): + with pytest.raises(psycopg3.OperationalError): await cur.execute("select 1") diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 1caa2c289..5458216cd 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -54,7 +54,12 @@ class ListPopAll(list): @pytest.fixture def commands(conn, monkeypatch): - """The queue of commands issued internally by the test connection.""" + """The list of commands issued internally by the test connection.""" + yield patch_exec(conn, monkeypatch) + + +def patch_exec(conn, monkeypatch): + """Helper to implement the commands fixture both sync and async.""" _orig_exec_command = conn._exec_command L = ListPopAll() @@ -63,10 +68,10 @@ def commands(conn, monkeypatch): command = command.decode(conn.client_encoding) L.insert(0, command) - _orig_exec_command(command) + return _orig_exec_command(command) monkeypatch.setattr(conn, "_exec_command", _exec_command) - yield L + return L def in_transaction(conn): diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py index f9cbd420d..f3d1582e9 100644 --- a/tests/test_transaction_async.py +++ b/tests/test_transaction_async.py @@ -3,27 +3,16 @@ import pytest from psycopg3 import ProgrammingError, Rollback from .test_transaction import in_transaction, insert_row, inserted -from .test_transaction import ExpectedException, ListPopAll +from .test_transaction import ExpectedException, patch_exec from .test_transaction import create_test_table # noqa # autouse fixture pytestmark = pytest.mark.asyncio @pytest.fixture -async def commands(aconn, monkeypatch): - """The queue of commands issued internally by the test connection.""" - _orig_exec_command = aconn._exec_command - L = ListPopAll() - - async def _exec_command(command): - if isinstance(command, bytes): - command = command.decode(aconn.client_encoding) - - L.insert(0, command) - await _orig_exec_command(command) - - monkeypatch.setattr(aconn, "_exec_command", _exec_command) - yield L +def commands(aconn, monkeypatch): + """The list of commands issued internally by the test connection.""" + yield patch_exec(aconn, monkeypatch) async def test_basic(aconn): -- 2.47.2