]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Don't use threads to manage the scheduler and workers of async tasks
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 7 Mar 2021 23:52:58 +0000 (00:52 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Mar 2021 04:07:25 +0000 (05:07 +0100)
psycopg3/psycopg3/pool/async_pool.py
psycopg3/psycopg3/pool/base.py
psycopg3/psycopg3/pool/pool.py
psycopg3/psycopg3/pool/tasks.py [deleted file]
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 144b58e1f322dccbf7e5b5ebd7a089db2f4934b0..a8400aa1b63ae29f96d5f38e108436de9cabb34a 100644 (file)
@@ -6,18 +6,20 @@ psycopg3 synchronous connection pool
 
 import asyncio
 import logging
+from abc import ABC, abstractmethod
 from time import monotonic
 from types import TracebackType
 from typing import Any, AsyncIterator, Awaitable, Callable, Deque
-from typing import Optional, Type
+from typing import List, Optional, Type
+from weakref import ref
 from collections import deque
 
 from ..pq import TransactionStatus
 from ..connection import AsyncConnection
-from ..utils.compat import asynccontextmanager, get_running_loop
+from ..utils.compat import asynccontextmanager, create_task
 
-from . import tasks
 from .base import ConnectionAttempt, BasePool
+from .sched import AsyncScheduler
 from .errors import PoolClosed, PoolTimeout
 
 logger = logging.getLogger(__name__)
@@ -40,10 +42,64 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
         # to notify that the pool is full
         self._pool_full_event: Optional[asyncio.Event] = None
 
-        self.loop = get_running_loop()
+        self._sched = AsyncScheduler()
+        self._tasks: "asyncio.Queue[MaintenanceTask]" = asyncio.Queue()
+        self._workers: "List[asyncio.Task[None]]" = []
 
         super().__init__(conninfo, **kwargs)
 
+        self._sched_runner = create_task(
+            self._sched.run(), name=f"{self.name}-scheduler"
+        )
+        for i in range(self.num_workers):
+            t = create_task(
+                self.worker(self._tasks),
+                name=f"{self.name}-worker-{i}",
+            )
+            self._workers.append(t)
+
+        # populate the pool with initial minconn connections in background
+        for i in range(self._nconns):
+            self.run_task(AddConnection(self))
+
+        # Schedule a task to shrink the pool if connections over minconn have
+        # remained unused.
+        self.run_task(Schedule(self, ShrinkPool(self), self.max_idle))
+
+    def run_task(self, task: "MaintenanceTask") -> None:
+        """Run a maintenance task in a worker thread."""
+        self._tasks.put_nowait(task)
+
+    async def schedule_task(
+        self, task: "MaintenanceTask", delay: float
+    ) -> None:
+        """Run a maintenance task in a worker thread in the future."""
+        await self._sched.enter(delay, task.tick)
+
+    @classmethod
+    async def worker(cls, q: "asyncio.Queue[MaintenanceTask]") -> None:
+        """Runner to execute pending maintenance task.
+
+        The function is designed to run as a separate thread.
+
+        Block on the queue *q*, run a task received. Finish running if a
+        StopWorker is received.
+        """
+        while True:
+            task = await q.get()
+
+            if isinstance(task, StopWorker):
+                logger.debug("terminating working task")
+                return
+
+            # Run the task. Make sure don't die in the attempt.
+            try:
+                await task.run()
+            except Exception as e:
+                logger.warning(
+                    "task run %s failed: %s: %s", task, e.__class__.__name__, e
+                )
+
     async def wait_ready(self, timeout: float = 30.0) -> None:
         """
         Wait for the pool to be full after init.
@@ -125,7 +181,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
                     logger.info(
                         "growing pool %r to %s", self.name, self._nconns
                     )
-                    self.run_task(tasks.AddConnection(self))
+                    self.run_task(AddConnection(self))
 
         # If we are in the waiting queue, wait to be assigned a connection
         # (outside the critical section, so only the waiting client is locked)
@@ -168,9 +224,9 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
             return
 
         # Use a worker to perform eventual maintenance work in a separate thread
-        self.run_task(tasks.ReturnConnection(self, conn))
+        self.run_task(ReturnConnection(self, conn))
 
-    async def close(self, timeout: float = 1.0) -> None:
+    async def close(self, timeout: float = 5.0) -> None:
         """Close the pool and make it unavailable to new clients.
 
         All the waiting and future client will fail to acquire a connection
@@ -196,11 +252,11 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
         # putconn will just close the returned connection.
 
         # Stop the scheduler
-        self._sched.enter(0, None)
+        await self._sched.enter(0, None)
 
         # Stop the worker threads
-        for i in range(len(self._workers)):
-            self.run_task(tasks.StopWorker(self))
+        for w in self._workers:
+            self.run_task(StopWorker(self))
 
         # Signal to eventual clients in the queue that business is closed.
         for pos in waiting:
@@ -211,18 +267,17 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
             await conn.close()
 
         # Wait for the worker threads to terminate
+        wait = asyncio.gather(self._sched_runner, *self._workers)
         if timeout > 0:
-            for t in [self._sched_runner] + self._workers:
-                if not t.is_alive():
-                    continue
-                await self.loop.run_in_executor(None, lambda: t.join(timeout))
-                if t.is_alive():
-                    logger.warning(
-                        "couldn't stop thread %s in pool %r within %s seconds",
-                        t,
-                        self.name,
-                        timeout,
-                    )
+            wait = asyncio.wait_for(asyncio.shield(wait), timeout=timeout)
+        try:
+            await wait
+        except asyncio.TimeoutError:
+            logger.warning(
+                "couldn't stop pool %r tasks within %s seconds",
+                self.name,
+                timeout,
+            )
 
     async def __aenter__(self) -> "AsyncConnectionPool":
         return self
@@ -254,7 +309,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
             self._nconns += ngrow
 
         for i in range(ngrow):
-            self.run_task(tasks.AddConnection(self))
+            self.run_task(AddConnection(self))
 
     async def configure(self, conn: AsyncConnection) -> None:
         """Configure a connection after creation."""
@@ -309,8 +364,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
                 self.reconnect_failed()
             else:
                 attempt.update_delay(now)
-                self.schedule_task(
-                    tasks.AddConnection(self, attempt), attempt.delay
+                await self.schedule_task(
+                    AddConnection(self, attempt), attempt.delay
                 )
         else:
             await self._add_to_pool(conn)
@@ -322,13 +377,13 @@ class AsyncConnectionPool(BasePool[AsyncConnection]):
         await self._reset_connection(conn)
         if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
             # Connection no more in working state: create a new one.
-            self.run_task(tasks.AddConnection(self))
+            self.run_task(AddConnection(self))
             logger.warning("discarding closed connection: %s", conn)
             return
 
         # Check if the connection is past its best before date
         if conn._expire_at <= monotonic():
-            self.run_task(tasks.AddConnection(self))
+            self.run_task(AddConnection(self))
             logger.info("discarding expired connection")
             await conn.close()
             return
@@ -486,4 +541,106 @@ class AsyncClient:
             return True
 
 
-tasks.AsyncConnectionPool = AsyncConnectionPool  # type: ignore
+class MaintenanceTask(ABC):
+    """A task to run asynchronously to maintain the pool state."""
+
+    def __init__(self, pool: "AsyncConnectionPool"):
+        self.pool = ref(pool)
+
+    def __repr__(self) -> str:
+        pool = self.pool()
+        name = repr(pool.name) if pool else "<pool is gone>"
+        return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>"
+
+    async def run(self) -> None:
+        """Run the task.
+
+        This usually happens in a worker thread. Call the concrete _run()
+        implementation, if the pool is still alive.
+        """
+        pool = self.pool()
+        if not pool or pool.closed:
+            # Pool is no more working. Quietly discard the operation.
+            return
+
+        await self._run(pool)
+
+    async def tick(self) -> None:
+        """Run the scheduled task
+
+        This function is called by the scheduler thread. Use a worker to
+        run the task for real in order to free the scheduler immediately.
+        """
+        pool = self.pool()
+        if not pool or pool.closed:
+            # Pool is no more working. Quietly discard the operation.
+            return
+
+        pool.run_task(self)
+
+    @abstractmethod
+    async def _run(self, pool: "AsyncConnectionPool") -> None:
+        ...
+
+
+class StopWorker(MaintenanceTask):
+    """Signal the maintenance thread to terminate."""
+
+    async def _run(self, pool: "AsyncConnectionPool") -> None:
+        pass
+
+
+class AddConnection(MaintenanceTask):
+    def __init__(
+        self,
+        pool: "AsyncConnectionPool",
+        attempt: Optional["ConnectionAttempt"] = None,
+    ):
+        super().__init__(pool)
+        self.attempt = attempt
+
+    async def _run(self, pool: "AsyncConnectionPool") -> None:
+        await pool._add_connection(self.attempt)
+
+
+class ReturnConnection(MaintenanceTask):
+    """Clean up and return a connection to the pool."""
+
+    def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection"):
+        super().__init__(pool)
+        self.conn = conn
+
+    async def _run(self, pool: "AsyncConnectionPool") -> None:
+        await pool._return_connection(self.conn)
+
+
+class ShrinkPool(MaintenanceTask):
+    """If the pool can shrink, remove one connection.
+
+    Re-schedule periodically and also reset the minimum number of connections
+    in the pool.
+    """
+
+    async def _run(self, pool: "AsyncConnectionPool") -> None:
+        # Reschedule the task now so that in case of any error we don't lose
+        # the periodic run.
+        await pool.schedule_task(self, pool.max_idle)
+        await pool._shrink_pool()
+
+
+class Schedule(MaintenanceTask):
+    """Schedule a task in the pool scheduler.
+
+    This task is a trampoline to allow to use a sync call (pool.run_task)
+    to execute an async one (pool.schedule_task).
+    """
+
+    def __init__(
+        self, pool: "AsyncConnectionPool", task: MaintenanceTask, delay: float
+    ):
+        super().__init__(pool)
+        self.task = task
+        self.delay = delay
+
+    async def _run(self, pool: "AsyncConnectionPool") -> None:
+        await pool.schedule_task(self.task, self.delay)
index c342e69d353411966bcf59945e9ed975c30f9e3e..c2d15ccc40362c44ac5e8c1447da9ad73db23c65 100644 (file)
@@ -5,21 +5,14 @@ psycopg3 connection pool base class and functionalities.
 # Copyright (C) 2021 The Psycopg Team
 
 import logging
-import threading
-from queue import Queue, Empty
 from random import random
-from typing import Any, Callable, Deque, Dict, Generic, List, Optional
+from typing import Any, Callable, Deque, Dict, Generic, Optional
 from collections import deque
 
 from ..proto import ConnectionType
 
-from . import tasks
-from .sched import Scheduler
-
 logger = logging.getLogger(__name__)
 
-WORKER_TIMEOUT = 60.0
-
 
 class BasePool(Generic[ConnectionType]):
 
@@ -69,7 +62,6 @@ class BasePool(Generic[ConnectionType]):
 
         self._nconns = minconn  # currently in the pool, out, being prepared
         self._pool: Deque[ConnectionType] = deque()
-        self._sched = Scheduler()
 
         # Min number of connections in the pool in a max_idle unit of time.
         # It is reset periodically by the ShrinkPool scheduled task.
@@ -78,61 +70,16 @@ class BasePool(Generic[ConnectionType]):
         # max_idle interval they weren't all used.
         self._nconns_min = minconn
 
-        self._tasks: "Queue[tasks.MaintenanceTask[ConnectionType]]" = Queue()
-        self._workers: List[threading.Thread] = []
-        for i in range(num_workers):
-            t = threading.Thread(
-                target=self.worker,
-                args=(self._tasks,),
-                name=f"{self.name}-worker-{i}",
-                daemon=True,
-            )
-            self._workers.append(t)
-
-        self._sched_runner = threading.Thread(
-            target=self._sched.run, name=f"{self.name}-scheduler", daemon=True
-        )
-
         # _close should be the last property to be set in the state
         # to avoid warning on __del__ in case __init__ fails.
         self._closed = False
 
-        # The object state is complete. Start the worker threads
-        self._sched_runner.start()
-        for t in self._workers:
-            t.start()
-
-        # populate the pool with initial minconn connections in background
-        for i in range(self._nconns):
-            self.run_task(tasks.AddConnection(self))
-
-        # Schedule a task to shrink the pool if connections over minconn have
-        # remained unused.
-        self.schedule_task(tasks.ShrinkPool(self), self.max_idle)
-
     def __repr__(self) -> str:
         return (
             f"<{self.__class__.__module__}.{self.__class__.__name__}"
             f" {self.name!r} at 0x{id(self):x}>"
         )
 
-    def __del__(self) -> None:
-        # If the '_closed' property is not set we probably failed in __init__.
-        # Don't try anything complicated as probably it won't work.
-        if getattr(self, "_closed", True):
-            return
-
-        # Things we can try to do on a best-effort basis while the world
-        # is crumbling (a-la Eternal Sunshine of the Spotless Mind)
-        # At worse we put an item in a queue that is being deleted.
-
-        # Stop the scheduler
-        self._sched.enter(0, None)
-
-        # Stop the worker threads
-        for i in range(len(self._workers)):
-            self.run_task(tasks.StopWorker(self))
-
     @property
     def minconn(self) -> int:
         return self._minconn
@@ -146,49 +93,6 @@ class BasePool(Generic[ConnectionType]):
         """`!True` if the pool is closed."""
         return self._closed
 
-    def run_task(self, task: tasks.MaintenanceTask[ConnectionType]) -> None:
-        """Run a maintenance task in a worker thread."""
-        self._tasks.put_nowait(task)
-
-    def schedule_task(
-        self, task: tasks.MaintenanceTask[Any], delay: float
-    ) -> None:
-        """Run a maintenance task in a worker thread in the future."""
-        self._sched.enter(delay, task.tick)
-
-    @classmethod
-    def worker(cls, q: "Queue[tasks.MaintenanceTask[ConnectionType]]") -> None:
-        """Runner to execute pending maintenance task.
-
-        The function is designed to run as a separate thread.
-
-        Block on the queue *q*, run a task received. Finish running if a
-        StopWorker is received.
-        """
-        # Don't make all the workers time out at the same moment
-        timeout = cls._jitter(WORKER_TIMEOUT, -0.1, 0.1)
-        while True:
-            # Use a timeout to make the wait interruptable
-            try:
-                task = q.get(timeout=timeout)
-            except Empty:
-                continue
-
-            if isinstance(task, tasks.StopWorker):
-                logger.debug(
-                    "terminating working thread %s",
-                    threading.current_thread().name,
-                )
-                return
-
-            # Run the task. Make sure don't die in the attempt.
-            try:
-                task.run()
-            except Exception as e:
-                logger.warning(
-                    "task run %s failed: %s: %s", task, e.__class__.__name__, e
-                )
-
     @classmethod
     def _jitter(cls, value: float, min_pc: float, max_pc: float) -> float:
         """
index bbc89776a4f699bf6b6efd1db0415dd4777d1fc6..08a7435c486a5f0af7645733527236583843bd54 100644 (file)
@@ -6,17 +6,20 @@ psycopg3 synchronous connection pool
 
 import logging
 import threading
+from abc import ABC, abstractmethod
 from time import monotonic
+from queue import Queue, Empty
 from types import TracebackType
-from typing import Any, Callable, Deque, Iterator, Optional, Type
+from typing import Any, Callable, Deque, Iterator, List, Optional, Type
+from weakref import ref
 from contextlib import contextmanager
 from collections import deque
 
 from ..pq import TransactionStatus
 from ..connection import Connection
 
-from . import tasks
 from .base import ConnectionAttempt, BasePool
+from .sched import Scheduler
 from .errors import PoolClosed, PoolTimeout
 
 logger = logging.getLogger(__name__)
@@ -38,8 +41,54 @@ class ConnectionPool(BasePool[Connection]):
         # to notify that the pool is full
         self._pool_full_event: Optional[threading.Event] = None
 
+        self._sched = Scheduler()
+        self._tasks: "Queue[MaintenanceTask]" = Queue()
+        self._workers: List[threading.Thread] = []
+
         super().__init__(conninfo, **kwargs)
 
+        self._sched_runner = threading.Thread(
+            target=self._sched.run, name=f"{self.name}-scheduler", daemon=True
+        )
+        for i in range(self.num_workers):
+            t = threading.Thread(
+                target=self.worker,
+                args=(self._tasks,),
+                name=f"{self.name}-worker-{i}",
+                daemon=True,
+            )
+            self._workers.append(t)
+
+        # The object state is complete. Start the worker threads
+        self._sched_runner.start()
+        for t in self._workers:
+            t.start()
+
+        # populate the pool with initial minconn connections in background
+        for i in range(self._nconns):
+            self.run_task(AddConnection(self))
+
+        # Schedule a task to shrink the pool if connections over minconn have
+        # remained unused.
+        self.schedule_task(ShrinkPool(self), self.max_idle)
+
+    def __del__(self) -> None:
+        # If the '_closed' property is not set we probably failed in __init__.
+        # Don't try anything complicated as probably it won't work.
+        if getattr(self, "_closed", True):
+            return
+
+        # Things we can try to do on a best-effort basis while the world
+        # is crumbling (a-la Eternal Sunshine of the Spotless Mind)
+        # At worse we put an item in a queue that is being deleted.
+
+        # Stop the scheduler
+        self._sched.enter(0, None)
+
+        # Stop the worker threads
+        for i in range(len(self._workers)):
+            self.run_task(StopWorker(self))
+
     def wait_ready(self, timeout: float = 30.0) -> None:
         """
         Wait for the pool to be full after init.
@@ -117,7 +166,7 @@ class ConnectionPool(BasePool[Connection]):
                     logger.info(
                         "growing pool %r to %s", self.name, self._nconns
                     )
-                    self.run_task(tasks.AddConnection(self))
+                    self.run_task(AddConnection(self))
 
         # If we are in the waiting queue, wait to be assigned a connection
         # (outside the critical section, so only the waiting client is locked)
@@ -160,7 +209,7 @@ class ConnectionPool(BasePool[Connection]):
             return
 
         # Use a worker to perform eventual maintenance work in a separate thread
-        self.run_task(tasks.ReturnConnection(self, conn))
+        self.run_task(ReturnConnection(self, conn))
 
     def close(self, timeout: float = 1.0) -> None:
         """Close the pool and make it unavailable to new clients.
@@ -192,7 +241,7 @@ class ConnectionPool(BasePool[Connection]):
 
         # Stop the worker threads
         for i in range(len(self._workers)):
-            self.run_task(tasks.StopWorker(self))
+            self.run_task(StopWorker(self))
 
         # Signal to eventual clients in the queue that business is closed.
         for pos in waiting:
@@ -244,7 +293,7 @@ class ConnectionPool(BasePool[Connection]):
             self._nconns += ngrow
 
         for i in range(ngrow):
-            self.run_task(tasks.AddConnection(self))
+            self.run_task(AddConnection(self))
 
     def check(self) -> None:
         """Verify the state of the connections currently in the pool.
@@ -262,7 +311,7 @@ class ConnectionPool(BasePool[Connection]):
                 conn.execute("select 1")
             except Exception:
                 logger.warning("discarding broken connection: %s", conn)
-                self.run_task(tasks.AddConnection(self))
+                self.run_task(AddConnection(self))
             else:
                 self._add_to_pool(conn)
 
@@ -276,6 +325,49 @@ class ConnectionPool(BasePool[Connection]):
         """
         self._reconnect_failed(self)
 
+    def run_task(self, task: "MaintenanceTask") -> None:
+        """Run a maintenance task in a worker thread."""
+        self._tasks.put_nowait(task)
+
+    def schedule_task(self, task: "MaintenanceTask", delay: float) -> None:
+        """Run a maintenance task in a worker thread in the future."""
+        self._sched.enter(delay, task.tick)
+
+    _WORKER_TIMEOUT = 60.0
+
+    @classmethod
+    def worker(cls, q: "Queue[MaintenanceTask]") -> None:
+        """Runner to execute pending maintenance task.
+
+        The function is designed to run as a separate thread.
+
+        Block on the queue *q*, run a task received. Finish running if a
+        StopWorker is received.
+        """
+        # Don't make all the workers time out at the same moment
+        timeout = cls._jitter(cls._WORKER_TIMEOUT, -0.1, 0.1)
+        while True:
+            # Use a timeout to make the wait interruptable
+            try:
+                task = q.get(timeout=timeout)
+            except Empty:
+                continue
+
+            if isinstance(task, StopWorker):
+                logger.debug(
+                    "terminating working thread %s",
+                    threading.current_thread().name,
+                )
+                return
+
+            # Run the task. Make sure don't die in the attempt.
+            try:
+                task.run()
+            except Exception as e:
+                logger.warning(
+                    "task run %s failed: %s: %s", task, e.__class__.__name__, e
+                )
+
     def _connect(self) -> Connection:
         """Return a new connection configured for the pool."""
         conn = Connection.connect(self.conninfo, **self.kwargs)
@@ -316,9 +408,7 @@ class ConnectionPool(BasePool[Connection]):
                 self.reconnect_failed()
             else:
                 attempt.update_delay(now)
-                self.schedule_task(
-                    tasks.AddConnection(self, attempt), attempt.delay
-                )
+                self.schedule_task(AddConnection(self, attempt), attempt.delay)
         else:
             self._add_to_pool(conn)
 
@@ -329,13 +419,13 @@ class ConnectionPool(BasePool[Connection]):
         self._reset_connection(conn)
         if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN:
             # Connection no more in working state: create a new one.
-            self.run_task(tasks.AddConnection(self))
+            self.run_task(AddConnection(self))
             logger.warning("discarding closed connection: %s", conn)
             return
 
         # Check if the connection is past its best before date
         if conn._expire_at <= monotonic():
-            self.run_task(tasks.AddConnection(self))
+            self.run_task(AddConnection(self))
             logger.info("discarding expired connection")
             conn.close()
             return
@@ -491,4 +581,94 @@ class WaitingClient:
             return True
 
 
-tasks.ConnectionPool = ConnectionPool  # type: ignore
+class MaintenanceTask(ABC):
+    """A task to run asynchronously to maintain the pool state."""
+
+    def __init__(self, pool: "ConnectionPool"):
+        self.pool = ref(pool)
+        logger.debug(
+            "task created in %s: %s", threading.current_thread().name, self
+        )
+
+    def __repr__(self) -> str:
+        pool = self.pool()
+        name = repr(pool.name) if pool else "<pool is gone>"
+        return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>"
+
+    def run(self) -> None:
+        """Run the task.
+
+        This usually happens in a worker thread. Call the concrete _run()
+        implementation, if the pool is still alive.
+        """
+        pool = self.pool()
+        if not pool or pool.closed:
+            # Pool is no more working. Quietly discard the operation.
+            return
+
+        logger.debug(
+            "task running in %s: %s", threading.current_thread().name, self
+        )
+        self._run(pool)
+
+    def tick(self) -> None:
+        """Run the scheduled task
+
+        This function is called by the scheduler thread. Use a worker to
+        run the task for real in order to free the scheduler immediately.
+        """
+        pool = self.pool()
+        if not pool or pool.closed:
+            # Pool is no more working. Quietly discard the operation.
+            return
+
+        pool.run_task(self)
+
+    @abstractmethod
+    def _run(self, pool: "ConnectionPool") -> None:
+        ...
+
+
+class StopWorker(MaintenanceTask):
+    """Signal the maintenance thread to terminate."""
+
+    def _run(self, pool: "ConnectionPool") -> None:
+        pass
+
+
+class AddConnection(MaintenanceTask):
+    def __init__(
+        self,
+        pool: "ConnectionPool",
+        attempt: Optional["ConnectionAttempt"] = None,
+    ):
+        super().__init__(pool)
+        self.attempt = attempt
+
+    def _run(self, pool: "ConnectionPool") -> None:
+        pool._add_connection(self.attempt)
+
+
+class ReturnConnection(MaintenanceTask):
+    """Clean up and return a connection to the pool."""
+
+    def __init__(self, pool: "ConnectionPool", conn: "Connection"):
+        super().__init__(pool)
+        self.conn = conn
+
+    def _run(self, pool: "ConnectionPool") -> None:
+        pool._return_connection(self.conn)
+
+
+class ShrinkPool(MaintenanceTask):
+    """If the pool can shrink, remove one connection.
+
+    Re-schedule periodically and also reset the minimum number of connections
+    in the pool.
+    """
+
+    def _run(self, pool: "ConnectionPool") -> None:
+        # Reschedule the task now so that in case of any error we don't lose
+        # the periodic run.
+        pool.schedule_task(self, pool.max_idle)
+        pool._shrink_pool()
diff --git a/psycopg3/psycopg3/pool/tasks.py b/psycopg3/psycopg3/pool/tasks.py
deleted file mode 100644 (file)
index 6028c3a..0000000
+++ /dev/null
@@ -1,157 +0,0 @@
-"""
-Maintenance tasks for the connection pools.
-"""
-
-# Copyright (C) 2021 The Psycopg Team
-
-import asyncio
-import logging
-import threading
-from abc import ABC, abstractmethod
-from typing import Any, cast, Generic, Optional, Type, TYPE_CHECKING
-from weakref import ref
-
-from ..proto import ConnectionType
-from .. import Connection, AsyncConnection
-
-if TYPE_CHECKING:
-    from .pool import ConnectionPool
-    from .async_pool import AsyncConnectionPool
-    from .base import BasePool, ConnectionAttempt
-else:
-    # Injected at pool.py and async_pool.py import
-    ConnectionPool: "Type[BasePool[Connection]]"
-    AsyncConnectionPool: "Type[BasePool[AsyncConnection]]"
-
-logger = logging.getLogger(__name__)
-
-
-class MaintenanceTask(ABC, Generic[ConnectionType]):
-    """A task to run asynchronously to maintain the pool state."""
-
-    TIMEOUT = 10.0
-
-    def __init__(self, pool: "BasePool[Any]"):
-        if isinstance(pool, AsyncConnectionPool):
-            self.event = threading.Event()
-
-        self.pool = ref(pool)
-        logger.debug(
-            "task created in %s: %s", threading.current_thread().name, self
-        )
-
-    def __repr__(self) -> str:
-        pool = self.pool()
-        name = repr(pool.name) if pool else "<pool is gone>"
-        return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>"
-
-    def run(self) -> None:
-        """Run the task.
-
-        This usually happens in a worker thread. Call the concrete _run()
-        implementation, if the pool is still alive.
-        """
-        pool = self.pool()
-        if not pool or pool.closed:
-            # Pool is no more working. Quietly discard the operation.
-            return
-
-        logger.debug(
-            "task running in %s: %s", threading.current_thread().name, self
-        )
-        if isinstance(pool, ConnectionPool):
-            self._run(pool)
-        elif isinstance(pool, AsyncConnectionPool):
-            self.event.clear()
-            asyncio.run_coroutine_threadsafe(self._run_async(pool), pool.loop)
-            if not self.event.wait(self.TIMEOUT):
-                logger.warning(
-                    "event %s didn't terminate after %s sec", self.TIMEOUT
-                )
-        else:
-            logger.error("%s run got %s instead of a pool", self, pool)
-
-    def tick(self) -> None:
-        """Run the scheduled task
-
-        This function is called by the scheduler thread. Use a worker to
-        run the task for real in order to free the scheduler immediately.
-        """
-        pool = self.pool()
-        if not pool or pool.closed:
-            # Pool is no more working. Quietly discard the operation.
-            return
-
-        pool.run_task(self)
-
-    @abstractmethod
-    def _run(self, pool: "ConnectionPool") -> None:
-        ...
-
-    @abstractmethod
-    async def _run_async(self, pool: "AsyncConnectionPool") -> None:
-        self.event.set()
-
-
-class StopWorker(MaintenanceTask[ConnectionType]):
-    """Signal the maintenance thread to terminate."""
-
-    def _run(self, pool: "ConnectionPool") -> None:
-        pass
-
-    async def _run_async(self, pool: "AsyncConnectionPool") -> None:
-        await super()._run_async(pool)
-
-
-class AddConnection(MaintenanceTask[ConnectionType]):
-    def __init__(
-        self,
-        pool: "BasePool[Any]",
-        attempt: Optional["ConnectionAttempt"] = None,
-    ):
-        super().__init__(pool)
-        self.attempt = attempt
-
-    def _run(self, pool: "ConnectionPool") -> None:
-        pool._add_connection(self.attempt)
-
-    async def _run_async(self, pool: "AsyncConnectionPool") -> None:
-        logger.debug("run async 1")
-        await pool._add_connection(self.attempt)
-        logger.debug("run async 2")
-        await super()._run_async(pool)
-        logger.debug("run async 3")
-
-
-class ReturnConnection(MaintenanceTask[ConnectionType]):
-    """Clean up and return a connection to the pool."""
-
-    def __init__(self, pool: "BasePool[Any]", conn: "ConnectionType"):
-        super().__init__(pool)
-        self.conn = conn
-
-    def _run(self, pool: "ConnectionPool") -> None:
-        pool._return_connection(cast(Connection, self.conn))
-
-    async def _run_async(self, pool: "AsyncConnectionPool") -> None:
-        await pool._return_connection(cast(AsyncConnection, self.conn))
-        await super()._run_async(pool)
-
-
-class ShrinkPool(MaintenanceTask[ConnectionType]):
-    """If the pool can shrink, remove one connection.
-
-    Re-schedule periodically and also reset the minimum number of connections
-    in the pool.
-    """
-
-    def _run(self, pool: "ConnectionPool") -> None:
-        # Reschedule the task now so that in case of any error we don't lose
-        # the periodic run.
-        pool.schedule_task(self, pool.max_idle)
-        pool._shrink_pool()
-
-    async def _run_async(self, pool: "AsyncConnectionPool") -> None:
-        pool.schedule_task(self, pool.max_idle)
-        await pool._shrink_pool()
-        await super()._run_async(pool)
index 6c400ac2220cc522dfa7a36eea4fac679b10bef7..77993c74d2a9ccc159f21fbecff2f06512fac2fb 100644 (file)
@@ -480,7 +480,7 @@ def test_grow(dsn, monkeypatch, retries):
 @pytest.mark.slow
 def test_shrink(dsn, monkeypatch):
 
-    from psycopg3.pool.tasks import ShrinkPool
+    from psycopg3.pool.pool import ShrinkPool
 
     results = []
 
index 3fad68414f22b67d5788862cd3b572fe4460abd7..246bccffa02b4b3b58baacf30b27ec771039e120 100644 (file)
@@ -369,16 +369,16 @@ async def test_fail_rollback_close(dsn, caplog, monkeypatch):
     assert "BAD" in recs[2].message
 
 
-async def test_close_no_threads(dsn):
+async def test_close_no_tasks(dsn):
     p = pool.AsyncConnectionPool(dsn)
-    assert p._sched_runner.is_alive()
+    assert not p._sched_runner.done()
     for t in p._workers:
-        assert t.is_alive()
+        assert not t.done()
 
     await p.close()
-    assert not p._sched_runner.is_alive()
+    assert p._sched_runner.done()
     for t in p._workers:
-        assert not t.is_alive()
+        assert t.done()
 
 
 async def test_putconn_no_pool(dsn):
@@ -409,16 +409,6 @@ async def test_del_no_warning(dsn, recwarn):
     assert not recwarn
 
 
-@pytest.mark.slow
-async def test_del_stop_threads(dsn):
-    p = pool.AsyncConnectionPool(dsn)
-    ts = [p._sched_runner] + p._workers
-    del p
-    await asyncio.sleep(0.1)
-    for t in ts:
-        assert not t.is_alive()
-
-
 async def test_closed_getconn(dsn):
     p = pool.AsyncConnectionPool(dsn, minconn=1)
     assert not p.closed
@@ -502,18 +492,18 @@ async def test_grow(dsn, monkeypatch, retries):
 @pytest.mark.slow
 async def test_shrink(dsn, monkeypatch):
 
-    from psycopg3.pool.tasks import ShrinkPool
+    from psycopg3.pool.async_pool import ShrinkPool
 
     results = []
 
-    async def run_async_hacked(self, pool):
+    async def run_hacked(self, pool):
         n0 = pool._nconns
         await orig_run(self, pool)
         n1 = pool._nconns
         results.append((n0, n1))
 
-    orig_run = ShrinkPool._run_async
-    monkeypatch.setattr(ShrinkPool, "_run_async", run_async_hacked)
+    orig_run = ShrinkPool._run
+    monkeypatch.setattr(ShrinkPool, "_run", run_hacked)
 
     async def worker(n):
         async with p.connection() as conn: