]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Implement connection/cursor procedures as generators
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 22 Dec 2020 17:09:34 +0000 (18:09 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 24 Dec 2020 03:51:34 +0000 (04:51 +0100)
Generators are common to both sync and async code so there is much less
code duplication, which will be useful in the light of doing more
complex things such as prepared transactions.

As a side effect of the refactoring, operations on closed connections
are raised as OperationalError rather than InterfaceError because
connection closure may depend on external factors.

Dropped fd from PQGen stream: the having the fd in the middle of the
`Ready` values made more difficult to chain more than one operation in
the same generator. The fd is not passed as extra parameter.

13 files changed:
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/generators.py
psycopg3/psycopg3/proto.py
psycopg3/psycopg3/transaction.py
psycopg3/psycopg3/waiting.py
psycopg3_c/psycopg3_c/generators.pyx
tests/pq/test_async.py
tests/pq/test_pgconn.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_transaction.py
tests/test_transaction_async.py

index d801a22384152da2c475993a6a08267b71c2e139..9af0bc6914ba89e3ba5e4a1345233e74cb333f5a 100644 (file)
@@ -10,7 +10,7 @@ import logging
 import threading
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
-from typing import Optional, Type, TYPE_CHECKING, Union
+from typing import Optional, Type, TYPE_CHECKING, TypeVar
 from weakref import ref, ReferenceType
 from functools import partial
 from contextlib import contextmanager
@@ -23,11 +23,11 @@ else:
 from . import pq
 from . import cursor
 from . import errors as e
+from . import waiting
 from . import encodings
 from .pq import TransactionStatus, ExecStatus, Format
 from .sql import Composable
 from .proto import DumpersMap, LoadersMap, PQGen, PQGenConn, RV, Query, Params
-from .waiting import wait, wait_async
 from .conninfo import make_conninfo
 from .generators import notifies
 from .transaction import Transaction, AsyncTransaction
@@ -73,6 +73,8 @@ Notify.__module__ = "psycopg3"
 NoticeHandler = Callable[[e.Diagnostic], None]
 NotifyHandler = Callable[[Notify], None]
 
+C = TypeVar("C", bound="BaseConnection")
+
 
 class BaseConnection:
     """
@@ -164,6 +166,15 @@ class BaseConnection:
     def _set_client_encoding(self, name: str) -> None:
         raise NotImplementedError
 
+    def _set_client_encoding_gen(self, name: str) -> PQGen[None]:
+        self.pgconn.send_query_params(
+            b"select set_config('client_encoding', $1, false)",
+            [encodings.py2pg(name)],
+        )
+        (result,) = yield from execute(self.pgconn)
+        if result.status != ExecStatus.TUPLES_OK:
+            raise e.error_from_result(result, encoding=self.client_encoding)
+
     def cancel(self) -> None:
         """Cancel the current operation on the connection."""
         c = self.pgconn.get_cancel()
@@ -223,6 +234,100 @@ class BaseConnection:
         for cb in self._notify_handlers:
             cb(n)
 
+    # Generators to perform high-level operations on the connection
+    #
+    # These operations are expressed in terms of non-blocking generators
+    # and the task of waiting when needed (when the generators yield) is left
+    # to the connections subclass, which might wait either in blocking mode
+    # or through asyncio.
+    #
+    # All these generators assume exclusive acces to the connection: subclasses
+    # should have a lock and hold it before calling and consuming them.
+
+    @classmethod
+    def _connect_gen(
+        cls: Type[C],
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        **kwargs: Any,
+    ) -> PQGenConn[C]:
+        """Generator to connect to the database and create a new instance."""
+        conninfo = make_conninfo(conninfo, **kwargs)
+        pgconn = yield from connect(conninfo)
+        conn = cls(pgconn)
+        conn._autocommit = autocommit
+        return conn
+
+    def _exec_command(self, command: Query) -> PQGen[None]:
+        """
+        Generator to send a command and receive the result to the backend.
+
+        Only used to implement internal commands such as commit, returning
+        no result. The cursor can do more complex stuff.
+        """
+        if self.pgconn.status != self.ConnStatus.OK:
+            if self.pgconn.status == self.ConnStatus.BAD:
+                raise e.OperationalError("the connection is closed")
+            raise e.InterfaceError(
+                f"cannot execute operations: the connection is"
+                f" in status {self.pgconn.status}"
+            )
+
+        if isinstance(command, str):
+            command = command.encode(self.client_encoding)
+        elif isinstance(command, Composable):
+            command = command.as_bytes(self)
+
+        self.pgconn.send_query(command)
+        result = (yield from execute(self.pgconn))[-1]
+        if result.status != ExecStatus.COMMAND_OK:
+            if result.status == ExecStatus.FATAL_ERROR:
+                raise e.error_from_result(
+                    result, encoding=self.client_encoding
+                )
+            else:
+                raise e.InterfaceError(
+                    f"unexpected result {ExecStatus(result.status).name}"
+                    f" from command {command.decode('utf8')!r}"
+                )
+
+    def _start_query(self) -> PQGen[None]:
+        """Generator to start a transaction if necessary."""
+        if self._autocommit:
+            return
+
+        if self.pgconn.transaction_status != TransactionStatus.IDLE:
+            return
+
+        yield from self._exec_command(b"begin")
+
+    def _commit_gen(self) -> PQGen[None]:
+        """Generator implementing `Connection.commit()`."""
+        if self._savepoints:
+            raise e.ProgrammingError(
+                "Explicit commit() forbidden within a Transaction "
+                "context. (Transaction will be automatically committed "
+                "on successful exit from context.)"
+            )
+        if self.pgconn.transaction_status == TransactionStatus.IDLE:
+            return
+
+        yield from self._exec_command(b"commit")
+
+    def _rollback_gen(self) -> PQGen[None]:
+        """Generator implementing `Connection.rollback()`."""
+        if self._savepoints:
+            raise e.ProgrammingError(
+                "Explicit rollback() forbidden within a Transaction "
+                "context. (Either raise Rollback() or allow "
+                "an exception to propagate out of the context.)"
+            )
+        if self.pgconn.transaction_status == TransactionStatus.IDLE:
+            return
+
+        yield from self._exec_command(b"rollback")
+
 
 class Connection(BaseConnection):
     """
@@ -247,13 +352,9 @@ class Connection(BaseConnection):
 
         TODO: connection_timeout to be implemented.
         """
-
-        conninfo = make_conninfo(conninfo, **kwargs)
-        gen = connect(conninfo)
-        pgconn = cls.wait(gen)
-        conn = cls(pgconn)
-        conn._autocommit = autocommit
-        return conn
+        return cls._wait_conn(
+            cls._connect_gen(conninfo, autocommit=autocommit, **kwargs)
+        )
 
     def __enter__(self) -> "Connection":
         return self
@@ -284,16 +385,6 @@ class Connection(BaseConnection):
 
         return self.cursor_factory(self, format=format)
 
-    def _start_query(self) -> None:
-        # the function is meant to be called by a cursor once the lock is taken
-        if self._autocommit:
-            return
-
-        if self.pgconn.transaction_status != TransactionStatus.IDLE:
-            return
-
-        self._exec_command(b"begin")
-
     def execute(
         self, query: Query, params: Optional[Params] = None
     ) -> "Cursor":
@@ -304,49 +395,12 @@ class Connection(BaseConnection):
     def commit(self) -> None:
         """Commit any pending transaction to the database."""
         with self.lock:
-            if self._savepoints:
-                raise e.ProgrammingError(
-                    "Explicit commit() forbidden within a Transaction "
-                    "context. (Transaction will be automatically committed "
-                    "on successful exit from context.)"
-                )
-            if self.pgconn.transaction_status == TransactionStatus.IDLE:
-                return
-            self._exec_command(b"commit")
+            self.wait(self._commit_gen())
 
     def rollback(self) -> None:
         """Roll back to the start of any pending transaction."""
         with self.lock:
-            if self._savepoints:
-                raise e.ProgrammingError(
-                    "Explicit rollback() forbidden within a Transaction "
-                    "context. (Either raise Rollback() or allow "
-                    "an exception to propagate out of the context.)"
-                )
-            if self.pgconn.transaction_status == TransactionStatus.IDLE:
-                return
-            self._exec_command(b"rollback")
-
-    def _exec_command(self, command: Query) -> None:
-        # Caller must hold self.lock
-
-        if isinstance(command, str):
-            command = command.encode(self.client_encoding)
-        elif isinstance(command, Composable):
-            command = command.as_string(self).encode(self.client_encoding)
-
-        self.pgconn.send_query(command)
-        result = self.wait(execute(self.pgconn))[-1]
-        if result.status != ExecStatus.COMMAND_OK:
-            if result.status == ExecStatus.FATAL_ERROR:
-                raise e.error_from_result(
-                    result, encoding=self.client_encoding
-                )
-            else:
-                raise e.InterfaceError(
-                    f"unexpected result {ExecStatus(result.status).name}"
-                    f" from command {command.decode('utf8')!r}"
-                )
+            self.wait(self._rollback_gen())
 
     @contextmanager
     def transaction(
@@ -365,27 +419,6 @@ class Connection(BaseConnection):
         with Transaction(self, savepoint_name, force_rollback) as tx:
             yield tx
 
-    @classmethod
-    def wait(
-        cls,
-        gen: Union[PQGen[RV], PQGenConn[RV]],
-        timeout: Optional[float] = 0.1,
-    ) -> RV:
-        return wait(gen, timeout=timeout)
-
-    def _set_client_encoding(self, name: str) -> None:
-        with self.lock:
-            self.pgconn.send_query_params(
-                b"select set_config('client_encoding', $1, false)",
-                [encodings.py2pg(name)],
-            )
-            gen = execute(self.pgconn)
-            (result,) = self.wait(gen)
-            if result.status != ExecStatus.TUPLES_OK:
-                raise e.error_from_result(
-                    result, encoding=self.client_encoding
-                )
-
     def notifies(self) -> Iterator[Notify]:
         """
         Yield `Notify` objects as soon as they are received from the database.
@@ -400,10 +433,30 @@ class Connection(BaseConnection):
                 )
                 yield n
 
