]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add open param to pool init
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 3 Jan 2022 12:09:38 +0000 (13:09 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 3 Jan 2022 15:41:10 +0000 (16:41 +0100)
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py

index 8c1b66849081cd5416e613867e28ed382895d7be..dc34528fbebbe7f647346208282d2b5326f8b3eb 100644 (file)
@@ -42,6 +42,7 @@ class BasePool(Generic[ConnectionType]):
         kwargs: Optional[Dict[str, Any]] = None,
         min_size: int = 4,
         max_size: Optional[int] = None,
+        open: bool = True,
         name: Optional[str] = None,
         timeout: float = 30.0,
         max_waiting: int = 0,
index 71479d41b656522fd0918fcd0c1acb6fea1ebbb2..98112daa72bc819791370d385ff4ae90190770c3 100644 (file)
@@ -32,6 +32,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         self,
         conninfo: str = "",
         *,
+        open: bool = True,
         connection_class: Type[Connection[Any]] = Connection,
         configure: Optional[Callable[[Connection[Any]], None]] = None,
         reset: Optional[Callable[[Connection[Any]], None]] = None,
@@ -54,7 +55,8 @@ class ConnectionPool(BasePool[Connection[Any]]):
 
         super().__init__(conninfo, **kwargs)
 
-        self.open()
+        if open:
+            self.open()
 
     def __del__(self) -> None:
         # If the '_closed' property is not set we probably failed in __init__.
@@ -228,13 +230,22 @@ class ConnectionPool(BasePool[Connection[Any]]):
 
         No-op if the pool is already opened.
         """
-        if not self._closed:
-            return
+        with self._lock:
+            if not self._closed:
+                return
 
-        self._check_open()
+            self._check_open()
 
+            self._start_workers()
+
+            self._closed = False
+            self._opened = True
+
+    def _start_workers(self) -> None:
         self._sched_runner = threading.Thread(
-            target=self._sched.run, name=f"{self.name}-scheduler", daemon=True
+            target=self._sched.run,
+            name=f"{self.name}-scheduler",
+            daemon=True,
         )
         assert not self._workers
         for i in range(self.num_workers):
@@ -259,9 +270,6 @@ class ConnectionPool(BasePool[Connection[Any]]):
         # remained unused.
         self.schedule_task(ShrinkPool(self), self.max_idle)
 
-        self._closed = False
-        self._opened = True
-
     def close(self, timeout: float = 5.0) -> None:
         """Close the pool and make it unavailable to new clients.
 
@@ -330,6 +338,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
                     )
 
     def __enter__(self) -> "ConnectionPool":
+        self.open()
         return self
 
     def __exit__(
index 46a3bd062475aa8b0d1e7c63f8a983a61448ada4..46a6b2f2b6b1533c8c6444de571b6b69cf621a8a 100644 (file)
@@ -31,6 +31,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         self,
         conninfo: str = "",
         *,
+        open: bool = True,
         connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
         configure: Optional[
             Callable[[AsyncConnection[Any]], Awaitable[None]]
@@ -63,7 +64,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
 
         super().__init__(conninfo, **kwargs)
 
-        self.open()
+        if open:
+            self.open()
 
     async def wait(self, timeout: float = 30.0) -> None:
         async with self._lock:
@@ -191,15 +193,17 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             await self._return_connection(conn)
 
     def open(self) -> None:
-        """Open the pool by starting worker tasks.
-
-        No-op if the pool is already opened.
-        """
         if not self._closed:
             return
 
         self._check_open()
 
+        self._start_workers()
+
+        self._closed = False
+        self._opened = True
+
+    def _start_workers(self) -> None:
         self._sched_runner = create_task(
             self._sched.run(), name=f"{self.name}-scheduler"
         )
@@ -218,9 +222,6 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         # remained unused.
         self.run_task(Schedule(self, ShrinkPool(self), self.max_idle))
 
-        self._closed = False
-        self._opened = True
-
     async def close(self, timeout: float = 5.0) -> None:
         if self._closed:
             return
@@ -278,6 +279,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             )
 
     async def __aenter__(self) -> "AsyncConnectionPool":
+        self.open()
         return self
 
     async def __aexit__(
index ba2812b58d0aab2b2710cb83958aaa6df98292e6..898c2ba0303d297693dd0e05d5ddc6651c1c91a2 100644 (file)
@@ -684,6 +684,57 @@ def test_closed_queue(dsn):
     assert len(success) == 2
 
 
+def test_open_explicit(dsn):
+    p = pool.ConnectionPool(dsn, open=False)
+    assert p.closed
+    with pytest.raises(pool.PoolClosed):
+        p.getconn()
+
+    with pytest.raises(pool.PoolClosed):
+        with p.connection():
+            pass
+
+    p.open()
+    try:
+        assert not p.closed
+
+        with p.connection() as conn:
+            cur = conn.execute("select 1")
+            assert cur.fetchone() == (1,)
+
+    finally:
+        p.close()
+
+
+def test_open_context(dsn):
+    p = pool.ConnectionPool(dsn, open=False)
+    assert p.closed
+
+    with p:
+        assert not p.closed
+
+        with p.connection() as conn:
+            cur = conn.execute("select 1")
+            assert cur.fetchone() == (1,)
+
+    assert p.closed
+
+
+def test_open_no_op(dsn):
+    p = pool.ConnectionPool(dsn)
+    try:
+        assert not p.closed
+        p.open()
+        assert not p.closed
+
+        with p.connection() as conn:
+            cur = conn.execute("select 1")
+            assert cur.fetchone() == (1,)
+
+    finally:
+        p.close()
+
+
 def test_reopen(dsn):
     p = pool.ConnectionPool(dsn)
     with p.connection() as conn:
index 7c2a9ae9ff805867fde94dc1768a1c172aa1fa6c..ea2868dcc75b4e016c13a92f7f6b36c82d554be5 100644 (file)
@@ -673,6 +673,57 @@ async def test_closed_queue(dsn):
     assert len(success) == 2
 
 
+async def test_open_explicit(dsn):
+    p = pool.AsyncConnectionPool(dsn, open=False)
+    assert p.closed
+    with pytest.raises(pool.PoolClosed):
+        await p.getconn()
+
+    with pytest.raises(pool.PoolClosed):
+        async with p.connection():
+            pass
+
+    p.open()
+    try:
+        assert not p.closed
+
+        async with p.connection() as conn:
+            cur = await conn.execute("select 1")
+            assert await cur.fetchone() == (1,)
+
+    finally:
+        await p.close()
+
+
+async def test_open_context(dsn):
+    p = pool.AsyncConnectionPool(dsn, open=False)
+    assert p.closed
+
+    async with p:
+        assert not p.closed
+
+        async with p.connection() as conn:
+            cur = await conn.execute("select 1")
+            assert await cur.fetchone() == (1,)
+
+    assert p.closed
+
+
+async def test_open_no_op(dsn):
+    p = pool.AsyncConnectionPool(dsn)
+    try:
+        assert not p.closed
+        p.open()
+        assert not p.closed
+
+        async with p.connection() as conn:
+            cur = await conn.execute("select 1")
+            assert await cur.fetchone() == (1,)
+
+    finally:
+        await p.close()
+
+
 async def test_reopen(dsn):
     p = pool.AsyncConnectionPool(dsn)
     async with p.connection() as conn: