From: Daniele Varrazzo Date: Mon, 9 May 2022 00:57:05 +0000 (+0200) Subject: fix: sync the pipeline on commit() X-Git-Tag: 3.1~113^2~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=666d2f8d5bc5d42f83ae807cacd0981635556f34;p=thirdparty%2Fpsycopg.git fix: sync the pipeline on commit() This maintains the expectation that, after commit(), the operation is really persisted. As per postgres pipeline docs: The client must not assume that work is committed when it sends a COMMIT — only when the corresponding result is received to confirm the commit is complete. With this change we are effectively receiving the result of the commit, eventually throwing an exception if it happened. Close #296 --- diff --git a/docs/advanced/pipeline.rst b/docs/advanced/pipeline.rst index e9fee5057..54bc7f0df 100644 --- a/docs/advanced/pipeline.rst +++ b/docs/advanced/pipeline.rst @@ -237,7 +237,7 @@ Flushing query results to the client can happen either when a synchronization point is established by Psycopg: - using the `Pipeline.sync()` method; -- on `Connection.rollback()`; +- on `Connection.commit()` or `~Connection.rollback()`; - at the end of a `!Pipeline` block; - using a fetch method such as `Cursor.fetchone()` (which only flushes the query but doesn't issue a Sync and doesn't reset a pipeline state error). diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 225f0c56e..4440c2285 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -519,6 +519,9 @@ class BaseConnection(Generic[Row]): yield from self._exec_command(b"COMMIT") + if self._pipeline: + yield from self._pipeline._sync_gen() + def _rollback_gen(self) -> PQGen[None]: """Generator implementing `Connection.rollback()`.""" if self._num_transactions: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e40fac3c5..ba3d3d6ef 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -217,6 +217,41 @@ def test_sync_syncs_errors(conn): p.sync() +def test_errors_raised_on_commit(conn): + with conn.pipeline(): + conn.execute("select 1 from nosuchtable") + with pytest.raises(e.UndefinedTable): + conn.commit() + conn.rollback() + cur1 = conn.execute("select 1") + cur2 = conn.execute("select 2") + + assert cur1.fetchone() == (1,) + assert cur2.fetchone() == (2,) + + +def test_error_on_commit(conn): + conn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) + conn.commit() + + with conn.pipeline(): + conn.execute("insert into selfref (y) values (-1)") + with pytest.raises(e.ForeignKeyViolation): + conn.commit() + cur1 = conn.execute("select 1") + cur2 = conn.execute("select 2") + + assert cur1.fetchone() == (1,) + assert cur2.fetchone() == (2,) + + def test_fetch_no_result(conn): with conn.pipeline(): cur = conn.cursor() diff --git a/tests/test_pipeline_async.py b/tests/test_pipeline_async.py index 3bbb98579..6f774a5bd 100644 --- a/tests/test_pipeline_async.py +++ b/tests/test_pipeline_async.py @@ -220,6 +220,41 @@ async def test_sync_syncs_errors(aconn): await p.sync() +async def test_errors_raised_on_commit(aconn): + async with aconn.pipeline(): + await aconn.execute("select 1 from nosuchtable") + with pytest.raises(e.UndefinedTable): + await aconn.commit() + await aconn.rollback() + cur1 = await aconn.execute("select 1") + cur2 = await aconn.execute("select 2") + + assert await cur1.fetchone() == (1,) + assert await cur2.fetchone() == (2,) + + +async def test_error_on_commit(aconn): + await aconn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) + await aconn.commit() + + async with aconn.pipeline(): + await aconn.execute("insert into selfref (y) values (-1)") + with pytest.raises(e.ForeignKeyViolation): + await aconn.commit() + cur1 = await aconn.execute("select 1") + cur2 = await aconn.execute("select 2") + + assert (await cur1.fetchone()) == (1,) + assert (await cur2.fetchone()) == (2,) + + async def test_fetch_no_result(aconn): async with aconn.pipeline(): cur = aconn.cursor()