]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: always gather the notifications received 975/head
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 22 Dec 2024 19:38:58 +0000 (20:38 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 26 Dec 2024 16:37:19 +0000 (17:37 +0100)
Starting to register them after the first call to notifies() is somewhat
weird. We also risk to lose notifications in a case such as:

    conn.execute("listen foo")
    conn.execute("listen bar")
    for n in conn.notifies():
        ...

docs/news.rst
psycopg/psycopg/_connection_base.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_notify.py
tests/test_notify_async.py

index e284ddedd8fdb422086e359620a9dafc44390106..d222448c5969dbe9df580373063f9ce2683a5197 100644 (file)
@@ -19,8 +19,8 @@ Python 3.3.0 (unreleased)
 Psycopg 3.2.4 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-- Don't lose notifies received between two `~Connection.notifies()` calls
-  (:ticket:`#962`).
+- Don't lose notifies received whilst the `~Connection.notifies()` iterator
+  is not running (:ticket:`#962`).
 - Make sure that the notifies callback is called during the use of the
   `~Connection.notifies()` generator (:ticket:`#972`).
 
index a2b7ac44e990ee471041379dced92779e4884098..ab34a7cc0b8934c67aa14ebd88d3b9a7feae551e 100644 (file)
@@ -117,12 +117,11 @@ class BaseConnection(Generic[Row]):
         pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
 
         # Gather notifies when the notifies() generator is not running.
-        # This handler is registered after notifies() is used te first time.
-        # backlog = None means that the handler hasn't been registered.
-        self._notifies_backlog: Deque[Notify] | None = None
+        self._notifies_backlog = Deque[Notify]()
         self._notifies_backlog_handler = partial(
             BaseConnection._add_notify_to_backlog, wself
         )
+        self.add_notify_handler(self._notifies_backlog_handler)
 
         # Attribute is only set if the connection is from a pool so we can tell
         # apart a connection in the pool too (when _pool = None)
index 598dac7c87e621836d4502ba7187d6dece8eca7d..4a590805fb13db007a4a3e57a48c3a703cff2772 100644 (file)
@@ -23,7 +23,7 @@ from ._tpc import Xid
 from .rows import Row, RowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from ._compat import Deque, Self
+from ._compat import Self
 from .conninfo import make_conninfo, conninfo_to_dict
 from .conninfo import conninfo_attempts, timeout_from_conninfo
 from ._pipeline import Pipeline
@@ -339,11 +339,9 @@ class Connection(BaseConnection[Row]):
         with self.lock:
             enc = self.pgconn._encoding
 
-            # If the backlog is set to not-None, then the handler is also set.
             # Remove the handler for the duration of this critical section to
             # avoid reporting notifies twice.
-            if self._notifies_backlog is not None:
-                self.remove_notify_handler(self._notifies_backlog_handler)
+            self.remove_notify_handler(self._notifies_backlog_handler)
 
             try:
                 while True:
@@ -378,10 +376,6 @@ class Connection(BaseConnection[Row]):
                         if interval < 0.0:
                             break
             finally:
-                # Install, or re-install, the backlog notify handler
-                # to catch notifications received while the generator was off.
-                if self._notifies_backlog is None:
-                    self._notifies_backlog = Deque()
                 self.add_notify_handler(self._notifies_backlog_handler)
 
     @contextmanager
index 577199dd00ceb0c43e191d9c66b0ff39366a1284..00700a8e12921553541067d2dd4eb10692a0b5ba 100644 (file)
@@ -20,7 +20,7 @@ from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from ._compat import Deque, Self
+from ._compat import Self
 from .conninfo import make_conninfo, conninfo_to_dict
 from .conninfo import conninfo_attempts_async, timeout_from_conninfo
 from ._pipeline import AsyncPipeline
@@ -359,11 +359,9 @@ class AsyncConnection(BaseConnection[Row]):
         async with self.lock:
             enc = self.pgconn._encoding
 
-            # If the backlog is set to not-None, then the handler is also set.
             # Remove the handler for the duration of this critical section to
             # avoid reporting notifies twice.
-            if self._notifies_backlog is not None:
-                self.remove_notify_handler(self._notifies_backlog_handler)
+            self.remove_notify_handler(self._notifies_backlog_handler)
 
             try:
                 while True:
@@ -401,11 +399,6 @@ class AsyncConnection(BaseConnection[Row]):
                         if interval < 0.0:
                             break
             finally:
-                # Install, or re-install, the backlog notify handler
-                # to catch notifications received while the generator was off.
-                if self._notifies_backlog is None:
-                    self._notifies_backlog = Deque()
-
                 self.add_notify_handler(self._notifies_backlog_handler)
 
     @asynccontextmanager
index 3871722ddaf0bfb4a3831b7d26e5a0ba4e46574d..8f741619dc5122043b9b939424d362d25785138b 100644 (file)
@@ -255,6 +255,23 @@ def test_generator_and_handler(conn, conn_cls, dsn):
     assert n2
 
 
+@pytest.mark.parametrize("query_between", [True, False])
+def test_first_notify_not_lost(conn, conn_cls, dsn, query_between):
+    conn.set_autocommit(True)
+    conn.execute("listen foo")
+
+    with conn_cls.connect(dsn, autocommit=True) as conn2:
+        conn2.execute("notify foo, 'hi'")
+
+    if query_between:
+        conn.execute("select 1")
+
+    n = None
+    for n in conn.notifies(timeout=1, stop_after=1):
+        pass
+    assert n
+
+
 @pytest.mark.slow
 @pytest.mark.timing
 @pytest.mark.parametrize("sleep_on", ["server", "client"])
index aebc333b02caa799a36c45ed1afa8a6fb2089c9b..8c87f3349635b9d02283f032ac4b9eee23162abc 100644 (file)
@@ -252,6 +252,23 @@ async def test_generator_and_handler(aconn, aconn_cls, dsn):
     assert n2
 
 
+@pytest.mark.parametrize("query_between", [True, False])
+async def test_first_notify_not_lost(aconn, aconn_cls, dsn, query_between):
+    await aconn.set_autocommit(True)
+    await aconn.execute("listen foo")
+
+    async with await aconn_cls.connect(dsn, autocommit=True) as conn2:
+        await conn2.execute("notify foo, 'hi'")
+
+    if query_between:
+        await aconn.execute("select 1")
+
+    n = None
+    async for n in aconn.notifies(timeout=1, stop_after=1):
+        pass
+    assert n
+
+
 @pytest.mark.slow
 @pytest.mark.timing
 @pytest.mark.parametrize("sleep_on", ["server", "client"])