From: Daniele Varrazzo Date: Sun, 7 Mar 2021 23:52:58 +0000 (+0100) Subject: Don't use threads to manage the scheduler and workers of async tasks X-Git-Tag: 3.0.dev0~87^2~26 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=da8ac2b2e303b5341f0e45cab5eb3ba3cf55b48a;p=thirdparty%2Fpsycopg.git Don't use threads to manage the scheduler and workers of async tasks --- diff --git a/psycopg3/psycopg3/pool/async_pool.py b/psycopg3/psycopg3/pool/async_pool.py index 144b58e1f..a8400aa1b 100644 --- a/psycopg3/psycopg3/pool/async_pool.py +++ b/psycopg3/psycopg3/pool/async_pool.py @@ -6,18 +6,20 @@ psycopg3 synchronous connection pool import asyncio import logging +from abc import ABC, abstractmethod from time import monotonic from types import TracebackType from typing import Any, AsyncIterator, Awaitable, Callable, Deque -from typing import Optional, Type +from typing import List, Optional, Type +from weakref import ref from collections import deque from ..pq import TransactionStatus from ..connection import AsyncConnection -from ..utils.compat import asynccontextmanager, get_running_loop +from ..utils.compat import asynccontextmanager, create_task -from . import tasks from .base import ConnectionAttempt, BasePool +from .sched import AsyncScheduler from .errors import PoolClosed, PoolTimeout logger = logging.getLogger(__name__) @@ -40,10 +42,64 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): # to notify that the pool is full self._pool_full_event: Optional[asyncio.Event] = None - self.loop = get_running_loop() + self._sched = AsyncScheduler() + self._tasks: "asyncio.Queue[MaintenanceTask]" = asyncio.Queue() + self._workers: "List[asyncio.Task[None]]" = [] super().__init__(conninfo, **kwargs) + self._sched_runner = create_task( + self._sched.run(), name=f"{self.name}-scheduler" + ) + for i in range(self.num_workers): + t = create_task( + self.worker(self._tasks), + name=f"{self.name}-worker-{i}", + ) + self._workers.append(t) + + # populate the pool with initial minconn connections in background + for i in range(self._nconns): + self.run_task(AddConnection(self)) + + # Schedule a task to shrink the pool if connections over minconn have + # remained unused. + self.run_task(Schedule(self, ShrinkPool(self), self.max_idle)) + + def run_task(self, task: "MaintenanceTask") -> None: + """Run a maintenance task in a worker thread.""" + self._tasks.put_nowait(task) + + async def schedule_task( + self, task: "MaintenanceTask", delay: float + ) -> None: + """Run a maintenance task in a worker thread in the future.""" + await self._sched.enter(delay, task.tick) + + @classmethod + async def worker(cls, q: "asyncio.Queue[MaintenanceTask]") -> None: + """Runner to execute pending maintenance task. + + The function is designed to run as a separate thread. + + Block on the queue *q*, run a task received. Finish running if a + StopWorker is received. + """ + while True: + task = await q.get() + + if isinstance(task, StopWorker): + logger.debug("terminating working task") + return + + # Run the task. Make sure don't die in the attempt. + try: + await task.run() + except Exception as e: + logger.warning( + "task run %s failed: %s: %s", task, e.__class__.__name__, e + ) + async def wait_ready(self, timeout: float = 30.0) -> None: """ Wait for the pool to be full after init. @@ -125,7 +181,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): logger.info( "growing pool %r to %s", self.name, self._nconns ) - self.run_task(tasks.AddConnection(self)) + self.run_task(AddConnection(self)) # If we are in the waiting queue, wait to be assigned a connection # (outside the critical section, so only the waiting client is locked) @@ -168,9 +224,9 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): return # Use a worker to perform eventual maintenance work in a separate thread - self.run_task(tasks.ReturnConnection(self, conn)) + self.run_task(ReturnConnection(self, conn)) - async def close(self, timeout: float = 1.0) -> None: + async def close(self, timeout: float = 5.0) -> None: """Close the pool and make it unavailable to new clients. All the waiting and future client will fail to acquire a connection @@ -196,11 +252,11 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): # putconn will just close the returned connection. # Stop the scheduler - self._sched.enter(0, None) + await self._sched.enter(0, None) # Stop the worker threads - for i in range(len(self._workers)): - self.run_task(tasks.StopWorker(self)) + for w in self._workers: + self.run_task(StopWorker(self)) # Signal to eventual clients in the queue that business is closed. for pos in waiting: @@ -211,18 +267,17 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): await conn.close() # Wait for the worker threads to terminate + wait = asyncio.gather(self._sched_runner, *self._workers) if timeout > 0: - for t in [self._sched_runner] + self._workers: - if not t.is_alive(): - continue - await self.loop.run_in_executor(None, lambda: t.join(timeout)) - if t.is_alive(): - logger.warning( - "couldn't stop thread %s in pool %r within %s seconds", - t, - self.name, - timeout, - ) + wait = asyncio.wait_for(asyncio.shield(wait), timeout=timeout) + try: + await wait + except asyncio.TimeoutError: + logger.warning( + "couldn't stop pool %r tasks within %s seconds", + self.name, + timeout, + ) async def __aenter__(self) -> "AsyncConnectionPool": return self @@ -254,7 +309,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): self._nconns += ngrow for i in range(ngrow): - self.run_task(tasks.AddConnection(self)) + self.run_task(AddConnection(self)) async def configure(self, conn: AsyncConnection) -> None: """Configure a connection after creation.""" @@ -309,8 +364,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): self.reconnect_failed() else: attempt.update_delay(now) - self.schedule_task( - tasks.AddConnection(self, attempt), attempt.delay + await self.schedule_task( + AddConnection(self, attempt), attempt.delay ) else: await self._add_to_pool(conn) @@ -322,13 +377,13 @@ class AsyncConnectionPool(BasePool[AsyncConnection]): await self._reset_connection(conn) if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: # Connection no more in working state: create a new one. - self.run_task(tasks.AddConnection(self)) + self.run_task(AddConnection(self)) logger.warning("discarding closed connection: %s", conn) return # Check if the connection is past its best before date if conn._expire_at <= monotonic(): - self.run_task(tasks.AddConnection(self)) + self.run_task(AddConnection(self)) logger.info("discarding expired connection") await conn.close() return @@ -486,4 +541,106 @@ class AsyncClient: return True -tasks.AsyncConnectionPool = AsyncConnectionPool # type: ignore +class MaintenanceTask(ABC): + """A task to run asynchronously to maintain the pool state.""" + + def __init__(self, pool: "AsyncConnectionPool"): + self.pool = ref(pool) + + def __repr__(self) -> str: + pool = self.pool() + name = repr(pool.name) if pool else "" + return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>" + + async def run(self) -> None: + """Run the task. + + This usually happens in a worker thread. Call the concrete _run() + implementation, if the pool is still alive. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + return + + await self._run(pool) + + async def tick(self) -> None: + """Run the scheduled task + + This function is called by the scheduler thread. Use a worker to + run the task for real in order to free the scheduler immediately. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + return + + pool.run_task(self) + + @abstractmethod + async def _run(self, pool: "AsyncConnectionPool") -> None: + ... + + +class StopWorker(MaintenanceTask): + """Signal the maintenance thread to terminate.""" + + async def _run(self, pool: "AsyncConnectionPool") -> None: + pass + + +class AddConnection(MaintenanceTask): + def __init__( + self, + pool: "AsyncConnectionPool", + attempt: Optional["ConnectionAttempt"] = None, + ): + super().__init__(pool) + self.attempt = attempt + + async def _run(self, pool: "AsyncConnectionPool") -> None: + await pool._add_connection(self.attempt) + + +class ReturnConnection(MaintenanceTask): + """Clean up and return a connection to the pool.""" + + def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection"): + super().__init__(pool) + self.conn = conn + + async def _run(self, pool: "AsyncConnectionPool") -> None: + await pool._return_connection(self.conn) + + +class ShrinkPool(MaintenanceTask): + """If the pool can shrink, remove one connection. + + Re-schedule periodically and also reset the minimum number of connections + in the pool. + """ + + async def _run(self, pool: "AsyncConnectionPool") -> None: + # Reschedule the task now so that in case of any error we don't lose + # the periodic run. + await pool.schedule_task(self, pool.max_idle) + await pool._shrink_pool() + + +class Schedule(MaintenanceTask): + """Schedule a task in the pool scheduler. + + This task is a trampoline to allow to use a sync call (pool.run_task) + to execute an async one (pool.schedule_task). + """ + + def __init__( + self, pool: "AsyncConnectionPool", task: MaintenanceTask, delay: float + ): + super().__init__(pool) + self.task = task + self.delay = delay + + async def _run(self, pool: "AsyncConnectionPool") -> None: + await pool.schedule_task(self.task, self.delay) diff --git a/psycopg3/psycopg3/pool/base.py b/psycopg3/psycopg3/pool/base.py index c342e69d3..c2d15ccc4 100644 --- a/psycopg3/psycopg3/pool/base.py +++ b/psycopg3/psycopg3/pool/base.py @@ -5,21 +5,14 @@ psycopg3 connection pool base class and functionalities. # Copyright (C) 2021 The Psycopg Team import logging -import threading -from queue import Queue, Empty from random import random -from typing import Any, Callable, Deque, Dict, Generic, List, Optional +from typing import Any, Callable, Deque, Dict, Generic, Optional from collections import deque from ..proto import ConnectionType -from . import tasks -from .sched import Scheduler - logger = logging.getLogger(__name__) -WORKER_TIMEOUT = 60.0 - class BasePool(Generic[ConnectionType]): @@ -69,7 +62,6 @@ class BasePool(Generic[ConnectionType]): self._nconns = minconn # currently in the pool, out, being prepared self._pool: Deque[ConnectionType] = deque() - self._sched = Scheduler() # Min number of connections in the pool in a max_idle unit of time. # It is reset periodically by the ShrinkPool scheduled task. @@ -78,61 +70,16 @@ class BasePool(Generic[ConnectionType]): # max_idle interval they weren't all used. self._nconns_min = minconn - self._tasks: "Queue[tasks.MaintenanceTask[ConnectionType]]" = Queue() - self._workers: List[threading.Thread] = [] - for i in range(num_workers): - t = threading.Thread( - target=self.worker, - args=(self._tasks,), - name=f"{self.name}-worker-{i}", - daemon=True, - ) - self._workers.append(t) - - self._sched_runner = threading.Thread( - target=self._sched.run, name=f"{self.name}-scheduler", daemon=True - ) - # _close should be the last property to be set in the state # to avoid warning on __del__ in case __init__ fails. self._closed = False - # The object state is complete. Start the worker threads - self._sched_runner.start() - for t in self._workers: - t.start() - - # populate the pool with initial minconn connections in background - for i in range(self._nconns): - self.run_task(tasks.AddConnection(self)) - - # Schedule a task to shrink the pool if connections over minconn have - # remained unused. - self.schedule_task(tasks.ShrinkPool(self), self.max_idle) - def __repr__(self) -> str: return ( f"<{self.__class__.__module__}.{self.__class__.__name__}" f" {self.name!r} at 0x{id(self):x}>" ) - def __del__(self) -> None: - # If the '_closed' property is not set we probably failed in __init__. - # Don't try anything complicated as probably it won't work. - if getattr(self, "_closed", True): - return - - # Things we can try to do on a best-effort basis while the world - # is crumbling (a-la Eternal Sunshine of the Spotless Mind) - # At worse we put an item in a queue that is being deleted. - - # Stop the scheduler - self._sched.enter(0, None) - - # Stop the worker threads - for i in range(len(self._workers)): - self.run_task(tasks.StopWorker(self)) - @property def minconn(self) -> int: return self._minconn @@ -146,49 +93,6 @@ class BasePool(Generic[ConnectionType]): """`!True` if the pool is closed.""" return self._closed - def run_task(self, task: tasks.MaintenanceTask[ConnectionType]) -> None: - """Run a maintenance task in a worker thread.""" - self._tasks.put_nowait(task) - - def schedule_task( - self, task: tasks.MaintenanceTask[Any], delay: float - ) -> None: - """Run a maintenance task in a worker thread in the future.""" - self._sched.enter(delay, task.tick) - - @classmethod - def worker(cls, q: "Queue[tasks.MaintenanceTask[ConnectionType]]") -> None: - """Runner to execute pending maintenance task. - - The function is designed to run as a separate thread. - - Block on the queue *q*, run a task received. Finish running if a - StopWorker is received. - """ - # Don't make all the workers time out at the same moment - timeout = cls._jitter(WORKER_TIMEOUT, -0.1, 0.1) - while True: - # Use a timeout to make the wait interruptable - try: - task = q.get(timeout=timeout) - except Empty: - continue - - if isinstance(task, tasks.StopWorker): - logger.debug( - "terminating working thread %s", - threading.current_thread().name, - ) - return - - # Run the task. Make sure don't die in the attempt. - try: - task.run() - except Exception as e: - logger.warning( - "task run %s failed: %s: %s", task, e.__class__.__name__, e - ) - @classmethod def _jitter(cls, value: float, min_pc: float, max_pc: float) -> float: """ diff --git a/psycopg3/psycopg3/pool/pool.py b/psycopg3/psycopg3/pool/pool.py index bbc89776a..08a7435c4 100644 --- a/psycopg3/psycopg3/pool/pool.py +++ b/psycopg3/psycopg3/pool/pool.py @@ -6,17 +6,20 @@ psycopg3 synchronous connection pool import logging import threading +from abc import ABC, abstractmethod from time import monotonic +from queue import Queue, Empty from types import TracebackType -from typing import Any, Callable, Deque, Iterator, Optional, Type +from typing import Any, Callable, Deque, Iterator, List, Optional, Type +from weakref import ref from contextlib import contextmanager from collections import deque from ..pq import TransactionStatus from ..connection import Connection -from . import tasks from .base import ConnectionAttempt, BasePool +from .sched import Scheduler from .errors import PoolClosed, PoolTimeout logger = logging.getLogger(__name__) @@ -38,8 +41,54 @@ class ConnectionPool(BasePool[Connection]): # to notify that the pool is full self._pool_full_event: Optional[threading.Event] = None + self._sched = Scheduler() + self._tasks: "Queue[MaintenanceTask]" = Queue() + self._workers: List[threading.Thread] = [] + super().__init__(conninfo, **kwargs) + self._sched_runner = threading.Thread( + target=self._sched.run, name=f"{self.name}-scheduler", daemon=True + ) + for i in range(self.num_workers): + t = threading.Thread( + target=self.worker, + args=(self._tasks,), + name=f"{self.name}-worker-{i}", + daemon=True, + ) + self._workers.append(t) + + # The object state is complete. Start the worker threads + self._sched_runner.start() + for t in self._workers: + t.start() + + # populate the pool with initial minconn connections in background + for i in range(self._nconns): + self.run_task(AddConnection(self)) + + # Schedule a task to shrink the pool if connections over minconn have + # remained unused. + self.schedule_task(ShrinkPool(self), self.max_idle) + + def __del__(self) -> None: + # If the '_closed' property is not set we probably failed in __init__. + # Don't try anything complicated as probably it won't work. + if getattr(self, "_closed", True): + return + + # Things we can try to do on a best-effort basis while the world + # is crumbling (a-la Eternal Sunshine of the Spotless Mind) + # At worse we put an item in a queue that is being deleted. + + # Stop the scheduler + self._sched.enter(0, None) + + # Stop the worker threads + for i in range(len(self._workers)): + self.run_task(StopWorker(self)) + def wait_ready(self, timeout: float = 30.0) -> None: """ Wait for the pool to be full after init. @@ -117,7 +166,7 @@ class ConnectionPool(BasePool[Connection]): logger.info( "growing pool %r to %s", self.name, self._nconns ) - self.run_task(tasks.AddConnection(self)) + self.run_task(AddConnection(self)) # If we are in the waiting queue, wait to be assigned a connection # (outside the critical section, so only the waiting client is locked) @@ -160,7 +209,7 @@ class ConnectionPool(BasePool[Connection]): return # Use a worker to perform eventual maintenance work in a separate thread - self.run_task(tasks.ReturnConnection(self, conn)) + self.run_task(ReturnConnection(self, conn)) def close(self, timeout: float = 1.0) -> None: """Close the pool and make it unavailable to new clients. @@ -192,7 +241,7 @@ class ConnectionPool(BasePool[Connection]): # Stop the worker threads for i in range(len(self._workers)): - self.run_task(tasks.StopWorker(self)) + self.run_task(StopWorker(self)) # Signal to eventual clients in the queue that business is closed. for pos in waiting: @@ -244,7 +293,7 @@ class ConnectionPool(BasePool[Connection]): self._nconns += ngrow for i in range(ngrow): - self.run_task(tasks.AddConnection(self)) + self.run_task(AddConnection(self)) def check(self) -> None: """Verify the state of the connections currently in the pool. @@ -262,7 +311,7 @@ class ConnectionPool(BasePool[Connection]): conn.execute("select 1") except Exception: logger.warning("discarding broken connection: %s", conn) - self.run_task(tasks.AddConnection(self)) + self.run_task(AddConnection(self)) else: self._add_to_pool(conn) @@ -276,6 +325,49 @@ class ConnectionPool(BasePool[Connection]): """ self._reconnect_failed(self) + def run_task(self, task: "MaintenanceTask") -> None: + """Run a maintenance task in a worker thread.""" + self._tasks.put_nowait(task) + + def schedule_task(self, task: "MaintenanceTask", delay: float) -> None: + """Run a maintenance task in a worker thread in the future.""" + self._sched.enter(delay, task.tick) + + _WORKER_TIMEOUT = 60.0 + + @classmethod + def worker(cls, q: "Queue[MaintenanceTask]") -> None: + """Runner to execute pending maintenance task. + + The function is designed to run as a separate thread. + + Block on the queue *q*, run a task received. Finish running if a + StopWorker is received. + """ + # Don't make all the workers time out at the same moment + timeout = cls._jitter(cls._WORKER_TIMEOUT, -0.1, 0.1) + while True: + # Use a timeout to make the wait interruptable + try: + task = q.get(timeout=timeout) + except Empty: + continue + + if isinstance(task, StopWorker): + logger.debug( + "terminating working thread %s", + threading.current_thread().name, + ) + return + + # Run the task. Make sure don't die in the attempt. + try: + task.run() + except Exception as e: + logger.warning( + "task run %s failed: %s: %s", task, e.__class__.__name__, e + ) + def _connect(self) -> Connection: """Return a new connection configured for the pool.""" conn = Connection.connect(self.conninfo, **self.kwargs) @@ -316,9 +408,7 @@ class ConnectionPool(BasePool[Connection]): self.reconnect_failed() else: attempt.update_delay(now) - self.schedule_task( - tasks.AddConnection(self, attempt), attempt.delay - ) + self.schedule_task(AddConnection(self, attempt), attempt.delay) else: self._add_to_pool(conn) @@ -329,13 +419,13 @@ class ConnectionPool(BasePool[Connection]): self._reset_connection(conn) if conn.pgconn.transaction_status == TransactionStatus.UNKNOWN: # Connection no more in working state: create a new one. - self.run_task(tasks.AddConnection(self)) + self.run_task(AddConnection(self)) logger.warning("discarding closed connection: %s", conn) return # Check if the connection is past its best before date if conn._expire_at <= monotonic(): - self.run_task(tasks.AddConnection(self)) + self.run_task(AddConnection(self)) logger.info("discarding expired connection") conn.close() return @@ -491,4 +581,94 @@ class WaitingClient: return True -tasks.ConnectionPool = ConnectionPool # type: ignore +class MaintenanceTask(ABC): + """A task to run asynchronously to maintain the pool state.""" + + def __init__(self, pool: "ConnectionPool"): + self.pool = ref(pool) + logger.debug( + "task created in %s: %s", threading.current_thread().name, self + ) + + def __repr__(self) -> str: + pool = self.pool() + name = repr(pool.name) if pool else "" + return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>" + + def run(self) -> None: + """Run the task. + + This usually happens in a worker thread. Call the concrete _run() + implementation, if the pool is still alive. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + return + + logger.debug( + "task running in %s: %s", threading.current_thread().name, self + ) + self._run(pool) + + def tick(self) -> None: + """Run the scheduled task + + This function is called by the scheduler thread. Use a worker to + run the task for real in order to free the scheduler immediately. + """ + pool = self.pool() + if not pool or pool.closed: + # Pool is no more working. Quietly discard the operation. + return + + pool.run_task(self) + + @abstractmethod + def _run(self, pool: "ConnectionPool") -> None: + ... + + +class StopWorker(MaintenanceTask): + """Signal the maintenance thread to terminate.""" + + def _run(self, pool: "ConnectionPool") -> None: + pass + + +class AddConnection(MaintenanceTask): + def __init__( + self, + pool: "ConnectionPool", + attempt: Optional["ConnectionAttempt"] = None, + ): + super().__init__(pool) + self.attempt = attempt + + def _run(self, pool: "ConnectionPool") -> None: + pool._add_connection(self.attempt) + + +class ReturnConnection(MaintenanceTask): + """Clean up and return a connection to the pool.""" + + def __init__(self, pool: "ConnectionPool", conn: "Connection"): + super().__init__(pool) + self.conn = conn + + def _run(self, pool: "ConnectionPool") -> None: + pool._return_connection(self.conn) + + +class ShrinkPool(MaintenanceTask): + """If the pool can shrink, remove one connection. + + Re-schedule periodically and also reset the minimum number of connections + in the pool. + """ + + def _run(self, pool: "ConnectionPool") -> None: + # Reschedule the task now so that in case of any error we don't lose + # the periodic run. + pool.schedule_task(self, pool.max_idle) + pool._shrink_pool() diff --git a/psycopg3/psycopg3/pool/tasks.py b/psycopg3/psycopg3/pool/tasks.py deleted file mode 100644 index 6028c3a88..000000000 --- a/psycopg3/psycopg3/pool/tasks.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -Maintenance tasks for the connection pools. -""" - -# Copyright (C) 2021 The Psycopg Team - -import asyncio -import logging -import threading -from abc import ABC, abstractmethod -from typing import Any, cast, Generic, Optional, Type, TYPE_CHECKING -from weakref import ref - -from ..proto import ConnectionType -from .. import Connection, AsyncConnection - -if TYPE_CHECKING: - from .pool import ConnectionPool - from .async_pool import AsyncConnectionPool - from .base import BasePool, ConnectionAttempt -else: - # Injected at pool.py and async_pool.py import - ConnectionPool: "Type[BasePool[Connection]]" - AsyncConnectionPool: "Type[BasePool[AsyncConnection]]" - -logger = logging.getLogger(__name__) - - -class MaintenanceTask(ABC, Generic[ConnectionType]): - """A task to run asynchronously to maintain the pool state.""" - - TIMEOUT = 10.0 - - def __init__(self, pool: "BasePool[Any]"): - if isinstance(pool, AsyncConnectionPool): - self.event = threading.Event() - - self.pool = ref(pool) - logger.debug( - "task created in %s: %s", threading.current_thread().name, self - ) - - def __repr__(self) -> str: - pool = self.pool() - name = repr(pool.name) if pool else "" - return f"<{self.__class__.__name__} {name} at 0x{id(self):x}>" - - def run(self) -> None: - """Run the task. - - This usually happens in a worker thread. Call the concrete _run() - implementation, if the pool is still alive. - """ - pool = self.pool() - if not pool or pool.closed: - # Pool is no more working. Quietly discard the operation. - return - - logger.debug( - "task running in %s: %s", threading.current_thread().name, self - ) - if isinstance(pool, ConnectionPool): - self._run(pool) - elif isinstance(pool, AsyncConnectionPool): - self.event.clear() - asyncio.run_coroutine_threadsafe(self._run_async(pool), pool.loop) - if not self.event.wait(self.TIMEOUT): - logger.warning( - "event %s didn't terminate after %s sec", self.TIMEOUT - ) - else: - logger.error("%s run got %s instead of a pool", self, pool) - - def tick(self) -> None: - """Run the scheduled task - - This function is called by the scheduler thread. Use a worker to - run the task for real in order to free the scheduler immediately. - """ - pool = self.pool() - if not pool or pool.closed: - # Pool is no more working. Quietly discard the operation. - return - - pool.run_task(self) - - @abstractmethod - def _run(self, pool: "ConnectionPool") -> None: - ... - - @abstractmethod - async def _run_async(self, pool: "AsyncConnectionPool") -> None: - self.event.set() - - -class StopWorker(MaintenanceTask[ConnectionType]): - """Signal the maintenance thread to terminate.""" - - def _run(self, pool: "ConnectionPool") -> None: - pass - - async def _run_async(self, pool: "AsyncConnectionPool") -> None: - await super()._run_async(pool) - - -class AddConnection(MaintenanceTask[ConnectionType]): - def __init__( - self, - pool: "BasePool[Any]", - attempt: Optional["ConnectionAttempt"] = None, - ): - super().__init__(pool) - self.attempt = attempt - - def _run(self, pool: "ConnectionPool") -> None: - pool._add_connection(self.attempt) - - async def _run_async(self, pool: "AsyncConnectionPool") -> None: - logger.debug("run async 1") - await pool._add_connection(self.attempt) - logger.debug("run async 2") - await super()._run_async(pool) - logger.debug("run async 3") - - -class ReturnConnection(MaintenanceTask[ConnectionType]): - """Clean up and return a connection to the pool.""" - - def __init__(self, pool: "BasePool[Any]", conn: "ConnectionType"): - super().__init__(pool) - self.conn = conn - - def _run(self, pool: "ConnectionPool") -> None: - pool._return_connection(cast(Connection, self.conn)) - - async def _run_async(self, pool: "AsyncConnectionPool") -> None: - await pool._return_connection(cast(AsyncConnection, self.conn)) - await super()._run_async(pool) - - -class ShrinkPool(MaintenanceTask[ConnectionType]): - """If the pool can shrink, remove one connection. - - Re-schedule periodically and also reset the minimum number of connections - in the pool. - """ - - def _run(self, pool: "ConnectionPool") -> None: - # Reschedule the task now so that in case of any error we don't lose - # the periodic run. - pool.schedule_task(self, pool.max_idle) - pool._shrink_pool() - - async def _run_async(self, pool: "AsyncConnectionPool") -> None: - pool.schedule_task(self, pool.max_idle) - await pool._shrink_pool() - await super()._run_async(pool) diff --git a/tests/pool/test_pool.py b/tests/pool/test_pool.py index 6c400ac22..77993c74d 100644 --- a/tests/pool/test_pool.py +++ b/tests/pool/test_pool.py @@ -480,7 +480,7 @@ def test_grow(dsn, monkeypatch, retries): @pytest.mark.slow def test_shrink(dsn, monkeypatch): - from psycopg3.pool.tasks import ShrinkPool + from psycopg3.pool.pool import ShrinkPool results = [] diff --git a/tests/pool/test_pool_async.py b/tests/pool/test_pool_async.py index 3fad68414..246bccffa 100644 --- a/tests/pool/test_pool_async.py +++ b/tests/pool/test_pool_async.py @@ -369,16 +369,16 @@ async def test_fail_rollback_close(dsn, caplog, monkeypatch): assert "BAD" in recs[2].message -async def test_close_no_threads(dsn): +async def test_close_no_tasks(dsn): p = pool.AsyncConnectionPool(dsn) - assert p._sched_runner.is_alive() + assert not p._sched_runner.done() for t in p._workers: - assert t.is_alive() + assert not t.done() await p.close() - assert not p._sched_runner.is_alive() + assert p._sched_runner.done() for t in p._workers: - assert not t.is_alive() + assert t.done() async def test_putconn_no_pool(dsn): @@ -409,16 +409,6 @@ async def test_del_no_warning(dsn, recwarn): assert not recwarn -@pytest.mark.slow -async def test_del_stop_threads(dsn): - p = pool.AsyncConnectionPool(dsn) - ts = [p._sched_runner] + p._workers - del p - await asyncio.sleep(0.1) - for t in ts: - assert not t.is_alive() - - async def test_closed_getconn(dsn): p = pool.AsyncConnectionPool(dsn, minconn=1) assert not p.closed @@ -502,18 +492,18 @@ async def test_grow(dsn, monkeypatch, retries): @pytest.mark.slow async def test_shrink(dsn, monkeypatch): - from psycopg3.pool.tasks import ShrinkPool + from psycopg3.pool.async_pool import ShrinkPool results = [] - async def run_async_hacked(self, pool): + async def run_hacked(self, pool): n0 = pool._nconns await orig_run(self, pool) n1 = pool._nconns results.append((n0, n1)) - orig_run = ShrinkPool._run_async - monkeypatch.setattr(ShrinkPool, "_run_async", run_async_hacked) + orig_run = ShrinkPool._run + monkeypatch.setattr(ShrinkPool, "_run", run_hacked) async def worker(n): async with p.connection() as conn: