]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: don't keep the notifiers backlog handler in the connection state 992/head
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 10 Jan 2025 14:17:01 +0000 (15:17 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 10 Jan 2025 20:36:47 +0000 (21:36 +0100)
Just keep the queue in the state and special-case its handling in the
`_notify_handler` connection method instead of registering a standard handler.
Set the queue to None to signify that we are in the `notifies()` generator.

This way we don't need the awkward weak-self + class method to avoid a
reference loop and to dereference the connection weak reference another
time, as we just did in `_notify_handler()`. Setting the queue to None
also feels cleaner than adding/removing the handler.

Relates to #975.

psycopg/psycopg/_connection_base.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py

index a260fe177ec1bcc29815a6361c9231fe2a32d0cf..a5222f6318837cb5768a81d917660d03f3604ff3 100644 (file)
@@ -113,17 +113,14 @@ class BaseConnection(Generic[Row]):
         self._prepared: PrepareManager = PrepareManager()
         self._tpc: tuple[Xid, bool] | None = None  # xid, prepared
 
+        # Gather notifies when the notifies() generator is not running.
+        # It will be set to None during `notifies()` generator run.
+        self._notifies_backlog: deque[Notify] | None = deque()
+
         wself = ref(self)
         pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
         pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
 
-        # Gather notifies when the notifies() generator is not running.
-        self._notifies_backlog = deque[Notify]()
-        self._notifies_backlog_handler = partial(
-            BaseConnection._add_notify_to_backlog, wself
-        )
-        self.add_notify_handler(self._notifies_backlog_handler)
-
         # Attribute is only set if the connection is from a pool so we can tell
         # apart a connection in the pool too (when _pool = None)
         self._pool: BasePool | None
@@ -376,24 +373,19 @@ class BaseConnection(Generic[Row]):
     def _notify_handler(
         wself: ReferenceType[BaseConnection[Row]], pgn: pq.PGnotify
     ) -> None:
-        self = wself()
-        if not (self and self._notify_handlers):
+        if not (self := wself()):
             return
 
         enc = self.pgconn._encoding
         n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
+
+        # `_notifies_backlog` is None if the `notifies()` generator is running
+        if (d := self._notifies_backlog) is not None:
+            d.append(n)
+
         for cb in self._notify_handlers:
             cb(n)
 
-    @staticmethod
-    def _add_notify_to_backlog(
-        wself: ReferenceType[BaseConnection[Row]], notify: Notify
-    ) -> None:
-        self = wself()
-        if not self or self._notifies_backlog is None:
-            return
-        self._notifies_backlog.append(notify)
-
     @property
     def prepare_threshold(self) -> int | None:
         """
index 2f0c03e5adef838a85c4826db6552271e5efa6f6..a0fd1cdee9ee9a825e7c49f7ab4adaeeb004fb0e 100644 (file)
@@ -340,17 +340,17 @@ class Connection(BaseConnection[Row]):
         with self.lock:
             enc = self.pgconn._encoding
 
-            # Remove the handler for the duration of this critical section to
-            # avoid reporting notifies twice.
-            self.remove_notify_handler(self._notifies_backlog_handler)
+            # Remove the backlog deque for the duration of this critical
+            # section to avoid reporting notifies twice.
+            self._notifies_backlog, d = (None, self._notifies_backlog)
 
             try:
                 while True:
                     # if notifies were received when the generator was off,
                     # return them in a first batch.
-                    if self._notifies_backlog:
-                        while self._notifies_backlog:
-                            yield self._notifies_backlog.popleft()
+                    if d:
+                        while d:
+                            yield d.popleft()
                             nreceived += 1
                     else:
                         try:
@@ -377,7 +377,7 @@ class Connection(BaseConnection[Row]):
                         if interval < 0.0:
                             break
             finally:
-                self.add_notify_handler(self._notifies_backlog_handler)
+                self._notifies_backlog = d
 
     @contextmanager
     def pipeline(self) -> Iterator[Pipeline]:
index 20397607859f8e38d47842610e48041e704912e8..840839508b14232b27ac369265e0a84ae305e084 100644 (file)
@@ -359,17 +359,17 @@ class AsyncConnection(BaseConnection[Row]):
         async with self.lock:
             enc = self.pgconn._encoding
 
-            # Remove the handler for the duration of this critical section to
-            # avoid reporting notifies twice.
-            self.remove_notify_handler(self._notifies_backlog_handler)
+            # Remove the backlog deque for the duration of this critical
+            # section to avoid reporting notifies twice.
+            self._notifies_backlog, d = None, self._notifies_backlog
 
             try:
                 while True:
                     # if notifies were received when the generator was off,
                     # return them in a first batch.
-                    if self._notifies_backlog:
-                        while self._notifies_backlog:
-                            yield self._notifies_backlog.popleft()
+                    if d:
+                        while d:
+                            yield d.popleft()
                             nreceived += 1
                     else:
                         try:
@@ -399,7 +399,7 @@ class AsyncConnection(BaseConnection[Row]):
                         if interval < 0.0:
                             break
             finally:
-                self.add_notify_handler(self._notifies_backlog_handler)
+                self._notifies_backlog = d
 
     @asynccontextmanager
     async def pipeline(self) -> AsyncIterator[AsyncPipeline]: