]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: lock the connection during a 'notifies()' call 760/head
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 1 Apr 2024 23:18:50 +0000 (23:18 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 13 Jun 2024 21:02:40 +0000 (23:02 +0200)
With the previous implementation, it was possible to sneak an execute()
while the generator is consumed. This gives the false impression that
it's possible to use the connection while listening (see #756), which is
false for reason better explored in #757.

Therefore, lock the connection while listening to notifications. If
someone wants to mix commands with listening on the same connection,
they should do it collaboratively with an adequately short notifies()
timeout.

psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_notify.py
tests/test_notify_async.py

index 8051a1b5cffa9875725679090afd20711f0bce30..9558e29c610794785a906cf2ed8147340fcd50c6 100644 (file)
@@ -326,7 +326,7 @@ class Connection(BaseConnection[Row]):
             notifications arrives in the same packet.
         """
         # Allow interrupting the wait with a signal by reducing a long timeout
-        # into shorter interval.
+        # into shorter intervals.
         if timeout is not None:
             deadline = monotonic() + timeout
             interval = min(timeout, _WAIT_INTERVAL)
@@ -336,34 +336,33 @@ class Connection(BaseConnection[Row]):
 
         nreceived = 0
 
-        while True:
-            # Collect notifications. Also get the connection encoding if any
-            # notification is received to makes sure that they are consistent.
-            try:
-                with self.lock:
+        with self.lock:
+            enc = self.pgconn._encoding
+            while True:
+                try:
                     ns = self.wait(notifies(self.pgconn), interval=interval)
-                    if ns:
-                        enc = self.pgconn._encoding
-            except e._NO_TRACEBACK as ex:
-                raise ex.with_traceback(None)
+                except e._NO_TRACEBACK as ex:
+                    raise ex.with_traceback(None)
 
-            # Emit the notifications received.
-            for pgn in ns:
-                n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
-                yield n
-                nreceived += 1
-
-            # Stop if we have received enough notifications.
-            if stop_after is not None and nreceived >= stop_after:
-                break
+                # Emit the notifications received.
+                for pgn in ns:
+                    n = Notify(
+                        pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
+                    )
+                    yield n
+                    nreceived += 1
 
-            # Check the deadline after the loop to ensure that timeout=0
-            # polls at least once.
-            if deadline:
-                interval = min(_WAIT_INTERVAL, deadline - monotonic())
-                if interval < 0.0:
+                # Stop if we have received enough notifications.
+                if stop_after is not None and nreceived >= stop_after:
                     break
 
+                # Check the deadline after the loop to ensure that timeout=0
+                # polls at least once.
+                if deadline:
+                    interval = min(_WAIT_INTERVAL, deadline - monotonic())
+                    if interval < 0.0:
+                        break
+
     @contextmanager
     def pipeline(self) -> Iterator[Pipeline]:
         """Context manager to switch the connection into pipeline mode."""
index 8252594b5ac391acc39ea7fad8aaa81afe0abf98..f82704e559ea92d89a59b25b501ef6f7c552e45f 100644 (file)
@@ -346,7 +346,7 @@ class AsyncConnection(BaseConnection[Row]):
             notifications arrives in the same packet.
         """
         # Allow interrupting the wait with a signal by reducing a long timeout
-        # into shorter interval.
+        # into shorter intervals.
         if timeout is not None:
             deadline = monotonic() + timeout
             interval = min(timeout, _WAIT_INTERVAL)
@@ -356,34 +356,33 @@ class AsyncConnection(BaseConnection[Row]):
 
         nreceived = 0
 
-        while True:
-            # Collect notifications. Also get the connection encoding if any
-            # notification is received to makes sure that they are consistent.
-            try:
-                async with self.lock:
+        async with self.lock:
+            enc = self.pgconn._encoding
+            while True:
+                try:
                     ns = await self.wait(notifies(self.pgconn), interval=interval)
-                    if ns:
-                        enc = self.pgconn._encoding
-            except e._NO_TRACEBACK as ex:
-                raise ex.with_traceback(None)
+                except e._NO_TRACEBACK as ex:
+                    raise ex.with_traceback(None)
 
-            # Emit the notifications received.
-            for pgn in ns:
-                n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
-                yield n
-                nreceived += 1
-
-            # Stop if we have received enough notifications.
-            if stop_after is not None and nreceived >= stop_after:
-                break
+                # Emit the notifications received.
+                for pgn in ns:
+                    n = Notify(
+                        pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
+                    )
+                    yield n
+                    nreceived += 1
 
-            # Check the deadline after the loop to ensure that timeout=0
-            # polls at least once.
-            if deadline:
-                interval = min(_WAIT_INTERVAL, deadline - monotonic())
-                if interval < 0.0:
+                # Stop if we have received enough notifications.
+                if stop_after is not None and nreceived >= stop_after:
                     break
 
+                # Check the deadline after the loop to ensure that timeout=0
+                # polls at least once.
+                if deadline:
+                    interval = min(_WAIT_INTERVAL, deadline - monotonic())
+                    if interval < 0.0:
+                        break
+
     @asynccontextmanager
     async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
         """Context manager to switch the connection into pipeline mode."""
index c67b331157d21868c5e34979764d56f6912ea8f0..e8bf6c91c611b85032b0f40efa8b9cae8e37a813 100644 (file)
@@ -193,3 +193,25 @@ def test_stop_after_batch(conn_cls, conn, dsn):
         assert ns[1].payload == "2"
     finally:
         gather(worker)
+
+
+@pytest.mark.slow
+def test_notifies_blocking(conn):
+
+    def listener():
+        for _ in conn.notifies(timeout=1):
+            pass
+
+    worker = spawn(listener)
+    try:
+        # Make sure the listener is listening
+        if not conn.lock.locked():
+            sleep(0.01)
+
+        t0 = time()
+        conn.execute("select 1")
+        dt = time() - t0
+    finally:
+        gather(worker)
+
+    assert dt > 0.5
index f4f0901d668705640743bff4a3b90c9ddded71fc..97e1628462247615e22ba9488dd47aac2aec6bee 100644 (file)
@@ -190,3 +190,24 @@ async def test_stop_after_batch(aconn_cls, aconn, dsn):
         assert ns[1].payload == "2"
     finally:
         await gather(worker)
+
+
+@pytest.mark.slow
+async def test_notifies_blocking(aconn):
+    async def listener():
+        async for _ in aconn.notifies(timeout=1):
+            pass
+
+    worker = spawn(listener)
+    try:
+        # Make sure the listener is listening
+        if not aconn.lock.locked():
+            await asleep(0.01)
+
+        t0 = time()
+        await aconn.execute("select 1")
+        dt = time() - t0
+    finally:
+        await gather(worker)
+
+    assert dt > 0.5