]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
tests: introduce a wait() helper function pq.PGconn tests
authorDenis Laxalde <denis.laxalde@dalibo.com>
Fri, 5 Apr 2024 10:25:10 +0000 (12:25 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 5 Apr 2024 18:35:16 +0000 (18:35 +0000)
tests/pq/test_pgconn.py

index ff18379a44ed7cabe19ff44506241b796102baea..7064a5f59f827f98ff447babf91bd5c34897b216 100644 (file)
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
 import os
 import sys
 import ctypes
@@ -9,9 +11,33 @@ import pytest
 
 import psycopg
 from psycopg import pq
+from psycopg.pq.abc import PGconn
 import psycopg.generators
 
 
+def wait(
+    conn: PGconn,
+    poll_method: str = "connect_poll",
+    return_on: pq.PollingStatus = pq.PollingStatus.OK,
+    timeout: int | None = None,
+) -> None:
+    poll = getattr(conn, poll_method)
+    while True:
+        assert conn.status != pq.ConnStatus.BAD, conn.error_message
+        rv = poll()
+        if rv == return_on:
+            return
+        elif rv == pq.PollingStatus.READING:
+            select([conn.socket], [], [], timeout)
+        elif rv == pq.PollingStatus.WRITING:
+            select([], [conn.socket], [], timeout)
+        else:
+            pytest.fail(f"unexpected poll result: {rv}")
+    assert (
+        conn.status == pq.ConnStatus.OK
+    ), f"unexpected connection status: {conn.error_message}"
+
+
 def test_connectdb(dsn):
     conn = pq.PGconn.connect(dsn.encode())
     assert conn.status == pq.ConnStatus.OK, conn.error_message
@@ -31,20 +57,7 @@ def test_connectdb_badtype(baddsn):
 def test_connect_async(dsn):
     conn = pq.PGconn.connect_start(dsn.encode())
     conn.nonblocking = 1
-    while True:
-        assert conn.status != pq.ConnStatus.BAD
-        rv = conn.connect_poll()
-        if rv == pq.PollingStatus.OK:
-            break
-        elif rv == pq.PollingStatus.READING:
-            select([conn.socket], [], [])
-        elif rv == pq.PollingStatus.WRITING:
-            select([], [conn.socket], [])
-        else:
-            assert False, rv
-
-    assert conn.status == pq.ConnStatus.OK
-
+    wait(conn)
     conn.finish()
     with pytest.raises(psycopg.OperationalError):
         conn.connect_poll()
@@ -56,18 +69,7 @@ def test_connect_async_bad(dsn):
     parsed_dsn[b"dbname"] = b"psycopg_test_not_for_real"
     dsn = b" ".join(b"%s='%s'" % item for item in parsed_dsn.items())
     conn = pq.PGconn.connect_start(dsn)
-    while True:
-        assert conn.status != pq.ConnStatus.BAD, conn.error_message
-        rv = conn.connect_poll()
-        if rv == pq.PollingStatus.FAILED:
-            break
-        elif rv == pq.PollingStatus.READING:
-            select([conn.socket], [], [])
-        elif rv == pq.PollingStatus.WRITING:
-            select([], [conn.socket], [])
-        else:
-            assert False, rv
-
+    wait(conn, return_on=pq.PollingStatus.FAILED)
     assert conn.status == pq.ConnStatus.BAD
 
 
@@ -157,17 +159,7 @@ def test_reset_async(pgconn):
     pgconn.exec_(b"select pg_terminate_backend(pg_backend_pid())")
     assert pgconn.status == pq.ConnStatus.BAD
     pgconn.reset_start()
-    while True:
-        rv = pgconn.reset_poll()
-        if rv == pq.PollingStatus.READING:
-            select([pgconn.socket], [], [])
-        elif rv == pq.PollingStatus.WRITING:
-            select([], [pgconn.socket], [])
-        else:
-            break
-
-    assert rv == pq.PollingStatus.OK
-    assert pgconn.status == pq.ConnStatus.OK
+    wait(pgconn, "reset_poll")
 
     pgconn.finish()
     with pytest.raises(psycopg.OperationalError):