]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Use a weak reference to avoid loops between pool and maintenance tasks
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 21 Feb 2021 01:17:41 +0000 (02:17 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool.py

index 4f2d0c0dfef674816fe97f98a05449dca1810beb..f28c902fd6ff864c8e9dd3a81a22b7c1ed5ca50b 100644 (file)
@@ -11,6 +11,7 @@ import threading
 from abc import ABC, abstractmethod
 from queue import Queue, Empty
 from typing import Any, Callable, Deque, Dict, Iterator, List, Optional, Tuple
+from weakref import ref
 from contextlib import contextmanager
 from collections import deque
 
@@ -193,7 +194,7 @@ class ConnectionPool:
         logger.debug("returning connection to %r", self.name)
 
         # If the pool is closed just close the connection instead of returning
-        # it to the poo. For extra refcare remove the pool reference from it.
+        # it to the pool. For extra refcare remove the pool reference from it.
         if self._closed:
             conn._pool = None
             conn.close()
@@ -344,11 +345,8 @@ class ConnectionPool:
                     "task run %s failed: %s: %s", task, e.__class__.__name__, e
                 )
 
-            # delete reference loops which may keep the pool alive
-            del task.pool
             if isinstance(task, StopWorker):
                 return
-            del task
 
     def _connect(self) -> Connection:
         """Return a new connection configured for the pool."""
@@ -395,36 +393,41 @@ class MaintenanceTask(ABC):
     """A task run asynchronously to maintain the pool state."""
 
     def __init__(self, pool: ConnectionPool):
-        self.pool = pool
+        self.pool = ref(pool)
         logger.debug("task created: %s", self)
 
     def __repr__(self) -> str:
-        return (
-            f"<{self.__class__.__name__} {self.pool.name!r} at 0x{id(self):x}>"
-        )
+        pool = self.pool()
+        name = repr(pool.name) if pool else "<pool is gone>"
+        return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>"
 
     def run(self) -> None:
+        pool = self.pool()
+        if not pool:
+            # Pool has been deleted. Quietly discard operation.
+            return
+
         logger.debug("task running: %s", self)
-        self._run()
+        self._run(pool)
 
     @abstractmethod
-    def _run(self) -> None:
+    def _run(self, pool: ConnectionPool) -> None:
         ...
 
 
 class StopWorker(MaintenanceTask):
     """Signal the maintenance thread to terminate."""
 
-    def _run(self) -> None:
+    def _run(self, pool: ConnectionPool) -> None:
         pass
 
 
 class AddConnection(MaintenanceTask):
     """Add a new connection into to the pool."""
 
-    def _run(self) -> None:
-        conn = self.pool._connect()
-        self.pool._add_to_pool(conn)
+    def _run(self, pool: ConnectionPool) -> None:
+        conn = pool._connect()
+        pool._add_to_pool(conn)
 
 
 class AddInitialConnection(AddConnection):
@@ -437,9 +440,9 @@ class AddInitialConnection(AddConnection):
         super().__init__(pool)
         self.event = event
 
-    def _run(self) -> None:
-        super()._run()
-        if len(self.pool._pool) >= self.pool._nconns:
+    def _run(self, pool: ConnectionPool) -> None:
+        super()._run(pool)
+        if len(pool._pool) >= pool._nconns:
             self.event.set()
 
 
@@ -450,5 +453,5 @@ class ReturnConnection(MaintenanceTask):
         super().__init__(pool)
         self.conn = conn
 
-    def _run(self) -> None:
-        self.pool._add_to_pool(self.conn)
+    def _run(self, pool: ConnectionPool) -> None:
+        pool._add_to_pool(self.conn)