]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: don't lose notifications between notifies() calls
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 21 Dec 2024 01:16:57 +0000 (02:16 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 22 Dec 2024 17:47:14 +0000 (18:47 +0100)
This allows to stop periodically the generator to run some queries (for
example to LISTEN/UNLISTEN certain channels) and start the generator
again without fearing to lose notification in the window.

Cloes #962.

docs/news.rst
psycopg/psycopg/_connection_base.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_notify.py
tests/test_notify_async.py

index b81b1f70c25ce11d2bdb25d912d00a28c09afecd..3b520301696312b8914d395eab40871591fe8b20 100644 (file)
@@ -13,6 +13,8 @@ Future releases
 Psycopg 3.2.4 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^
 
+- Don't lose notifies received between two `~Connection.notifies()` calls
+  (:ticket:`#962`).
 - Make sure that the notifies callback is called during the use of the
   `~Connection.notifies()` generator (:ticket:`#972`).
 
index 5741b849f1f5ff30719b100e5590464c96c6eb0b..a2b7ac44e990ee471041379dced92779e4884098 100644 (file)
@@ -23,7 +23,7 @@ from ._tpc import Xid
 from .rows import Row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from ._compat import LiteralString, Self, TypeAlias, TypeVar
+from ._compat import Deque, LiteralString, Self, TypeAlias, TypeVar
 from .pq.misc import connection_summary
 from ._pipeline import BasePipeline
 from ._preparing import PrepareManager
@@ -116,6 +116,14 @@ class BaseConnection(Generic[Row]):
         pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
         pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
 
+        # Gather notifies when the notifies() generator is not running.
+        # This handler is registered after notifies() is used te first time.
+        # backlog = None means that the handler hasn't been registered.
+        self._notifies_backlog: Deque[Notify] | None = None
+        self._notifies_backlog_handler = partial(
+            BaseConnection._add_notify_to_backlog, wself
+        )
+
         # Attribute is only set if the connection is from a pool so we can tell
         # apart a connection in the pool too (when _pool = None)
         self._pool: BasePool | None
@@ -377,6 +385,15 @@ class BaseConnection(Generic[Row]):
         for cb in self._notify_handlers:
             cb(n)
 
+    @staticmethod
+    def _add_notify_to_backlog(
+        wself: ReferenceType[BaseConnection[Row]], notify: Notify
+    ) -> None:
+        self = wself()
+        if not self or self._notifies_backlog is None:
+            return
+        self._notifies_backlog.append(notify)
+
     @property
     def prepare_threshold(self) -> int | None:
         """
index 9558e29c610794785a906cf2ed8147340fcd50c6..598dac7c87e621836d4502ba7187d6dece8eca7d 100644 (file)
@@ -23,7 +23,7 @@ from ._tpc import Xid
 from .rows import Row, RowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from ._compat import Self
+from ._compat import Deque, Self
 from .conninfo import make_conninfo, conninfo_to_dict
 from .conninfo import conninfo_attempts, timeout_from_conninfo
 from ._pipeline import Pipeline
@@ -338,31 +338,52 @@ class Connection(BaseConnection[Row]):
 
         with self.lock:
             enc = self.pgconn._encoding
-            while True:
-                try:
-                    ns = self.wait(notifies(self.pgconn), interval=interval)
-                except e._NO_TRACEBACK as ex:
-                    raise ex.with_traceback(None)
-
-                # Emit the notifications received.
-                for pgn in ns:
-                    n = Notify(
-                        pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
-                    )
-                    yield n
-                    nreceived += 1
-
-                # Stop if we have received enough notifications.
-                if stop_after is not None and nreceived >= stop_after:
-                    break
-
-                # Check the deadline after the loop to ensure that timeout=0
-                # polls at least once.
-                if deadline:
-                    interval = min(_WAIT_INTERVAL, deadline - monotonic())
-                    if interval < 0.0:
+
+            # If the backlog is set to not-None, then the handler is also set.
+            # Remove the handler for the duration of this critical section to
+            # avoid reporting notifies twice.
+            if self._notifies_backlog is not None:
+                self.remove_notify_handler(self._notifies_backlog_handler)
+
+            try:
+                while True:
+                    # if notifies were received when the generator was off,
+                    # return them in a first batch.
+                    if self._notifies_backlog:
+                        while self._notifies_backlog:
+                            yield self._notifies_backlog.popleft()
+                            nreceived += 1
+                    else:
+                        try:
+                            pgns = self.wait(notifies(self.pgconn), interval=interval)
+                        except e._NO_TRACEBACK as ex:
+                            raise ex.with_traceback(None)
+                        # Emit the notifications received.
+                        for pgn in pgns:
+                            yield Notify(
+                                pgn.relname.decode(enc),
+                                pgn.extra.decode(enc),
+                                pgn.be_pid,
+                            )
+                            nreceived += 1
+
+                    # Stop if we have received enough notifications.
+                    if stop_after is not None and nreceived >= stop_after:
                         break
 
+                    # Check the deadline after the loop to ensure that timeout=0
+                    # polls at least once.
+                    if deadline:
+                        interval = min(_WAIT_INTERVAL, deadline - monotonic())
+                        if interval < 0.0:
+                            break
+            finally:
+                # Install, or re-install, the backlog notify handler
+                # to catch notifications received while the generator was off.
+                if self._notifies_backlog is None:
+                    self._notifies_backlog = Deque()
+                self.add_notify_handler(self._notifies_backlog_handler)
+
     @contextmanager
     def pipeline(self) -> Iterator[Pipeline]:
         """Context manager to switch the connection into pipeline mode."""
index f82704e559ea92d89a59b25b501ef6f7c552e45f..577199dd00ceb0c43e191d9c66b0ff39366a1284 100644 (file)
@@ -20,7 +20,7 @@ from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from ._compat import Self
+from ._compat import Deque, Self
 from .conninfo import make_conninfo, conninfo_to_dict
 from .conninfo import conninfo_attempts_async, timeout_from_conninfo
 from ._pipeline import AsyncPipeline
@@ -358,31 +358,56 @@ class AsyncConnection(BaseConnection[Row]):
 
         async with self.lock:
             enc = self.pgconn._encoding
-            while True:
-                try:
-                    ns = await self.wait(notifies(self.pgconn), interval=interval)
-                except e._NO_TRACEBACK as ex:
-                    raise ex.with_traceback(None)
-
-                # Emit the notifications received.
-                for pgn in ns:
-                    n = Notify(
-                        pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
-                    )
-                    yield n
-                    nreceived += 1
-
-                # Stop if we have received enough notifications.
-                if stop_after is not None and nreceived >= stop_after:
-                    break
-
-                # Check the deadline after the loop to ensure that timeout=0
-                # polls at least once.
-                if deadline:
-                    interval = min(_WAIT_INTERVAL, deadline - monotonic())
-                    if interval < 0.0:
+
+            # If the backlog is set to not-None, then the handler is also set.
+            # Remove the handler for the duration of this critical section to
+            # avoid reporting notifies twice.
+            if self._notifies_backlog is not None:
+                self.remove_notify_handler(self._notifies_backlog_handler)
+
+            try:
+                while True:
+                    # if notifies were received when the generator was off,
+                    # return them in a first batch.
+                    if self._notifies_backlog:
+                        while self._notifies_backlog:
+                            yield self._notifies_backlog.popleft()
+                            nreceived += 1
+                    else:
+                        try:
+                            pgns = await self.wait(
+                                notifies(self.pgconn), interval=interval
+                            )
+                        except e._NO_TRACEBACK as ex:
+                            raise ex.with_traceback(None)
+
+                        # Emit the notifications received.
+                        for pgn in pgns:
+                            yield Notify(
+                                pgn.relname.decode(enc),
+                                pgn.extra.decode(enc),
+                                pgn.be_pid,
+                            )
+                            nreceived += 1
+
+                    # Stop if we have received enough notifications.
+                    if stop_after is not None and nreceived >= stop_after:
                         break
 
+                    # Check the deadline after the loop to ensure that timeout=0
+                    # polls at least once.
+                    if deadline:
+                        interval = min(_WAIT_INTERVAL, deadline - monotonic())
+                        if interval < 0.0:
+                            break
+            finally:
+                # Install, or re-install, the backlog notify handler
+                # to catch notifications received while the generator was off.
+                if self._notifies_backlog is None:
+                    self._notifies_backlog = Deque()
+
+                self.add_notify_handler(self._notifies_backlog_handler)
+
     @asynccontextmanager
     async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
         """Context manager to switch the connection into pipeline mode."""
index 157d64b0dbc5dbf61079f3a8680d18fb0e05f251..3871722ddaf0bfb4a3831b7d26e5a0ba4e46574d 100644 (file)
@@ -8,7 +8,7 @@ from time import time
 import pytest
 from psycopg import Notify
 
-from .acompat import sleep, gather, spawn
+from .acompat import Event, sleep, gather, spawn
 
 pytestmark = pytest.mark.crdb_skip("notify")
 
@@ -253,3 +253,45 @@ def test_generator_and_handler(conn, conn_cls, dsn):
 
     assert n1
     assert n2
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.parametrize("sleep_on", ["server", "client"])
+def test_notify_query_notify(conn_cls, dsn, sleep_on):
+    e = Event()
+    by_gen: list[int] = []
+    by_cb: list[int] = []
+    workers = []
+
+    def notifier():
+        with conn_cls.connect(dsn, autocommit=True) as conn:
+            sleep(0.1)
+            for i in range(3):
+                conn.execute("select pg_notify('counter', %s)", (str(i),))
+                sleep(0.2)
+
+    def listener():
+        with conn_cls.connect(dsn, autocommit=True) as conn:
+            conn.add_notify_handler(lambda n: by_cb.append(int(n.payload)))
+
+            conn.execute("listen counter")
+            e.set()
+            for n in conn.notifies(timeout=0.2):
+                by_gen.append(int(n.payload))
+
+            if sleep_on == "server":
+                conn.execute("select pg_sleep(0.2)")
+            else:
+                assert sleep_on == "client"
+                sleep(0.2)
+
+            for n in conn.notifies(timeout=0.2):
+                by_gen.append(int(n.payload))
+
+    workers.append(spawn(listener))
+    e.wait()
+    workers.append(spawn(notifier))
+    gather(*workers)
+
+    assert list(range(3)) == by_cb == by_gen, f"by_gen={by_gen!r}, by_cb={by_cb!r}"
index 68ffd9463d73a98fc059898faf2050d6c7a57be1..aebc333b02caa799a36c45ed1afa8a6fb2089c9b 100644 (file)
@@ -5,7 +5,7 @@ from time import time
 import pytest
 from psycopg import Notify
 
-from .acompat import alist, asleep, gather, spawn
+from .acompat import AEvent, alist, asleep, gather, spawn
 
 pytestmark = pytest.mark.crdb_skip("notify")
 
@@ -250,3 +250,45 @@ async def test_generator_and_handler(aconn, aconn_cls, dsn):
 
     assert n1
     assert n2
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.parametrize("sleep_on", ["server", "client"])
+async def test_notify_query_notify(aconn_cls, dsn, sleep_on):
+    e = AEvent()
+    by_gen: list[int] = []
+    by_cb: list[int] = []
+    workers = []
+
+    async def notifier():
+        async with await aconn_cls.connect(dsn, autocommit=True) as aconn:
+            await asleep(0.1)
+            for i in range(3):
+                await aconn.execute("select pg_notify('counter', %s)", (str(i),))
+                await asleep(0.2)
+
+    async def listener():
+        async with await aconn_cls.connect(dsn, autocommit=True) as aconn:
+            aconn.add_notify_handler(lambda n: by_cb.append(int(n.payload)))
+
+            await aconn.execute("listen counter")
+            e.set()
+            async for n in aconn.notifies(timeout=0.2):
+                by_gen.append(int(n.payload))
+
+            if sleep_on == "server":
+                await aconn.execute("select pg_sleep(0.2)")
+            else:
+                assert sleep_on == "client"
+                await asleep(0.2)
+
+            async for n in aconn.notifies(timeout=0.2):
+                by_gen.append(int(n.payload))
+
+    workers.append(spawn(listener))
+    await e.wait()
+    workers.append(spawn(notifier))
+    await gather(*workers)
+
+    assert list(range(3)) == by_cb == by_gen, f"{by_gen=}, {by_cb=}"