]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add an open() method to connection pool classes
authorDenis Laxalde <denis.laxalde@dalibo.com>
Mon, 15 Nov 2021 08:13:04 +0000 (09:13 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 3 Jan 2022 15:41:10 +0000 (16:41 +0100)
This method is responsible for setting the '_closed' attribute, which
hence now defaults to True in the base class, along with the
_sched_runner attribute, which is reset to None in close().

docs/api/pool.rst
docs/news_pool.rst
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 3604c614e9fa231e1960726516b739c289ed1e82..66b6de2abfa110ff91b089900e460a6141343544 100644 (file)
@@ -145,6 +145,8 @@ The `!ConnectionPool` class
 
           # the connection is now back in the pool
    
+   .. automethod:: open
+
    .. automethod:: close
 
       .. note::
@@ -233,6 +235,8 @@ listed here.
 
           # the connection is now back in the pool
    
+   .. automethod:: open
+
    .. automethod:: close
 
       .. note::
index 3dfca7fc37230e09b1ee4262a76097ae7152accd..62c6a33558ca0a5d96b30285d0d3f381b695a687 100644 (file)
 Current release
 ---------------
 
+psycopg_pool 3.1.0
+^^^^^^^^^^^^^^^^^^
+
+- Add `ConnectionPool.open()` and `AsyncConnectionPool.open()`
+  (:ticket:`#155`).
+
 psycopg_pool 3.0.2
 ^^^^^^^^^^^^^^^^^^
 
index 160221507353cea71a987770139115bee4f920e6..129ffc623c432c9638b3bc4c898cef306ab19d50 100644 (file)
@@ -94,9 +94,7 @@ class BasePool(Generic[ConnectionType]):
         # connections to the pool.
         self._growing = False
 
-        # _close should be the last property to be set in the state
-        # to avoid warning on __del__ in case __init__ fails.
-        self._closed = False
+        self._closed = True
 
     def __repr__(self) -> str:
         return (
index dcbcde02ac547a3af78ec60753cf8a02b572d814..a6ec95601016646272b630043b74e070e60a1467 100644 (file)
@@ -48,35 +48,13 @@ class ConnectionPool(BasePool[Connection[Any]]):
         self._pool_full_event: Optional[threading.Event] = None
 
         self._sched = Scheduler()
+        self._sched_runner: Optional[threading.Thread] = None
         self._tasks: "Queue[MaintenanceTask]" = Queue()
         self._workers: List[threading.Thread] = []
 
         super().__init__(conninfo, **kwargs)
 
-        self._sched_runner = threading.Thread(
-            target=self._sched.run, name=f"{self.name}-scheduler", daemon=True
-        )
-        for i in range(self.num_workers):
-            t = threading.Thread(
-                target=self.worker,
-                args=(self._tasks,),
-                name=f"{self.name}-worker-{i}",
-                daemon=True,
-            )
-            self._workers.append(t)
-
-        # The object state is complete. Start the worker threads
-        self._sched_runner.start()
-        for t in self._workers:
-            t.start()
-
-        # populate the pool with initial min_size connections in background
-        for i in range(self._nconns):
-            self.run_task(AddConnection(self))
-
-        # Schedule a task to shrink the pool if connections over min_size have
-        # remained unused.
-        self.schedule_task(ShrinkPool(self), self.max_idle)
+        self.open()
 
     def __del__(self) -> None:
         # If the '_closed' property is not set we probably failed in __init__.
@@ -254,6 +232,42 @@ class ConnectionPool(BasePool[Connection[Any]]):
         else:
             self._return_connection(conn)
 
+    def open(self) -> None:
+        """Open the pool by starting worker threads.
+
+        No-op if the pool is already opened.
+        """
+        if not self._closed:
+            return
+
+        self._sched_runner = threading.Thread(
+            target=self._sched.run, name=f"{self.name}-scheduler", daemon=True
+        )
+        assert not self._workers
+        for i in range(self.num_workers):
+            t = threading.Thread(
+                target=self.worker,
+                args=(self._tasks,),
+                name=f"{self.name}-worker-{i}",
+                daemon=True,
+            )
+            self._workers.append(t)
+
+        # The object state is complete. Start the worker threads
+        self._sched_runner.start()
+        for t in self._workers:
+            t.start()
+
+        # populate the pool with initial min_size connections in background
+        for i in range(self._nconns):
+            self.run_task(AddConnection(self))
+
+        # Schedule a task to shrink the pool if connections over min_size have
+        # remained unused.
+        self.schedule_task(ShrinkPool(self), self.max_idle)
+
+        self._closed = False
+
     def close(self, timeout: float = 5.0) -> None:
         """Close the pool and make it unavailable to new clients.
 
@@ -297,6 +311,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
             conn.close()
 
         # Wait for the worker threads to terminate
+        assert self._sched_runner is not None
         if timeout > 0:
             for t in [self._sched_runner] + self._workers:
                 if not t.is_alive():
@@ -309,6 +324,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
                         self.name,
                         timeout,
                     )
+        self._sched_runner = None
 
     def __enter__(self) -> "ConnectionPool":
         return self
index 7c4f4548a949988d7bfaa6c20a12286e4af7d03d..cf6887ef4eafcc251937d20de49429cfeb15eab9 100644 (file)
@@ -57,28 +57,13 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         self._pool_full_event: Optional[asyncio.Event] = None
 
         self._sched = AsyncScheduler()
+        self._sched_runner: Optional[Task[None]] = None
         self._tasks: "asyncio.Queue[MaintenanceTask]" = asyncio.Queue()
         self._workers: List[Task[None]] = []
 
         super().__init__(conninfo, **kwargs)
 
-        self._sched_runner = create_task(
-            self._sched.run(), name=f"{self.name}-scheduler"
-        )
-        for i in range(self.num_workers):
-            t = create_task(
-                self.worker(self._tasks),
-                name=f"{self.name}-worker-{i}",
-            )
-            self._workers.append(t)
-
-        # populate the pool with initial min_size connections in background
-        for i in range(self._nconns):
-            self.run_task(AddConnection(self))
-
-        # Schedule a task to shrink the pool if connections over min_size have
-        # remained unused.
-        self.run_task(Schedule(self, ShrinkPool(self), self.max_idle))
+        self.open()
 
     async def wait(self, timeout: float = 30.0) -> None:
         async with self._lock:
@@ -205,6 +190,34 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         else:
             await self._return_connection(conn)
 
+    def open(self) -> None:
+        """Open the pool by starting worker tasks.
+
+        No-op if the pool is already opened.
+        """
+        if not self._closed:
+            return
+
+        self._sched_runner = create_task(
+            self._sched.run(), name=f"{self.name}-scheduler"
+        )
+        for i in range(self.num_workers):
+            t = create_task(
+                self.worker(self._tasks),
+                name=f"{self.name}-worker-{i}",
+            )
+            self._workers.append(t)
+
+        # populate the pool with initial min_size connections in background
+        for i in range(self._nconns):
+            self.run_task(AddConnection(self))
+
+        # Schedule a task to shrink the pool if connections over min_size have
+        # remained unused.
+        self.run_task(Schedule(self, ShrinkPool(self), self.max_idle))
+
+        self._closed = False
+
     async def close(self, timeout: float = 5.0) -> None:
         if self._closed:
             return
@@ -238,6 +251,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             await conn.close()
 
         # Wait for the worker tasks to terminate
+        assert self._sched_runner is not None
         wait = asyncio.gather(self._sched_runner, *self._workers)
         try:
             if timeout > 0:
@@ -250,6 +264,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
                 self.name,
                 timeout,
             )
+        self._sched_runner = None
 
     async def __aenter__(self) -> "AsyncConnectionPool":
         return self
index 748fd93690a4cc46ac2271e058a399e8fa218f05..e96e0518e8e0b13d0a4f446f5c5196a5b609fb66 100644 (file)
@@ -561,12 +561,12 @@ def test_fail_rollback_close(dsn, caplog, monkeypatch):
 
 def test_close_no_threads(dsn):
     p = pool.ConnectionPool(dsn)
-    assert p._sched_runner.is_alive()
+    assert p._sched_runner and p._sched_runner.is_alive()
     for t in p._workers:
         assert t.is_alive()
 
     p.close()
-    assert not p._sched_runner.is_alive()
+    assert p._sched_runner is None
     for t in p._workers:
         assert not t.is_alive()
 
@@ -603,6 +603,7 @@ def test_del_no_warning(dsn, recwarn):
 @pytest.mark.slow
 def test_del_stop_threads(dsn):
     p = pool.ConnectionPool(dsn)
+    assert p._sched_runner is not None
     ts = [p._sched_runner] + p._workers
     del p
     sleep(0.1)
index 9ec20f22a5acb4cc8590a97a7f17b14cd2ca1361..2efab352f85b176d67a2eaef3d175f1b32ec4bba 100644 (file)
@@ -576,12 +576,12 @@ async def test_fail_rollback_close(dsn, caplog, monkeypatch):
 
 async def test_close_no_tasks(dsn):
     p = pool.AsyncConnectionPool(dsn)
-    assert not p._sched_runner.done()
+    assert p._sched_runner and not p._sched_runner.done()
     for t in p._workers:
         assert not t.done()
 
     await p.close()
-    assert p._sched_runner.done()
+    assert p._sched_runner is None
     for t in p._workers:
         assert t.done()