+    def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+        """
+        Consume a generator operating on the connection.
+
+        The function must be used on generators that don't change connection
+        fd (i.e. not on connect and reset).
+        """
+        return waiting.wait(gen, self.pgconn.socket, timeout=timeout)
+
+    @classmethod
+    def _wait_conn(
+        cls, gen: PQGenConn[RV], timeout: Optional[float] = 0.1
+    ) -> RV:
+        """Consume a connection generator."""
+        return waiting.wait_conn(gen, timeout=timeout)
+
     def _set_autocommit(self, value: bool) -> None:
         with self.lock:
             super()._set_autocommit(value)
 
+    def _set_client_encoding(self, name: str) -> None:
+        with self.lock:
+            self.wait(self._set_client_encoding_gen(name))
+
 
 class AsyncConnection(BaseConnection):
     """
@@ -423,12 +476,9 @@ class AsyncConnection(BaseConnection):
     async def connect(
         cls, conninfo: str = "", *, autocommit: bool = False, **kwargs: Any
     ) -> "AsyncConnection":
-        conninfo = make_conninfo(conninfo, **kwargs)
-        gen = connect(conninfo)
-        pgconn = await cls.wait(gen)
-        conn = cls(pgconn)
-        conn._autocommit = autocommit
-        return conn
+        return await cls._wait_conn(
+            cls._connect_gen(conninfo, autocommit=autocommit, **kwargs)
+        )
 
     async def __aenter__(self) -> "AsyncConnection":
         return self
@@ -460,16 +510,6 @@ class AsyncConnection(BaseConnection):
 
         return self.cursor_factory(self, format=format)
 
-    async def _start_query(self) -> None:
-        # the function is meant to be called by a cursor once the lock is taken
-        if self._autocommit:
-            return
-
-        if self.pgconn.transaction_status != TransactionStatus.IDLE:
-            return
-
-        await self._exec_command(b"begin")
-
     async def execute(
         self, query: Query, params: Optional[Params] = None
     ) -> "AsyncCursor":
@@ -478,48 +518,11 @@ class AsyncConnection(BaseConnection):
 
     async def commit(self) -> None:
         async with self.lock:
-            if self._savepoints:
-                raise e.ProgrammingError(
-                    "Explicit commit() forbidden within a Transaction "
-                    "context. (Transaction will be automatically committed "
-                    "on successful exit from context.)"
-                )
-            if self.pgconn.transaction_status == TransactionStatus.IDLE:
-                return
-            await self._exec_command(b"commit")
+            await self.wait(self._commit_gen())
 
     async def rollback(self) -> None:
         async with self.lock:
-            if self._savepoints:
-                raise e.ProgrammingError(
-                    "Explicit rollback() forbidden within a Transaction "
-                    "context. (Either raise Rollback() or allow "
-                    "an exception to propagate out of the context.)"
-                )
-            if self.pgconn.transaction_status == TransactionStatus.IDLE:
-                return
-            await self._exec_command(b"rollback")
-
-    async def _exec_command(self, command: Query) -> None:
-        # Caller must hold self.lock
-
-        if isinstance(command, str):
-            command = command.encode(self.client_encoding)
-        elif isinstance(command, Composable):
-            command = command.as_string(self).encode(self.client_encoding)
-
-        self.pgconn.send_query(command)
-        result = (await self.wait(execute(self.pgconn)))[-1]
-        if result.status != ExecStatus.COMMAND_OK:
-            if result.status == ExecStatus.FATAL_ERROR:
-                raise e.error_from_result(
-                    result, encoding=self.client_encoding
-                )
-            else:
-                raise e.InterfaceError(
-                    f"unexpected result {ExecStatus(result.status).name}"
-                    f" from command {command.decode('utf8')!r}"
-                )
+            await self.wait(self._rollback_gen())
 
     @asynccontextmanager
     async def transaction(
@@ -534,9 +537,23 @@ class AsyncConnection(BaseConnection):
         async with tx:
             yield tx
 
+    async def notifies(self) -> AsyncIterator[Notify]:
+        while 1:
+            async with self.lock:
+                ns = await self.wait(notifies(self.pgconn))
+            enc = self.client_encoding
+            for pgn in ns:
+                n = Notify(
+                    pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
+                )
+                yield n
+
+    async def wait(self, gen: PQGen[RV]) -> RV:
+        return await waiting.wait_async(gen, self.pgconn.socket)
+
     @classmethod
-    async def wait(cls, gen: Union[PQGen[RV], PQGenConn[RV]]) -> RV:
-        return await wait_async(gen)
+    async def _wait_conn(cls, gen: PQGenConn[RV]) -> RV:
+        return await waiting.wait_async_conn(gen)
 
     def _set_client_encoding(self, name: str) -> None:
         raise AttributeError(
@@ -547,27 +564,7 @@ class AsyncConnection(BaseConnection):
     async def set_client_encoding(self, name: str) -> None:
         """Async version of the `~Connection.client_encoding` setter."""
         async with self.lock:
-            self.pgconn.send_query_params(
-                b"select set_config('client_encoding', $1, false)",
-                [encodings.py2pg(name)],
-            )
-            gen = execute(self.pgconn)
-            (result,) = await self.wait(gen)
-            if result.status != ExecStatus.TUPLES_OK:
-                raise e.error_from_result(
-                    result, encoding=self.client_encoding
-                )
-
-    async def notifies(self) -> AsyncIterator[Notify]:
-        while 1:
-            async with self.lock:
-                ns = await self.wait(notifies(self.pgconn))
-            enc = self.client_encoding
-            for pgn in ns:
-                n = Notify(
-                    pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
-                )
-                yield n
+            await self.wait(self._set_client_encoding_gen(name))
 
     def _set_autocommit(self, value: bool) -> None:
         raise AttributeError(
index 1f4b854b29d189086f8e9a2b6a7a63018f4d0c22..af161b4e7b77b37ca68c0bdd44b3853a25028e52 100644 (file)
@@ -12,7 +12,7 @@ from contextlib import contextmanager
 
 from . import errors as e
 from . import pq
-from .pq import ConnStatus, ExecStatus, Format
+from .pq import ExecStatus, Format
 from .copy import Copy, AsyncCopy
 from .proto import ConnectionType, Query, Params, DumpersMap, LoadersMap, PQGen
 from ._column import Column
@@ -148,29 +148,76 @@ class BaseCursor(Generic[ConnectionType]):
         else:
             return None
 
-    def _start_query(self) -> None:
+    #
+    # Generators for the high level operations on the cursor
+    #
+    # Like for sync/async connections, these are implemented as generators
+    # so that different concurrency strategies (threads,asyncio) can use their
+    # own way of waiting (or better, `connection.wait()`).
+    #
+
+    def _execute_gen(
+        self, query: Query, params: Optional[Params] = None
+    ) -> PQGen[None]:
+        """Generator implementing `Cursor.execute()`."""
+        yield from self._start_query()
+        self._execute_send(query, params)
+        results = yield from execute(self._conn.pgconn)
+        self._execute_results(results)
+
+    def _executemany_gen(
+        self, query: Query, params_seq: Sequence[Params]
+    ) -> PQGen[None]:
+        """Generator implementing `Cursor.executemany()`."""
+        yield from self._start_query()
+        first = True
+        for params in params_seq:
+            if first:
+                pgq = self._send_prepare(b"", query, params)
+                (result,) = yield from execute(self._conn.pgconn)
+                if result.status == ExecStatus.FATAL_ERROR:
+                    raise e.error_from_result(
+                        result, encoding=self._conn.client_encoding
+                    )
+            else:
+                pgq.dump(params)
+
+            self._send_query_prepared(b"", pgq)
+            (result,) = yield from execute(self._conn.pgconn)
+            self._execute_results((result,))
+
+    def _start_query(self) -> PQGen[None]:
+        """Generator to start the processing of a query.
+
+        It is implemented as generator because it may send additional queries,
+        such as `begin`.
+        """
         from . import adapt
 
         if self.closed:
             raise e.InterfaceError("the cursor is closed")
 
-        if self._conn.closed:
-            raise e.InterfaceError("the connection is closed")
-
-        if self._conn.pgconn.status != ConnStatus.OK:
-            raise e.InterfaceError(
-                f"cannot execute operations: the connection is"
-                f" in status {self._conn.pgconn.status}"
-            )
-
         self._reset()
         self._transformer = adapt.Transformer(self)
+        yield from self._conn._start_query()
+
+    def _start_copy_gen(self, statement: Query) -> PQGen[None]:
+        """Generator implementing sending a command for `Cursor.copy()."""
+        yield from self._start_query()
+        # Make sure to avoid PQexec to avoid receiving a mix of COPY and
+        # other operations.
+        self._execute_send(statement, None, no_pqexec=True)
+        (result,) = yield from execute(self._conn.pgconn)
+        self._check_copy_result(result)
+        self.pgresult = result  # will set it on the transformer too
 
     def _execute_send(
         self, query: Query, params: Optional[Params], no_pqexec: bool = False
     ) -> None:
         """
-        Implement part of execute() before waiting common to sync and async
+        Implement part of execute() before waiting common to sync and async.
+
+        This is not a generator, but a normal, non-blocking function.
         """
         pgq = PostgresQuery(self._transformer)
         pgq.convert(query, params)
@@ -206,6 +253,8 @@ class BaseCursor(Generic[ConnectionType]):
     def _execute_results(self, results: Sequence["PGresult"]) -> None:
         """
         Implement part of execute() after waiting common to sync and async
+
+        This is not a generator, but a normal, non-blocking function.
         """
         if not results:
             raise e.InternalError("got no result from the query")
@@ -241,9 +290,6 @@ class BaseCursor(Generic[ConnectionType]):
     def _send_prepare(
         self, name: bytes, query: Query, params: Optional[Params]
     ) -> PostgresQuery:
-        """
-        Implement part of execute() before waiting common to sync and async
-        """
         pgq = PostgresQuery(self._transformer)
         pgq.convert(query, params)
 
@@ -270,16 +316,10 @@ class BaseCursor(Generic[ConnectionType]):
                 "the last operation didn't produce a result"
             )
 
-    def _check_copy_results(self, results: Sequence["PGresult"]) -> None:
+    def _check_copy_result(self, result: "PGresult") -> None:
         """
         Check that the value returned in a copy() operation is a legit COPY.
         """
-        if len(results) != 1:
-            raise e.InternalError(
-                f"expected 1 result from copy, got {len(results)} instead"
-            )
-
-        result = results[0]
         status = result.status
         if status in (ExecStatus.COPY_IN, ExecStatus.COPY_OUT):
             return
@@ -322,12 +362,7 @@ class Cursor(BaseCursor["Connection"]):
         Execute a query or command to the database.
         """
         with self._conn.lock:
-            self._start_query()
-            self._conn._start_query()
-            self._execute_send(query, params)
-            gen = execute(self._conn.pgconn)
-            results = self._conn.wait(gen)
-            self._execute_results(results)
+            self._conn.wait(self._execute_gen(query, params))
         return self
 
     def executemany(self, query: Query, params_seq: Sequence[Params]) -> None:
@@ -335,25 +370,7 @@ class Cursor(BaseCursor["Connection"]):
         Execute the same command with a sequence of input data.
         """
         with self._conn.lock:
-            self._start_query()
-            self._conn._start_query()
-            first = True
-            for params in params_seq:
-                if first:
-                    pgq = self._send_prepare(b"", query, params)
-                    gen = execute(self._conn.pgconn)
-                    (result,) = self._conn.wait(gen)
-                    if result.status == ExecStatus.FATAL_ERROR:
-                        raise e.error_from_result(
-                            result, encoding=self._conn.client_encoding
-                        )
-                else:
-                    pgq.dump(params)
-
-                self._send_query_prepared(b"", pgq)
-                gen = execute(self._conn.pgconn)
-                (result,) = self._conn.wait(gen)
-                self._execute_results((result,))
+            self._conn.wait(self._executemany_gen(query, params_seq))
 
     def fetchone(self) -> Optional[Sequence[Any]]:
         """
@@ -414,22 +431,11 @@ class Cursor(BaseCursor["Connection"]):
         """
         Initiate a :sql:`COPY` operation and return an object to manage it.
         """
-        with self._start_copy(statement) as copy:
-            yield copy
-
-    def _start_copy(self, statement: Query) -> Copy:
         with self._conn.lock:
-            self._start_query()
-            self._conn._start_query()
-            # Make sure to avoid PQexec to avoid receiving a mix of COPY and
-            # other operations.
-            self._execute_send(statement, None, no_pqexec=True)
-            gen = execute(self._conn.pgconn)
-            results = self._conn.wait(gen)
-            self._check_copy_results(results)
-            self.pgresult = results[0]  # will set it on the transformer too
+            self._conn.wait(self._start_copy_gen(statement))
 
-        return Copy(self)
+        with Copy(self) as copy:
+            yield copy
 
 
 class AsyncCursor(BaseCursor["AsyncConnection"]):
@@ -454,37 +460,14 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
         self, query: Query, params: Optional[Params] = None
     ) -> "AsyncCursor":
         async with self._conn.lock:
-            self._start_query()
-            await self._conn._start_query()
-            self._execute_send(query, params)
-            gen = execute(self._conn.pgconn)
-            results = await self._conn.wait(gen)
-            self._execute_results(results)
+            await self._conn.wait(self._execute_gen(query, params))
         return self
 
     async def executemany(
         self, query: Query, params_seq: Sequence[Params]
     ) -> None:
         async with self._conn.lock:
-            self._start_query()
-            await self._conn._start_query()
-            first = True
-            for params in params_seq:
-                if first:
-                    pgq = self._send_prepare(b"", query, params)
-                    gen = execute(self._conn.pgconn)
-                    (result,) = await self._conn.wait(gen)
-                    if result.status == ExecStatus.FATAL_ERROR:
-                        raise e.error_from_result(
-                            result, encoding=self._conn.client_encoding
-                        )
-                else:
-                    pgq.dump(params)
-
-                self._send_query_prepared(b"", pgq)
-                gen = execute(self._conn.pgconn)
-                (result,) = await self._conn.wait(gen)
-                self._execute_results((result,))
+            await self._conn.wait(self._executemany_gen(query, params_seq))
 
     async def fetchone(self) -> Optional[Sequence[Any]]:
         self._check_result()
@@ -533,23 +516,11 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
 
     @asynccontextmanager
     async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]:
-        copy = await self._start_copy(statement)
-        async with copy:
-            yield copy
-
-    async def _start_copy(self, statement: Query) -> AsyncCopy:
         async with self._conn.lock:
-            self._start_query()
-            await self._conn._start_query()
-            # Make sure to avoid PQexec to avoid receiving a mix of COPY and
-            # other operations.
-            self._execute_send(statement, None, no_pqexec=True)
-            gen = execute(self._conn.pgconn)
-            results = await self._conn.wait(gen)
-            self._check_copy_results(results)
-            self.pgresult = results[0]  # will set it on the transformer too
-
-        return AsyncCopy(self)
+            await self._conn.wait(self._start_copy_gen(statement))
+
+        async with AsyncCopy(self) as copy:
+            yield copy
 
 
 class NamedCursorMixin:
index 50a6c4e5bdc253158c66e96b6d39c2c4cef68794..cbf5d6dd65e706c3491320e33699c7ab5b1d303f 100644 (file)
@@ -88,7 +88,6 @@ def send(pgconn: PGconn) -> PQGen[None]:
     After this generator has finished you may want to cycle using `_fetch()`
     to retrieve the results available.
     """
-    yield pgconn.socket
     while 1:
         f = pgconn.flush()
         if f == 0:
@@ -150,7 +149,6 @@ _copy_statuses = (
 
 
 def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
-    yield pgconn.socket
     yield Wait.R
     pgconn.consume_input()
 
@@ -166,7 +164,6 @@ def notifies(pgconn: PGconn) -> PQGen[List[pq.PGnotify]]:
 
 
 def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]:
-    yield pgconn.socket
     while 1:
         nbytes, data = pgconn.get_copy_data(1)
         if nbytes != 0:
@@ -192,14 +189,12 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[bytes, PGresult]]:
 
 
 def copy_to(pgconn: PGconn, buffer: bytes) -> PQGen[None]:
-    yield pgconn.socket
     # Retry enqueuing data until successful
     while pgconn.put_copy_data(buffer) == 0:
         yield Wait.W
 
 
 def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
-    yield pgconn.socket
     # Retry enqueuing end copy message until successful
     while pgconn.put_copy_end(error) == 0:
         yield Wait.W
index 3deb72dc27f331acee47d345fe84bfa072aef817..fff6ffee0aed5ae113506f9d369c1223d11f2dec 100644 (file)
@@ -33,7 +33,7 @@ PQGenConn = Generator[Tuple[int, "Wait"], "Ready", RV]
 This can happen in connection and reset, but not in normal querying.
 """
 
-PQGen = Generator[Union[int, "Wait"], "Ready", RV]
+PQGen = Generator["Wait", "Ready", RV]
 """Generator for processes where the connection file number won't change.
 
 The first item generated is the file descriptor; following items are be the
index 21c28fda01bb325fe6f6adbe1f2cfbb49524f1c4..6be6e34f48fcdcb110019b4bd28127b64fb74d41 100644 (file)
@@ -7,11 +7,11 @@ Transaction context managers returned by Connection.transaction()
 import logging
 
 from types import TracebackType
-from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING
+from typing import Generic, Optional, Type, Union, TYPE_CHECKING
 
 from . import sql
 from .pq import TransactionStatus
-from .proto import ConnectionType
+from .proto import ConnectionType, PQGen
 
 if TYPE_CHECKING:
     from .connection import Connection, AsyncConnection  # noqa: F401
@@ -74,7 +74,7 @@ class BaseTransaction(Generic[ConnectionType]):
             args.append("force_rollback=True")
         return f"{self.__class__.__qualname__}({', '.join(args)})"
 
-    def _enter_commands(self) -> List[bytes]:
+    def _enter_gen(self) -> PQGen[None]:
         if not self._yolo:
             raise TypeError("transaction blocks can be used only once")
         else:
@@ -107,9 +107,21 @@ class BaseTransaction(Generic[ConnectionType]):
             )
 
         self._conn._savepoints.append(self._savepoint_name)
-        return commands
+        return self._conn._exec_command(b"; ".join(commands))
 
-    def _commit_commands(self) -> List[bytes]:
+    def _exit_gen(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> PQGen[bool]:
+        if not exc_val and not self.force_rollback:
+            yield from self._commit_gen()
+            return False
+        else:
+            return (yield from self._rollback_gen(exc_val))
+
+    def _commit_gen(self) -> PQGen[None]:
         assert self._conn._savepoints[-1] == self._savepoint_name
         self._conn._savepoints.pop()
 
@@ -125,9 +137,14 @@ class BaseTransaction(Generic[ConnectionType]):
             assert not self._conn._savepoints
             commands.append(b"commit")
 
-        return commands
+        return self._conn._exec_command(b"; ".join(commands))
+
+    def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
+        if isinstance(exc_val, Rollback):
+            _log.debug(
+                f"{self._conn}: Explicit rollback from: ", exc_info=True
+            )
 
-    def _rollback_commands(self) -> List[bytes]:
         assert self._conn._savepoints[-1] == self._savepoint_name
         self._conn._savepoints.pop()
 
@@ -143,7 +160,13 @@ class BaseTransaction(Generic[ConnectionType]):
             assert not self._conn._savepoints
             commands.append(b"rollback")
 
-        return commands
+        yield from self._conn._exec_command(b"; ".join(commands))
+
+        if isinstance(exc_val, Rollback):
+            if not exc_val.transaction or exc_val.transaction is self:
+                return True  # Swallow the exception
+
+        return False
 
 
 class Transaction(BaseTransaction["Connection"]):
@@ -155,7 +178,7 @@ class Transaction(BaseTransaction["Connection"]):
 
     def __enter__(self) -> "Transaction":
         with self._conn.lock:
-            self._execute(self._enter_commands())
+            self._conn.wait(self._enter_gen())
         return self
 
     def __exit__(
@@ -165,33 +188,7 @@ class Transaction(BaseTransaction["Connection"]):
         exc_tb: Optional[TracebackType],
     ) -> bool:
         with self._conn.lock:
-            if not exc_val and not self.force_rollback:
-                self._commit()
-                return False
-            else:
-                return self._rollback(exc_val)
-
-    def _commit(self) -> None:
-        """Commit changes made in the transaction context."""
-        self._execute(self._commit_commands())
-
-    def _rollback(self, exc_val: Optional[BaseException]) -> bool:
-        # Rollback changes made in the transaction context
-        if isinstance(exc_val, Rollback):
-            _log.debug(
-                f"{self._conn}: Explicit rollback from: ", exc_info=True
-            )
-
-        self._execute(self._rollback_commands())
-
-        if isinstance(exc_val, Rollback):
-            if exc_val.transaction in (self, None):
-                return True  # Swallow the exception
-
-        return False
-
-    def _execute(self, commands: List[bytes]) -> None:
-        self._conn._exec_command(b"; ".join(commands))
+            return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
 
 
 class AsyncTransaction(BaseTransaction["AsyncConnection"]):
@@ -203,8 +200,7 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]):
 
     async def __aenter__(self) -> "AsyncTransaction":
         async with self._conn.lock:
-            await self._execute(self._enter_commands())
-
+            await self._conn.wait(self._enter_gen())
         return self
 
     async def __aexit__(
@@ -214,30 +210,6 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]):
         exc_tb: Optional[TracebackType],
     ) -> bool:
         async with self._conn.lock:
-            if not exc_val and not self.force_rollback:
-                await self._commit()
-                return False
-            else:
-                return await self._rollback(exc_val)
-
-    async def _commit(self) -> None:
-        """Commit changes made in the transaction context."""
-        await self._execute(self._commit_commands())
-
-    async def _rollback(self, exc_val: Optional[BaseException]) -> bool:
-        # Rollback changes made in the transaction context
-        if isinstance(exc_val, Rollback):
-            _log.debug(
-                f"{self._conn}: Explicit rollback from: ", exc_info=True
+            return await self._conn.wait(
+                self._exit_gen(exc_type, exc_val, exc_tb)
             )
-
-        await self._execute(self._rollback_commands())
-
-        if isinstance(exc_val, Rollback):
-            if exc_val.transaction in (self, None):
-                return True  # Swallow the exception
-
-        return False
-
-    async def _execute(self, commands: List[bytes]) -> None:
-        await self._conn._exec_command(b"; ".join(commands))
index 85e1119662079c06cc37949c99c3dc401c69d8a6..af818fbb862a55c2c364e6923b317a2be3a54510 100644 (file)
@@ -10,7 +10,7 @@ These functions are designed to consume the generators returned by the
 
 
 from enum import IntEnum
-from typing import Optional, Union
+from typing import Optional
 from asyncio import get_event_loop, Event
 from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE
 
@@ -29,11 +29,40 @@ class Ready(IntEnum):
     W = EVENT_WRITE
 
 
-def wait(
-    gen: Union[PQGen[RV], PQGenConn[RV]], timeout: Optional[float] = None
-) -> RV:
+def wait(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
     """
-    Wait for a generator using the best option available on the platform.
+    Wait for a generator using the best strategy available.
+
+    :param gen: a generator performing database operations and yielding
+        `Ready` values when it would block.
+    :param fileno: the file descriptor to wait on.
+    :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+        to allow Ctrl-C.
+    :type timeout: float
+    :return: whatever *gen* returns on completion.
+
+    Consume *gen*, scheduling `fileno` for completion when it is reported to
+    block. Once ready again send the ready state back to *gen*.
+    """
+    sel = DefaultSelector()
+    try:
+        s = next(gen)
+        while 1:
+            sel.register(fileno, s)
+            ready = None
+            while not ready:
+                ready = sel.select(timeout=timeout)
+            sel.unregister(fileno)
+            s = gen.send(ready[0][1])
+
+    except StopIteration as ex:
+        rv: RV = ex.args[0] if ex.args else None
+        return rv
+
+
+def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
+    """
+    Wait for a connection generator using the best strategy available.
 
     :param gen: a generator performing database operations and yielding
         (fd, `Ready`) pairs when it would block.
@@ -41,59 +70,94 @@ def wait(
         to allow Ctrl-C.
     :type timeout: float
     :return: whatever *gen* returns on completion.
+
+    Behave like in `wait()`, but take the fileno to wait from the generator
+    itself, which might change during processing.
     """
-    fd: int
-    s: Wait
     sel = DefaultSelector()
     try:
-        # Use the first generated item to tell if it's a PQgen or PQgenConn.
-        # Note: mypy gets confused by the behaviour of this generator.
-        item = next(gen)
-        if isinstance(item, tuple):
-            fd, s = item
-            while 1:
-                sel.register(fd, s)
-                ready = None
-                while not ready:
-                    ready = sel.select(timeout=timeout)
-                sel.unregister(fd)
-
-                assert len(ready) == 1
-                fd, s = gen.send(ready[0][1])
-        else:
-            fd = item  # type: ignore[assignment]
-            s = next(gen)  # type: ignore[assignment]
-            while 1:
-                sel.register(fd, s)
-                ready = None
-                while not ready:
-                    ready = sel.select(timeout=timeout)
-                sel.unregister(fd)
-
-                assert len(ready) == 1
-                s = gen.send(ready[0][1])  # type: ignore[arg-type,assignment]
+        fileno, s = next(gen)
+        while 1:
+            sel.register(fileno, s)
+            ready = None
+            while not ready:
+                ready = sel.select(timeout=timeout)
+            sel.unregister(fileno)
+            fileno, s = gen.send(ready[0][1])
 
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
         return rv
 
 
-async def wait_async(gen: Union[PQGen[RV], PQGenConn[RV]]) -> RV:
+async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
     """
     Coroutine waiting for a generator to complete.
 
-    *gen* is expected to generate tuples (fd, status). consume it and block
-    according to the status until fd is ready. Send back the ready state
-    to the generator.
+    :param gen: a generator performing database operations and yielding
+        `Ready` values when it would block.
+    :param fileno: the file descriptor to wait on.
+    :return: whatever *gen* returns on completion.
+
+    Behave like in `wait()`, but exposing an `asyncio` interface.
+    """
+    # Use an event to block and restart after the fd state changes.
+    # Not sure this is the best implementation but it's a start.
+    ev = Event()
+    loop = get_event_loop()
+    ready: Ready
+    s: Wait
+
+    def wakeup(state: Ready) -> None:
+        nonlocal ready
+        ready = state
+        ev.set()
+
+    try:
+        s = next(gen)
+        while 1:
+            ev.clear()
+            if s == Wait.R:
+                loop.add_reader(fileno, wakeup, Ready.R)
+                await ev.wait()
+                loop.remove_reader(fileno)
+            elif s == Wait.W:
+                loop.add_writer(fileno, wakeup, Ready.W)
+                await ev.wait()
+                loop.remove_writer(fileno)
+            elif s == Wait.RW:
+                loop.add_reader(fileno, wakeup, Ready.R)
+                loop.add_writer(fileno, wakeup, Ready.W)
+                await ev.wait()
+                loop.remove_reader(fileno)
+                loop.remove_writer(fileno)
+            else:
+                raise e.InternalError("bad poll status: %s")
+            s = gen.send(ready)
+
+    except StopIteration as ex:
+        rv: RV = ex.args[0] if ex.args else None
+        return rv
+
+
+async def wait_async_conn(gen: PQGenConn[RV]) -> RV:
+    """
+    Coroutine waiting for a connection generator to complete.
+
+    :param gen: a generator performing database operations and yielding
+        (fd, `Ready`) pairs when it would block.
+    :param timeout: timeout (in seconds) to check for other interrupt, e.g.
+        to allow Ctrl-C.
+    :return: whatever *gen* returns on completion.
 
-    Return what the generator eventually returned.
+    Behave like in `wait()`, but take the fileno to wait from the generator
+    itself, which might change during processing.
     """
     # Use an event to block and restart after the fd state changes.
     # Not sure this is the best implementation but it's a start.
     ev = Event()
     loop = get_event_loop()
     ready: Ready
-    fd: int
     s: Wait
 
     def wakeup(state: Ready) -> None:
@@ -102,52 +166,26 @@ async def wait_async(gen: Union[PQGen[RV], PQGenConn[RV]]) -> RV:
         ev.set()
 
     try:
-        # Use the first generated item to tell if it's a PQgen or PQgenConn.
-        # Note: mypy gets confused by the behaviour of this generator.
-        item = next(gen)
-        if isinstance(item, tuple):
-            fd, s = item
-            while 1:
-                ev.clear()
-                if s == Wait.R:
-                    loop.add_reader(fd, wakeup, Ready.R)
-                    await ev.wait()
-                    loop.remove_reader(fd)
-                elif s == Wait.W:
-                    loop.add_writer(fd, wakeup, Ready.W)
-                    await ev.wait()
-                    loop.remove_writer(fd)
-                elif s == Wait.RW:
-                    loop.add_reader(fd, wakeup, Ready.R)
-                    loop.add_writer(fd, wakeup, Ready.W)
-                    await ev.wait()
-                    loop.remove_reader(fd)
-                    loop.remove_writer(fd)
-                else:
-                    raise e.InternalError("bad poll status: %s")
-                fd, s = gen.send(ready)  # type: ignore[misc]
-        else:
-            fd = item  # type: ignore[assignment]
-            s = next(gen)  # type: ignore[assignment]
-            while 1:
-                ev.clear()
-                if s == Wait.R:
-                    loop.add_reader(fd, wakeup, Ready.R)
-                    await ev.wait()
-                    loop.remove_reader(fd)
-                elif s == Wait.W:
-                    loop.add_writer(fd, wakeup, Ready.W)
-                    await ev.wait()
-                    loop.remove_writer(fd)
-                elif s == Wait.RW:
-                    loop.add_reader(fd, wakeup, Ready.R)
-                    loop.add_writer(fd, wakeup, Ready.W)
-                    await ev.wait()
-                    loop.remove_reader(fd)
-                    loop.remove_writer(fd)
-                else:
-                    raise e.InternalError("bad poll status: %s")
-                s = gen.send(ready)  # type: ignore[arg-type,assignment]
+        fileno, s = next(gen)
+        while 1:
+            ev.clear()
+            if s == Wait.R:
+                loop.add_reader(fileno, wakeup, Ready.R)
+                await ev.wait()
+                loop.remove_reader(fileno)
+            elif s == Wait.W:
+                loop.add_writer(fileno, wakeup, Ready.W)
+                await ev.wait()
+                loop.remove_writer(fileno)
+            elif s == Wait.RW:
+                loop.add_reader(fileno, wakeup, Ready.R)
+                loop.add_writer(fileno, wakeup, Ready.W)
+                await ev.wait()
+                loop.remove_reader(fileno)
+                loop.remove_writer(fileno)
+            else:
+                raise e.InternalError("bad poll status: %s")
+            fileno, s = gen.send(ready)
 
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
index 16eb8267bebdb35e7560b53045957beab577c6cd..baf0119e7c3c09d42a49fd9f7e0847ea9cbcf2ad 100644 (file)
@@ -73,10 +73,6 @@ def execute(PGconn pgconn) -> PQGen[List[pq.proto.PGresult]]:
     cdef libpq.PGresult *pgres
     cdef int cires, ibres
 
-    # Start the generator by sending the connection fd, which won't change
-    # during the query process.
-    yield libpq.PQsocket(pgconn_ptr)
-
     # Sending the query
     while 1:
         if libpq.PQflush(pgconn_ptr) == 0:
index 9fddde267881a4a5fe20ba63e707dda427d0ec75..820f6e85d799d116c13677cfe6bcd3b4311c3a9a 100644 (file)
@@ -5,6 +5,10 @@ from psycopg3 import pq
 from psycopg3.generators import execute
 
 
+def execute_wait(pgconn):
+    return psycopg3.waiting.wait(execute(pgconn), pgconn.socket)
+
+
 def test_send_query(pgconn):
     # This test shows how to process an async query in all its glory
     pgconn.nonblocking = 1
@@ -63,7 +67,7 @@ def test_send_query_compact_test(pgconn):
         b"/* %s */ select pg_sleep(0.01); select 1 as foo;"
         % (b"x" * 1_000_000)
     )
-    results = psycopg3.waiting.wait(execute(pgconn))
+    results = execute_wait(pgconn)
 
     assert len(results) == 2
     assert results[0].nfields == 1
@@ -80,7 +84,7 @@ def test_send_query_compact_test(pgconn):
 
 def test_send_query_params(pgconn):
     pgconn.send_query_params(b"select $1::int + $2", [b"5", b"3"])
-    (res,) = psycopg3.waiting.wait(execute(pgconn))
+    (res,) = execute_wait(pgconn)
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == b"8"
 
@@ -91,11 +95,11 @@ def test_send_query_params(pgconn):
 
 def test_send_prepare(pgconn):
     pgconn.send_prepare(b"prep", b"select $1::int + $2::int")
-    (res,) = psycopg3.waiting.wait(execute(pgconn))
+    (res,) = execute_wait(pgconn)
     assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
 
     pgconn.send_query_prepared(b"prep", [b"3", b"5"])
-    (res,) = psycopg3.waiting.wait(execute(pgconn))
+    (res,) = execute_wait(pgconn)
     assert res.get_value(0, 0) == b"8"
 
     pgconn.finish()
@@ -107,22 +111,22 @@ def test_send_prepare(pgconn):
 
 def test_send_prepare_types(pgconn):
     pgconn.send_prepare(b"prep", b"select $1 + $2", [23, 23])
-    (res,) = psycopg3.waiting.wait(execute(pgconn))
+    (res,) = execute_wait(pgconn)
     assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
 
     pgconn.send_query_prepared(b"prep", [b"3", b"5"])
-    (res,) = psycopg3.waiting.wait(execute(pgconn))
+    (res,) = execute_wait(pgconn)
     assert res.get_value(0, 0) == b"8"
 
 
 def test_send_prepared_binary_in(pgconn):
     val = b"foo\00bar"
     pgconn.send_prepare(b"", b"select length($1::bytea), length($2::bytea)")
-    (res,) = psycopg3.waiting.wait(execute(pgconn))
+    (res,) = execute_wait(pgconn)
     assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
 
     pgconn.send_query_prepared(b"", [val, val], param_formats=[0, 1])
-    (res,) = psycopg3.waiting.wait(execute(pgconn))
+    (res,) = execute_wait(pgconn)
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == b"3"
     assert res.get_value(0, 1) == b"7"
@@ -137,12 +141,12 @@ def test_send_prepared_binary_in(pgconn):
 def test_send_prepared_binary_out(pgconn, fmt, out):
     val = b"foo\00bar"
     pgconn.send_prepare(b"", b"select $1::bytea")
-    (res,) = psycopg3.waiting.wait(execute(pgconn))
+    (res,) = execute_wait(pgconn)
     assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
 
     pgconn.send_query_prepared(
         b"", [val], param_formats=[1], result_format=fmt
     )
-    (res,) = psycopg3.waiting.wait(execute(pgconn))
+    (res,) = execute_wait(pgconn)
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == out
index 17c28877c9e9ba1e6b4bbcd88a9ef83097950438..01e2e2362ae0a579df1174a06d84a2b3f588e5bd 100644 (file)
@@ -242,7 +242,7 @@ def test_transaction_status(pgconn):
     assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
     pgconn.send_query(b"select 1")
     assert pgconn.transaction_status == pq.TransactionStatus.ACTIVE
-    psycopg3.waiting.wait(psycopg3.generators.execute(pgconn))
+    psycopg3.waiting.wait(psycopg3.generators.execute(pgconn), pgconn.socket)
     assert pgconn.transaction_status == pq.TransactionStatus.INTRANS
     pgconn.finish()
     assert pgconn.transaction_status == pq.TransactionStatus.UNKNOWN
index af938857e0860b87e347a19a5f144dd5cb7f39fb..0be39e0e9c57b25235ffaacdb1573cd0dd53ca37 100644 (file)
@@ -68,7 +68,7 @@ def test_close(conn):
     assert conn.closed
     assert conn.pgconn.status == conn.ConnStatus.BAD
 
-    with pytest.raises(psycopg3.InterfaceError):
+    with pytest.raises(psycopg3.OperationalError):
         cur.execute("select 1")
 
 
index de1879e9c860a4fe76780f5d4d2ef799fb09079c..2d57d9c0f29ea20e686a34aafa8d662d2a639f7d 100644 (file)
@@ -73,7 +73,7 @@ async def test_close(aconn):
     assert aconn.closed
     assert aconn.pgconn.status == aconn.ConnStatus.BAD
 
-    with pytest.raises(psycopg3.InterfaceError):
+    with pytest.raises(psycopg3.OperationalError):
         await cur.execute("select 1")
 
 
index 1caa2c289c4d6ad2ef06a1d943af7a1f53f073e6..5458216cd57caf59046666784667a4feeb0b7723 100644 (file)
@@ -54,7 +54,12 @@ class ListPopAll(list):
 
 @pytest.fixture
 def commands(conn, monkeypatch):
-    """The queue of commands issued internally by the test connection."""
+    """The list of commands issued internally by the test connection."""
+    yield patch_exec(conn, monkeypatch)
+
+
+def patch_exec(conn, monkeypatch):
+    """Helper to implement the commands fixture both sync and async."""
     _orig_exec_command = conn._exec_command
     L = ListPopAll()
 
@@ -63,10 +68,10 @@ def commands(conn, monkeypatch):
             command = command.decode(conn.client_encoding)
 
         L.insert(0, command)
-        _orig_exec_command(command)
+        return _orig_exec_command(command)
 
     monkeypatch.setattr(conn, "_exec_command", _exec_command)
-    yield L
+    return L
 
 
 def in_transaction(conn):
index f9cbd420dff66d73b6bb70fef6ac8d830f37545c..f3d1582e92628fc26499eb44f6a18d6fe0b75c79 100644 (file)
@@ -3,27 +3,16 @@ import pytest
 from psycopg3 import ProgrammingError, Rollback
 
 from .test_transaction import in_transaction, insert_row, inserted
-from .test_transaction import ExpectedException, ListPopAll
+from .test_transaction import ExpectedException, patch_exec
 from .test_transaction import create_test_table  # noqa  # autouse fixture
 
 pytestmark = pytest.mark.asyncio
 
 
 @pytest.fixture
-async def commands(aconn, monkeypatch):
-    """The queue of commands issued internally by the test connection."""
-    _orig_exec_command = aconn._exec_command
-    L = ListPopAll()
-
-    async def _exec_command(command):
-        if isinstance(command, bytes):
-            command = command.decode(aconn.client_encoding)
-
-        L.insert(0, command)
-        await _orig_exec_command(command)
-
-    monkeypatch.setattr(aconn, "_exec_command", _exec_command)
-    yield L
+def commands(aconn, monkeypatch):
+    """The list of commands issued internally by the test connection."""
+    yield patch_exec(aconn, monkeypatch)
 
 
 async def test_basic(aconn):