From: Daniele Varrazzo Date: Mon, 4 Apr 2022 01:41:49 +0000 (+0200) Subject: fix: sync pipeline state on rollback X-Git-Tag: 3.1~121 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F267%2Fhead;p=thirdparty%2Fpsycopg.git fix: sync pipeline state on rollback --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 65bc2f658..fc92c0512 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -531,6 +531,11 @@ class BaseConnection(Generic[Row]): raise e.ProgrammingError( "rollback() cannot be used during a two-phase transaction" ) + + # Get out of a "pipeline aborted" state + if self._pipeline and self.pgconn.pipeline_status == pq.PipelineStatus.ABORTED: + yield from self._pipeline._sync_gen() + if self.pgconn.transaction_status == TransactionStatus.IDLE: return @@ -539,6 +544,9 @@ class BaseConnection(Generic[Row]): for cmd in self._prepared.get_maintenance_commands(): yield from self._exec_command(cmd) + if self._pipeline: + yield from self._pipeline._sync_gen() + def xid(self, format_id: int, gtrid: str, bqual: str) -> Xid: """ Returns a `Xid` to pass to the `!tpc_*()` methods of this connection. diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py index b8e36867e..35ef973d8 100644 --- a/psycopg/psycopg/transaction.py +++ b/psycopg/psycopg/transaction.py @@ -56,6 +56,7 @@ class BaseTransaction(Generic[ConnectionType]): force_rollback: bool = False, ): self._conn = connection + self.pgconn = self._conn.pgconn self._savepoint_name = savepoint_name or "" self.force_rollback = force_rollback self._entered = self._exited = False @@ -73,7 +74,7 @@ class BaseTransaction(Generic[ConnectionType]): def __repr__(self) -> str: cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}" - info = pq.misc.connection_summary(self._conn.pgconn) + info = pq.misc.connection_summary(self.pgconn) if not self._entered: status = "inactive" elif not self._exited: @@ -136,6 +137,13 @@ class BaseTransaction(Generic[ConnectionType]): if ex: raise ex + # Get out of a "pipeline aborted" state + if ( + self._conn._pipeline + and self.pgconn.pipeline_status == pq.PipelineStatus.ABORTED + ): + yield from self._conn._pipeline._sync_gen() + for command in self._get_rollback_commands(): yield from self._conn._exec_command(command) @@ -196,7 +204,7 @@ class BaseTransaction(Generic[ConnectionType]): Also set the internal state of the object and verify consistency. """ self._outer_transaction = ( - self._conn.pgconn.transaction_status == TransactionStatus.IDLE + self.pgconn.transaction_status == TransactionStatus.IDLE ) if self._outer_transaction: # outer transaction: if no name it's only a begin, else @@ -248,7 +256,7 @@ class Transaction(BaseTransaction["Connection[Any]"]): exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> bool: - if self._conn.pgconn.status == ConnStatus.OK: + if self.pgconn.status == ConnStatus.OK: with self._conn.lock: return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) else: @@ -277,7 +285,7 @@ class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]): exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> bool: - if self._conn.pgconn.status == ConnStatus.OK: + if self.pgconn.status == ConnStatus.OK: async with self._conn.lock: return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb)) else: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index aba1f9b04..d5ef4185e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -360,6 +360,26 @@ def test_outer_transaction_error(conn): conn.execute("create table voila ()") +def test_rollback_explicit(conn): + conn.autocommit = True + with conn.pipeline(): + with pytest.raises(e.DivisionByZero): + cur = conn.execute("select 1 / %s", [0]) + cur.fetchone() + conn.rollback() + conn.execute("select 1") + + +def test_rollback_transaction(conn): + conn.autocommit = True + with pytest.raises(e.DivisionByZero): + with conn.pipeline(): + with conn.transaction(): + cur = conn.execute("select 1 / %s", [0]) + cur.fetchone() + conn.execute("select 1") + + def test_concurrency(conn): with conn.transaction(): conn.execute("drop table if exists pipeline_concurrency") diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py index 269af39e6..420734824 100644 --- a/tests/test_pipeline_async.py +++ b/tests/test_pipeline_async.py @@ -364,6 +364,26 @@ async def test_outer_transaction_error(aconn): await aconn.execute("create table voila ()") +async def test_rollback_explicit(aconn): + await aconn.set_autocommit(True) + async with aconn.pipeline(): + with pytest.raises(e.DivisionByZero): + cur = await aconn.execute("select 1 / %s", [0]) + await cur.fetchone() + await aconn.rollback() + await aconn.execute("select 1") + + +async def test_rollback_transaction(aconn): + await aconn.set_autocommit(True) + with pytest.raises(e.DivisionByZero): + async with aconn.pipeline(): + async with aconn.transaction(): + cur = await aconn.execute("select 1 / %s", [0]) + await cur.fetchone() + await aconn.execute("select 1") + + async def test_concurrency(aconn): async with aconn.transaction(): await aconn.execute("drop table if exists pipeline_concurrency")