]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Make testable, and test, the different waiting implementation
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 29 Dec 2020 17:09:02 +0000 (18:09 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 29 Dec 2020 17:14:35 +0000 (18:14 +0100)
Also fixed segfault calling flush on a closed connection.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/pq/pq_ctypes.py
psycopg3/psycopg3/proto.py
psycopg3/psycopg3/waiting.py
psycopg3_c/psycopg3_c/pq/pgconn.pyx
tests/test_waiting.py [new file with mode: 0644]

index 765add5baf895f86fbde66974098bcc70e238055..d22433a8e0f7286b4ec4f407dc80801d75990a01 100644 (file)
@@ -623,7 +623,7 @@ class AsyncConnection(BaseConnection):
 
     @classmethod
     async def _wait_conn(cls, gen: PQGenConn[RV]) -> RV:
-        return await waiting.wait_async_conn(gen)
+        return await waiting.wait_conn_async(gen)
 
     def _set_client_encoding(self, name: str) -> None:
         raise AttributeError(
index 9c68148263983ef2c3dfdc792b3842f04d3942f1..e09c7e5be05ae1a44c11fd6271bd2fec32dbd1b5 100644 (file)
@@ -493,6 +493,8 @@ class PGconn:
             raise PQerror(f"setting nonblocking failed: {error_message(self)}")
 
     def flush(self) -> int:
+        if not self.pgconn_ptr:
+            raise PQerror("flushing failed: the connection is closed")
         rv: int = impl.PQflush(self.pgconn_ptr)
         if rv < 0:
             raise PQerror(f"flushing failed: {error_message(self)}")
index 8cb92e5889a26aa9df68dfbb88dbb377476b5ee0..9234c7211662634076c827aab045b4d1060e4332 100644 (file)
@@ -26,6 +26,7 @@ ConnectionType = TypeVar("ConnectionType", bound="BaseConnection")
 # Waiting protocol types
 
 RV = TypeVar("RV")
+
 PQGenConn = Generator[Tuple[int, "Wait"], "Ready", RV]
 """Generator for processes where the connection file number can change.
 
index e227aabc785b7ce7c1e3d1ecfe94bf986d9d3063..ba6e005d54ac47dfd2c9ad4448170dd49e34a58d 100644 (file)
@@ -31,7 +31,9 @@ class Ready(IntEnum):
     W = EVENT_WRITE
 
 
-def wait(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
+def wait_selector(
+    gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
+) -> RV:
     """
     Wait for a generator using the best strategy available.
 
@@ -142,7 +144,7 @@ async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
         return rv
 
 
-async def wait_async_conn(gen: PQGenConn[RV]) -> RV:
+async def wait_conn_async(gen: PQGenConn[RV]) -> RV:
     """
     Coroutine waiting for a connection generator to complete.
 
@@ -243,4 +245,6 @@ if (
     selectors.DefaultSelector  # type: ignore[comparison-overlap]
     is selectors.EpollSelector
 ):
-    wait = wait_epoll  # noqa: F811
+    wait = wait_epoll
+else:
+    wait = wait_selector
index 48f453df29564d1a6944530635d29a1c5d5ef94b..0e250d48783617484a8add0ee4bff410163a7b79 100644 (file)
@@ -400,9 +400,11 @@ cdef class PGconn:
             raise PQerror(f"setting nonblocking failed: {error_message(self)}")
 
     def flush(self) -> int:
+        if self.pgconn_ptr == NULL:
+            raise PQerror(f"flushing failed: the connection is closed")
         cdef int rv = libpq.PQflush(self.pgconn_ptr)
         if rv < 0:
-            raise PQerror(f"flushing failed:{error_message(self)}")
+            raise PQerror(f"flushing failed: {error_message(self)}")
         return rv
 
     def get_cancel(self) -> PGcancel:
diff --git a/tests/test_waiting.py b/tests/test_waiting.py
new file mode 100644 (file)
index 0000000..c5f3b71
--- /dev/null
@@ -0,0 +1,107 @@
+import select
+
+import pytest
+
+import psycopg3
+from psycopg3 import waiting
+from psycopg3 import generators
+from psycopg3.pq import ConnStatus, ExecStatus
+
+
+skip_no_epoll = pytest.mark.skipif(
+    not hasattr(select, "epoll"), reason="epoll not available"
+)
+
+timeouts = [
+    {},
+    {"timeout": None},
+    {"timeout": 0},
+    {"timeout": 0.1},
+    {"timeout": 10},
+]
+
+
+@pytest.mark.parametrize("timeout", timeouts)
+def test_wait_conn(dsn, timeout):
+    gen = generators.connect(dsn)
+    conn = waiting.wait_conn(gen, **timeout)
+    assert conn.status == ConnStatus.OK
+
+
+def test_wait_conn_bad(dsn):
+    gen = generators.connect("dbname=nosuchdb")
+    with pytest.raises(psycopg3.OperationalError):
+        waiting.wait_conn(gen)
+
+
+@pytest.mark.parametrize("timeout", timeouts)
+def test_wait(pgconn, timeout):
+    pgconn.send_query(b"select 1")
+    gen = generators.execute(pgconn)
+    (res,) = waiting.wait(gen, pgconn.socket, **timeout)
+    assert res.status == ExecStatus.TUPLES_OK
+
+
+@pytest.mark.parametrize("timeout", timeouts)
+def test_wait_selector(pgconn, timeout):
+    pgconn.send_query(b"select 1")
+    gen = generators.execute(pgconn)
+    (res,) = waiting.wait_selector(gen, pgconn.socket, **timeout)
+    assert res.status == ExecStatus.TUPLES_OK
+
+
+def test_wait_selector_bad(pgconn):
+    pgconn.send_query(b"select 1")
+    gen = generators.execute(pgconn)
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        waiting.wait_selector(gen, pgconn.socket)
+
+
+@skip_no_epoll
+@pytest.mark.parametrize("timeout", timeouts)
+def test_wait_epoll(pgconn, timeout):
+    pgconn.send_query(b"select 1")
+    gen = generators.execute(pgconn)
+    (res,) = waiting.wait_epoll(gen, pgconn.socket, **timeout)
+    assert res.status == ExecStatus.TUPLES_OK
+
+
+@skip_no_epoll
+def test_wait_epoll_bad(pgconn):
+    pgconn.send_query(b"select 1")
+    gen = generators.execute(pgconn)
+    (res,) = waiting.wait_epoll(gen, pgconn.socket)
+    assert res.status == ExecStatus.TUPLES_OK
+
+
+@pytest.mark.asyncio
+async def test_wait_conn_async(dsn):
+    gen = generators.connect(dsn)
+    conn = await waiting.wait_conn_async(gen)
+    assert conn.status == ConnStatus.OK
+
+
+@pytest.mark.asyncio
+async def test_wait_conn_async_bad(dsn):
+    gen = generators.connect("dbname=nosuchdb")
+    with pytest.raises(psycopg3.OperationalError):
+        await waiting.wait_conn_async(gen)
+
+
+@pytest.mark.asyncio
+async def test_wait_async(pgconn):
+    pgconn.send_query(b"select 1")
+    gen = generators.execute(pgconn)
+    (res,) = await waiting.wait_async(gen, pgconn.socket)
+    assert res.status == ExecStatus.TUPLES_OK
+
+
+@pytest.mark.asyncio
+async def test_wait_async_bad(pgconn):
+    pgconn.send_query(b"select 1")
+    gen = generators.execute(pgconn)
+    socket = pgconn.socket
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        await waiting.wait_async(gen, socket)