From: Daniele Varrazzo Date: Tue, 24 Oct 2023 21:51:17 +0000 (+0200) Subject: feat: add `timeout` and `stop_after` parameters to Connection.notifies X-Git-Tag: 3.2.0~87^2~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=f88dcbc0a314abca080715a840b508f1e08bc77d;p=thirdparty%2Fpsycopg.git feat: add `timeout` and `stop_after` parameters to Connection.notifies Close #340 --- diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 2c503a126..cb0244aa5 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -10,6 +10,7 @@ Psycopg connection object (sync version) from __future__ import annotations import logging +from time import monotonic from types import TracebackType from typing import Any, Generator, Iterator, List, Optional from typing import Type, Union, cast, overload, TYPE_CHECKING @@ -39,6 +40,8 @@ from threading import Lock if TYPE_CHECKING: from .pq.abc import PGconn +_WAIT_INTERVAL = 0.1 + TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY @@ -277,20 +280,56 @@ class Connection(BaseConnection[Row]): with tx: yield tx - def notifies(self) -> Generator[Notify, None, None]: + def notifies( + self, *, timeout: Optional[float] = None, stop_after: Optional[int] = None + ) -> Generator[Notify, None, None]: """ Yield `Notify` objects as soon as they are received from the database. + + :param timeout: maximum amount of time to wait for notifications. + `!None` means no timeout. + :param stop_after: stop after receiving this number of notifications. + You might actually receive more than this number if more than one + notifications arrives in the same packet. """ + # Allow interrupting the wait with a signal by reducing a long timeout + # into shorter interval. + if timeout is not None: + deadline = monotonic() + timeout + timeout = min(timeout, _WAIT_INTERVAL) + else: + deadline = None + timeout = _WAIT_INTERVAL + + nreceived = 0 + while True: - with self.lock: - try: - ns = self.wait(notifies(self.pgconn)) - except e._NO_TRACEBACK as ex: - raise ex.with_traceback(None) - enc = pgconn_encoding(self.pgconn) + # Collect notifications. Also get the connection encoding if any + # notification is received to makes sure that they are consistent. + try: + with self.lock: + ns = self.wait(notifies(self.pgconn), timeout=timeout) + if ns: + enc = pgconn_encoding(self.pgconn) + except e._NO_TRACEBACK as ex: + raise ex.with_traceback(None) + + # Emit the notifications received. for pgn in ns: n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) yield n + nreceived += 1 + + # Stop if we have received enough notifications. + if stop_after is not None and nreceived >= stop_after: + break + + # Check the deadline after the loop to ensure that timeout=0 + # polls at least once. + if deadline: + timeout = min(_WAIT_INTERVAL, deadline - monotonic()) + if timeout < 0.0: + break @contextmanager def pipeline(self) -> Iterator[Pipeline]: @@ -312,7 +351,7 @@ class Connection(BaseConnection[Row]): assert pipeline is self._pipeline self._pipeline = None - def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: + def wait(self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL) -> RV: """ Consume a generator operating on the connection. diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 3a57df375..d810d45b2 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -7,6 +7,7 @@ Psycopg connection object (async version) from __future__ import annotations import logging +from time import monotonic from types import TracebackType from typing import Any, AsyncGenerator, AsyncIterator, List, Optional from typing import Type, Union, cast, overload, TYPE_CHECKING @@ -41,6 +42,8 @@ else: if TYPE_CHECKING: from .pq.abc import PGconn +_WAIT_INTERVAL = 0.1 + TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY @@ -293,20 +296,56 @@ class AsyncConnection(BaseConnection[Row]): async with tx: yield tx - async def notifies(self) -> AsyncGenerator[Notify, None]: + async def notifies( + self, *, timeout: Optional[float] = None, stop_after: Optional[int] = None + ) -> AsyncGenerator[Notify, None]: """ Yield `Notify` objects as soon as they are received from the database. + + :param timeout: maximum amount of time to wait for notifications. + `!None` means no timeout. + :param stop_after: stop after receiving this number of notifications. + You might actually receive more than this number if more than one + notifications arrives in the same packet. """ + # Allow interrupting the wait with a signal by reducing a long timeout + # into shorter interval. + if timeout is not None: + deadline = monotonic() + timeout + timeout = min(timeout, _WAIT_INTERVAL) + else: + deadline = None + timeout = _WAIT_INTERVAL + + nreceived = 0 + while True: - async with self.lock: - try: - ns = await self.wait(notifies(self.pgconn)) - except e._NO_TRACEBACK as ex: - raise ex.with_traceback(None) - enc = pgconn_encoding(self.pgconn) + # Collect notifications. Also get the connection encoding if any + # notification is received to makes sure that they are consistent. + try: + async with self.lock: + ns = await self.wait(notifies(self.pgconn), timeout=timeout) + if ns: + enc = pgconn_encoding(self.pgconn) + except e._NO_TRACEBACK as ex: + raise ex.with_traceback(None) + + # Emit the notifications received. for pgn in ns: n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid) yield n + nreceived += 1 + + # Stop if we have received enough notifications. + if stop_after is not None and nreceived >= stop_after: + break + + # Check the deadline after the loop to ensure that timeout=0 + # polls at least once. + if deadline: + timeout = min(_WAIT_INTERVAL, deadline - monotonic()) + if timeout < 0.0: + break @asynccontextmanager async def pipeline(self) -> AsyncIterator[AsyncPipeline]: @@ -328,7 +367,9 @@ class AsyncConnection(BaseConnection[Row]): assert pipeline is self._pipeline self._pipeline = None - async def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: + async def wait( + self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL + ) -> RV: """ Consume a generator operating on the connection. diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py index 150d77477..a14fb93e6 100644 --- a/tests/test_concurrency_async.py +++ b/tests/test_concurrency_async.py @@ -6,7 +6,7 @@ import threading import subprocess as sp from asyncio import create_task from asyncio.queues import Queue -from typing import List, Tuple +from typing import List import pytest @@ -58,50 +58,6 @@ async def test_concurrent_execution(aconn_cls, dsn): assert time.time() - t0 < 0.8, "something broken in concurrency" -@pytest.mark.slow -@pytest.mark.timing -@pytest.mark.crdb_skip("notify") -async def test_notifies(aconn_cls, aconn, dsn): - nconn = await aconn_cls.connect(dsn, autocommit=True) - npid = nconn.pgconn.backend_pid - - async def notifier(): - cur = nconn.cursor() - await asyncio.sleep(0.25) - await cur.execute("notify foo, '1'") - await asyncio.sleep(0.25) - await cur.execute("notify foo, '2'") - await nconn.close() - - async def receiver(): - await aconn.set_autocommit(True) - cur = aconn.cursor() - await cur.execute("listen foo") - gen = aconn.notifies() - async for n in gen: - ns.append((n, time.time())) - if len(ns) >= 2: - await gen.aclose() - - ns: List[Tuple[psycopg.Notify, float]] = [] - t0 = time.time() - workers = [notifier(), receiver()] - await asyncio.gather(*workers) - assert len(ns) == 2 - - n, t1 = ns[0] - assert n.pid == npid - assert n.channel == "foo" - assert n.payload == "1" - assert t1 - t0 == pytest.approx(0.25, abs=0.05) - - n, t1 = ns[1] - assert n.pid == npid - assert n.channel == "foo" - assert n.payload == "2" - assert t1 - t0 == pytest.approx(0.5, abs=0.05) - - async def canceller(aconn, errors): try: await asyncio.sleep(0.5) diff --git a/tests/test_connection.py b/tests/test_connection.py index 1d9217d2b..f17f8d2a0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -9,7 +9,7 @@ import weakref from typing import Any, List import psycopg -from psycopg import Notify, pq, errors as e +from psycopg import pq, errors as e from psycopg.rows import tuple_row from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo @@ -528,47 +528,6 @@ def test_notice_handlers(conn, caplog): conn.remove_notice_handler(cb1) -@pytest.mark.crdb_skip("notify") -def test_notify_handlers(conn): - nots1 = [] - nots2 = [] - - def cb1(n): - nots1.append(n) - - conn.add_notify_handler(cb1) - conn.add_notify_handler(lambda n: nots2.append(n)) - - conn.set_autocommit(True) - cur = conn.cursor() - cur.execute("listen foo") - cur.execute("notify foo, 'n1'") - - assert len(nots1) == 1 - n = nots1[0] - assert n.channel == "foo" - assert n.payload == "n1" - assert n.pid == conn.pgconn.backend_pid - - assert len(nots2) == 1 - assert nots2[0] == nots1[0] - - conn.remove_notify_handler(cb1) - cur.execute("notify foo, 'n2'") - - assert len(nots1) == 1 - assert len(nots2) == 2 - n = nots2[1] - assert isinstance(n, Notify) - assert n.channel == "foo" - assert n.payload == "n2" - assert n.pid == conn.pgconn.backend_pid - assert hash(n) - - with pytest.raises(ValueError): - conn.remove_notify_handler(cb1) - - def test_execute(conn): cur = conn.execute("select %s, %s", [10, 20]) assert cur.fetchone() == (10, 20) diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index a98f0f80b..d7aa7ca8b 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -6,7 +6,7 @@ import weakref from typing import Any, List import psycopg -from psycopg import Notify, pq, errors as e +from psycopg import pq, errors as e from psycopg.rows import tuple_row from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo @@ -526,47 +526,6 @@ async def test_notice_handlers(aconn, caplog): aconn.remove_notice_handler(cb1) -@pytest.mark.crdb_skip("notify") -async def test_notify_handlers(aconn): - nots1 = [] - nots2 = [] - - def cb1(n): - nots1.append(n) - - aconn.add_notify_handler(cb1) - aconn.add_notify_handler(lambda n: nots2.append(n)) - - await aconn.set_autocommit(True) - cur = aconn.cursor() - await cur.execute("listen foo") - await cur.execute("notify foo, 'n1'") - - assert len(nots1) == 1 - n = nots1[0] - assert n.channel == "foo" - assert n.payload == "n1" - assert n.pid == aconn.pgconn.backend_pid - - assert len(nots2) == 1 - assert nots2[0] == nots1[0] - - aconn.remove_notify_handler(cb1) - await cur.execute("notify foo, 'n2'") - - assert len(nots1) == 1 - assert len(nots2) == 2 - n = nots2[1] - assert isinstance(n, Notify) - assert n.channel == "foo" - assert n.payload == "n2" - assert n.pid == aconn.pgconn.backend_pid - assert hash(n) - - with pytest.raises(ValueError): - aconn.remove_notify_handler(cb1) - - async def test_execute(aconn): cur = await aconn.execute("select %s, %s", [10, 20]) assert await cur.fetchone() == (10, 20) diff --git a/tests/test_notify.py b/tests/test_notify.py new file mode 100644 index 000000000..c67b33115 --- /dev/null +++ b/tests/test_notify.py @@ -0,0 +1,195 @@ +# WARNING: this file is auto-generated by 'async_to_sync.py' +# from the original file 'test_notify_async.py' +# DO NOT CHANGE! Change the original file instead. +from __future__ import annotations + +from time import time + +import pytest +from psycopg import Notify + +from .acompat import sleep, gather, spawn + +pytestmark = pytest.mark.crdb_skip("notify") + + +def test_notify_handlers(conn): + nots1 = [] + nots2 = [] + + def cb1(n): + nots1.append(n) + + conn.add_notify_handler(cb1) + conn.add_notify_handler(lambda n: nots2.append(n)) + + conn.set_autocommit(True) + conn.execute("listen foo") + conn.execute("notify foo, 'n1'") + + assert len(nots1) == 1 + n = nots1[0] + assert n.channel == "foo" + assert n.payload == "n1" + assert n.pid == conn.pgconn.backend_pid + + assert len(nots2) == 1 + assert nots2[0] == nots1[0] + + conn.remove_notify_handler(cb1) + conn.execute("notify foo, 'n2'") + + assert len(nots1) == 1 + assert len(nots2) == 2 + n = nots2[1] + assert isinstance(n, Notify) + assert n.channel == "foo" + assert n.payload == "n2" + assert n.pid == conn.pgconn.backend_pid + assert hash(n) + + with pytest.raises(ValueError): + conn.remove_notify_handler(cb1) + + +@pytest.mark.slow +@pytest.mark.timing +def test_notify(conn_cls, conn, dsn): + npid = None + + def notifier(): + with conn_cls.connect(dsn, autocommit=True) as nconn: + nonlocal npid + npid = nconn.pgconn.backend_pid + + sleep(0.25) + nconn.execute("notify foo, '1'") + sleep(0.25) + nconn.execute("notify foo, '2'") + + def receiver(): + conn.set_autocommit(True) + cur = conn.cursor() + cur.execute("listen foo") + gen = conn.notifies() + for n in gen: + ns.append((n, time())) + if len(ns) >= 2: + gen.close() + + ns: list[tuple[Notify, float]] = [] + t0 = time() + workers = [spawn(notifier), spawn(receiver)] + gather(*workers) + assert len(ns) == 2 + + n, t1 = ns[0] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "1" + assert t1 - t0 == pytest.approx(0.25, abs=0.05) + + n, t1 = ns[1] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "2" + assert t1 - t0 == pytest.approx(0.5, abs=0.05) + + +@pytest.mark.slow +@pytest.mark.timing +def test_no_notify_timeout(conn): + conn.set_autocommit(True) + t0 = time() + for n in conn.notifies(timeout=0.5): + assert False + dt = time() - t0 + assert 0.5 <= dt < 0.75 + + +@pytest.mark.slow +@pytest.mark.timing +def test_notify_timeout(conn_cls, conn, dsn): + conn.set_autocommit(True) + conn.execute("listen foo") + + def notifier(): + with conn_cls.connect(dsn, autocommit=True) as nconn: + sleep(0.25) + nconn.execute("notify foo, '1'") + + worker = spawn(notifier) + try: + times = [time()] + for n in conn.notifies(timeout=0.5): + times.append(time()) + times.append(time()) + finally: + gather(worker) + + assert len(times) == 3 + assert times[1] - times[0] == pytest.approx(0.25, 0.1) + assert times[2] - times[1] == pytest.approx(0.25, 0.1) + + +@pytest.mark.slow +def test_notify_timeout_0(conn_cls, conn, dsn): + conn.set_autocommit(True) + conn.execute("listen foo") + + ns = list(conn.notifies(timeout=0)) + assert not ns + + with conn_cls.connect(dsn, autocommit=True) as nconn: + nconn.execute("notify foo, '1'") + sleep(0.1) + + ns = list(conn.notifies(timeout=0)) + assert len(ns) == 1 + + +@pytest.mark.slow +def test_stop_after(conn_cls, conn, dsn): + conn.set_autocommit(True) + conn.execute("listen foo") + + def notifier(): + with conn_cls.connect(dsn, autocommit=True) as nconn: + nconn.execute("notify foo, '1'") + sleep(0.1) + nconn.execute("notify foo, '2'") + sleep(0.1) + nconn.execute("notify foo, '3'") + + worker = spawn(notifier) + try: + ns = list(conn.notifies(timeout=1.0, stop_after=2)) + assert len(ns) == 2 + assert ns[0].payload == "1" + assert ns[1].payload == "2" + finally: + gather(worker) + + ns = list(conn.notifies(timeout=0.0)) + assert len(ns) == 1 + assert ns[0].payload == "3" + + +def test_stop_after_batch(conn_cls, conn, dsn): + conn.set_autocommit(True) + conn.execute("listen foo") + + def notifier(): + with conn_cls.connect(dsn, autocommit=True) as nconn: + with nconn.transaction(): + nconn.execute("notify foo, '1'") + nconn.execute("notify foo, '2'") + + worker = spawn(notifier) + try: + ns = list(conn.notifies(timeout=1.0, stop_after=1)) + assert len(ns) == 2 + assert ns[0].payload == "1" + assert ns[1].payload == "2" + finally: + gather(worker) diff --git a/tests/test_notify_async.py b/tests/test_notify_async.py new file mode 100644 index 000000000..f4f0901d6 --- /dev/null +++ b/tests/test_notify_async.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from time import time + +import pytest +from psycopg import Notify + +from .acompat import alist, asleep, gather, spawn + +pytestmark = pytest.mark.crdb_skip("notify") + + +async def test_notify_handlers(aconn): + nots1 = [] + nots2 = [] + + def cb1(n): + nots1.append(n) + + aconn.add_notify_handler(cb1) + aconn.add_notify_handler(lambda n: nots2.append(n)) + + await aconn.set_autocommit(True) + await aconn.execute("listen foo") + await aconn.execute("notify foo, 'n1'") + + assert len(nots1) == 1 + n = nots1[0] + assert n.channel == "foo" + assert n.payload == "n1" + assert n.pid == aconn.pgconn.backend_pid + + assert len(nots2) == 1 + assert nots2[0] == nots1[0] + + aconn.remove_notify_handler(cb1) + await aconn.execute("notify foo, 'n2'") + + assert len(nots1) == 1 + assert len(nots2) == 2 + n = nots2[1] + assert isinstance(n, Notify) + assert n.channel == "foo" + assert n.payload == "n2" + assert n.pid == aconn.pgconn.backend_pid + assert hash(n) + + with pytest.raises(ValueError): + aconn.remove_notify_handler(cb1) + + +@pytest.mark.slow +@pytest.mark.timing +async def test_notify(aconn_cls, aconn, dsn): + npid = None + + async def notifier(): + async with await aconn_cls.connect(dsn, autocommit=True) as nconn: + nonlocal npid + npid = nconn.pgconn.backend_pid + + await asleep(0.25) + await nconn.execute("notify foo, '1'") + await asleep(0.25) + await nconn.execute("notify foo, '2'") + + async def receiver(): + await aconn.set_autocommit(True) + cur = aconn.cursor() + await cur.execute("listen foo") + gen = aconn.notifies() + async for n in gen: + ns.append((n, time())) + if len(ns) >= 2: + await gen.aclose() + + ns: list[tuple[Notify, float]] = [] + t0 = time() + workers = [spawn(notifier), spawn(receiver)] + await gather(*workers) + assert len(ns) == 2 + + n, t1 = ns[0] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "1" + assert t1 - t0 == pytest.approx(0.25, abs=0.05) + + n, t1 = ns[1] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "2" + assert t1 - t0 == pytest.approx(0.5, abs=0.05) + + +@pytest.mark.slow +@pytest.mark.timing +async def test_no_notify_timeout(aconn): + await aconn.set_autocommit(True) + t0 = time() + async for n in aconn.notifies(timeout=0.5): + assert False + dt = time() - t0 + assert 0.5 <= dt < 0.75 + + +@pytest.mark.slow +@pytest.mark.timing +async def test_notify_timeout(aconn_cls, aconn, dsn): + await aconn.set_autocommit(True) + await aconn.execute("listen foo") + + async def notifier(): + async with await aconn_cls.connect(dsn, autocommit=True) as nconn: + await asleep(0.25) + await nconn.execute("notify foo, '1'") + + worker = spawn(notifier) + try: + times = [time()] + async for n in aconn.notifies(timeout=0.5): + times.append(time()) + times.append(time()) + finally: + await gather(worker) + + assert len(times) == 3 + assert times[1] - times[0] == pytest.approx(0.25, 0.1) + assert times[2] - times[1] == pytest.approx(0.25, 0.1) + + +@pytest.mark.slow +async def test_notify_timeout_0(aconn_cls, aconn, dsn): + await aconn.set_autocommit(True) + await aconn.execute("listen foo") + + ns = await alist(aconn.notifies(timeout=0)) + assert not ns + + async with await aconn_cls.connect(dsn, autocommit=True) as nconn: + await nconn.execute("notify foo, '1'") + await asleep(0.1) + + ns = await alist(aconn.notifies(timeout=0)) + assert len(ns) == 1 + + +@pytest.mark.slow +async def test_stop_after(aconn_cls, aconn, dsn): + await aconn.set_autocommit(True) + await aconn.execute("listen foo") + + async def notifier(): + async with await aconn_cls.connect(dsn, autocommit=True) as nconn: + await nconn.execute("notify foo, '1'") + await asleep(0.1) + await nconn.execute("notify foo, '2'") + await asleep(0.1) + await nconn.execute("notify foo, '3'") + + worker = spawn(notifier) + try: + ns = await alist(aconn.notifies(timeout=1.0, stop_after=2)) + assert len(ns) == 2 + assert ns[0].payload == "1" + assert ns[1].payload == "2" + finally: + await gather(worker) + + ns = await alist(aconn.notifies(timeout=0.0)) + assert len(ns) == 1 + assert ns[0].payload == "3" + + +async def test_stop_after_batch(aconn_cls, aconn, dsn): + await aconn.set_autocommit(True) + await aconn.execute("listen foo") + + async def notifier(): + async with await aconn_cls.connect(dsn, autocommit=True) as nconn: + async with nconn.transaction(): + await nconn.execute("notify foo, '1'") + await nconn.execute("notify foo, '2'") + + worker = spawn(notifier) + try: + ns = await alist(aconn.notifies(timeout=1.0, stop_after=1)) + assert len(ns) == 2 + assert ns[0].payload == "1" + assert ns[1].payload == "2" + finally: + await gather(worker) diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index d264c78bd..c4c053a1a 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -48,6 +48,7 @@ ALL_INPUTS = """ tests/test_cursor_common_async.py tests/test_cursor_raw_async.py tests/test_cursor_server_async.py + tests/test_notify_async.py tests/test_pipeline_async.py tests/test_prepared_async.py tests/test_tpc_async.py