]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: sync the pipeline on commit()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 9 May 2022 00:57:05 +0000 (02:57 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 9 May 2022 01:10:54 +0000 (03:10 +0200)
This maintains the expectation that, after commit(), the operation is
really persisted. As per postgres pipeline docs:

    The client must not assume that work is committed when it sends a
    COMMIT — only when the corresponding result is received to confirm the
    commit is complete.

With this change we are effectively receiving the result of the commit,
eventually throwing an exception if it happened.

Close #296

docs/advanced/pipeline.rst
psycopg/psycopg/connection.py
tests/test_pipeline.py
tests/test_pipeline_async.py

index e9fee5057b484f58fa966b418dea082a3936b7a9..54bc7f0df3e840cccdb9220b47af45cb14aaad4e 100644 (file)
@@ -237,7 +237,7 @@ Flushing query results to the client can happen either when a synchronization
 point is established by Psycopg:
 
 - using the `Pipeline.sync()` method;
-- on `Connection.rollback()`;
+- on `Connection.commit()` or `~Connection.rollback()`;
 - at the end of a `!Pipeline` block;
 - using a fetch method such as `Cursor.fetchone()` (which only flushes the
   query but doesn't issue a Sync and doesn't reset a pipeline state error).
index 225f0c56ee15023e788bd30152b013a36c8850b4..4440c2285c1391366c05ebdcf07def18e4830b79 100644 (file)
@@ -519,6 +519,9 @@ class BaseConnection(Generic[Row]):
 
         yield from self._exec_command(b"COMMIT")
 
+        if self._pipeline:
+            yield from self._pipeline._sync_gen()
+
     def _rollback_gen(self) -> PQGen[None]:
         """Generator implementing `Connection.rollback()`."""
         if self._num_transactions:
index e40fac3c574f036f6bc493e42c9ce80ea2562e41..ba3d3d6ef202b8b269dbc9d837401462dc9c213f 100644 (file)
@@ -217,6 +217,41 @@ def test_sync_syncs_errors(conn):
             p.sync()
 
 
+def test_errors_raised_on_commit(conn):
+    with conn.pipeline():
+        conn.execute("select 1 from nosuchtable")
+        with pytest.raises(e.UndefinedTable):
+            conn.commit()
+        conn.rollback()
+        cur1 = conn.execute("select 1")
+    cur2 = conn.execute("select 2")
+
+    assert cur1.fetchone() == (1,)
+    assert cur2.fetchone() == (2,)
+
+
+def test_error_on_commit(conn):
+    conn.execute(
+        """
+        drop table if exists selfref;
+        create table selfref (
+            x serial primary key,
+            y int references selfref (x) deferrable initially deferred)
+        """
+    )
+    conn.commit()
+
+    with conn.pipeline():
+        conn.execute("insert into selfref (y) values (-1)")
+        with pytest.raises(e.ForeignKeyViolation):
+            conn.commit()
+        cur1 = conn.execute("select 1")
+    cur2 = conn.execute("select 2")
+
+    assert cur1.fetchone() == (1,)
+    assert cur2.fetchone() == (2,)
+
+
 def test_fetch_no_result(conn):
     with conn.pipeline():
         cur = conn.cursor()
index 3bbb985791eab21503e888ffa5afdde3e301c1c2..6f774a5bdbf82e4c768dcac8d83a3a9bfe81883e 100644 (file)
@@ -220,6 +220,41 @@ async def test_sync_syncs_errors(aconn):
             await p.sync()
 
 
+async def test_errors_raised_on_commit(aconn):
+    async with aconn.pipeline():
+        await aconn.execute("select 1 from nosuchtable")
+        with pytest.raises(e.UndefinedTable):
+            await aconn.commit()
+        await aconn.rollback()
+        cur1 = await aconn.execute("select 1")
+    cur2 = await aconn.execute("select 2")
+
+    assert await cur1.fetchone() == (1,)
+    assert await cur2.fetchone() == (2,)
+
+
+async def test_error_on_commit(aconn):
+    await aconn.execute(
+        """
+        drop table if exists selfref;
+        create table selfref (
+            x serial primary key,
+            y int references selfref (x) deferrable initially deferred)
+        """
+    )
+    await aconn.commit()
+
+    async with aconn.pipeline():
+        await aconn.execute("insert into selfref (y) values (-1)")
+        with pytest.raises(e.ForeignKeyViolation):
+            await aconn.commit()
+        cur1 = await aconn.execute("select 1")
+    cur2 = await aconn.execute("select 2")
+
+    assert (await cur1.fetchone()) == (1,)
+    assert (await cur2.fetchone()) == (2,)
+
+
 async def test_fetch_no_result(aconn):
     async with aconn.pipeline():
         cur = aconn.cursor()