From: Daniele Varrazzo Date: Sat, 13 Feb 2021 17:02:08 +0000 (+0100) Subject: Add basic pool functionality test for queuing and timeout X-Git-Tag: 3.0.dev0~87^2~77 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=3c936e503aeefdfbf5a0379615f00437588e3213;p=thirdparty%2Fpsycopg.git Add basic pool functionality test for queuing and timeout --- diff --git a/psycopg3/psycopg3/pool.py b/psycopg3/psycopg3/pool.py index aced55ba7..d3e081ba1 100644 --- a/psycopg3/psycopg3/pool.py +++ b/psycopg3/psycopg3/pool.py @@ -37,7 +37,7 @@ class ConnectionPool: maxconn: Optional[int] = None, name: Optional[str] = None, timeout_sec: float = 30.0, - nworkers: int = 1, + num_workers: int = 1, ): if maxconn is None: maxconn = minconn @@ -58,7 +58,7 @@ class ConnectionPool: self.minconn = minconn self.maxconn = maxconn self.timeout_sec = timeout_sec - self.nworkers = nworkers + self.num_workers = num_workers self._nconns = 0 # currently in the pool, out, being prepared self._pool: List[Connection] = [] @@ -67,7 +67,7 @@ class ConnectionPool: self._wqueue: "Queue[MaintenanceTask]" = Queue() self._workers: List[threading.Thread] = [] - for i in range(nworkers): + for i in range(num_workers): t = threading.Thread(target=self.worker, args=(self._wqueue,)) t.daemon = True t.start() diff --git a/tests/test_pool.py b/tests/test_pool.py index b716597b1..febe018c5 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,6 +1,26 @@ +from time import time +from threading import Thread + +import pytest + from psycopg3 import pool +def test_minconn_maxconn(dsn): + p = pool.ConnectionPool(dsn, num_workers=0) + assert p.minconn == p.maxconn == 4 + + p = pool.ConnectionPool(dsn, minconn=2, num_workers=0) + assert p.minconn == p.maxconn == 2 + + p = pool.ConnectionPool(dsn, minconn=2, maxconn=4, num_workers=0) + assert p.minconn == 2 + assert p.maxconn == 4 + + with pytest.raises(ValueError): + pool.ConnectionPool(dsn, minconn=4, maxconn=2, num_workers=0) + + def test_pool(dsn): p = pool.ConnectionPool(dsn, minconn=2, timeout_sec=1.0) with p.connection() as conn: @@ -13,3 +33,103 @@ def test_pool(dsn): with p.connection() as conn: assert conn.pgconn.backend_pid in (pid1, pid2) + + +@pytest.mark.slow +def test_queue(dsn): + p = pool.ConnectionPool(dsn, minconn=2) + results = [] + + def worker(n): + t0 = time() + with p.connection() as conn: + (pid,) = conn.execute( + "select pg_backend_pid() from pg_sleep(0.2)" + ).fetchone() + t1 = time() + results.append((n, t1 - t0, pid)) + + ts = [] + for i in range(6): + t = Thread(target=worker, args=(i,)) + t.start() + ts.append(t) + + for t in ts: + t.join() + + assert len([r for r in results if 0.2 < r[1] < 0.35]) == 2 + assert len([r for r in results if 0.4 < r[1] < 0.55]) == 2 + assert len([r for r in results if 0.5 < r[1] < 0.75]) == 2 + assert len(set(r[2] for r in results)) == 2 + + +@pytest.mark.slow +def test_queue_timeout(dsn): + p = pool.ConnectionPool(dsn, minconn=2, timeout_sec=0.1) + results = [] + errors = [] + + def worker(n): + t0 = time() + try: + with p.connection() as conn: + (pid,) = conn.execute( + "select pg_backend_pid() from pg_sleep(0.2)" + ).fetchone() + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + ts = [] + for i in range(4): + t = Thread(target=worker, args=(i,)) + t.start() + ts.append(t) + + for t in ts: + t.join() + + assert len(results) == 2 + assert len(errors) == 2 + for e in errors: + assert 0.1 < e[1] < 0.15 + + +@pytest.mark.slow +def test_queue_timeout_override(dsn): + p = pool.ConnectionPool(dsn, minconn=2, timeout_sec=0.1) + results = [] + errors = [] + + def worker(n): + t0 = time() + timeout = 0.25 if n == 3 else None + try: + with p.connection(timeout_sec=timeout) as conn: + (pid,) = conn.execute( + "select pg_backend_pid() from pg_sleep(0.2)" + ).fetchone() + except pool.PoolTimeout as e: + t1 = time() + errors.append((n, t1 - t0, e)) + else: + t1 = time() + results.append((n, t1 - t0, pid)) + + ts = [] + for i in range(4): + t = Thread(target=worker, args=(i,)) + t.start() + ts.append(t) + + for t in ts: + t.join() + + assert len(results) == 3 + assert len(errors) == 1 + for e in errors: + assert 0.1 < e[1] < 0.15