]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add `timeout` and `stop_after` parameters to Connection.notifies
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Oct 2023 21:51:17 +0000 (23:51 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 29 Jan 2024 02:25:32 +0000 (02:25 +0000)
Close #340

psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_concurrency_async.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_notify.py [new file with mode: 0644]
tests/test_notify_async.py [new file with mode: 0644]
tools/async_to_sync.py

index 2c503a1267de44779be26c20f1e10ce2d206e582..cb0244aa504d5c7dc617dc16a7bcfb057a6218ba 100644 (file)
@@ -10,6 +10,7 @@ Psycopg connection object (sync version)
 from __future__ import annotations
 
 import logging
+from time import monotonic
 from types import TracebackType
 from typing import Any, Generator, Iterator, List, Optional
 from typing import Type, Union, cast, overload, TYPE_CHECKING
@@ -39,6 +40,8 @@ from threading import Lock
 if TYPE_CHECKING:
     from .pq.abc import PGconn
 
+_WAIT_INTERVAL = 0.1
+
 TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
@@ -277,20 +280,56 @@ class Connection(BaseConnection[Row]):
             with tx:
                 yield tx
 
-    def notifies(self) -> Generator[Notify, None, None]:
+    def notifies(
+        self, *, timeout: Optional[float] = None, stop_after: Optional[int] = None
+    ) -> Generator[Notify, None, None]:
         """
         Yield `Notify` objects as soon as they are received from the database.
+
+        :param timeout: maximum amount of time to wait for notifications.
+            `!None` means no timeout.
+        :param stop_after: stop after receiving this number of notifications.
+            You might actually receive more than this number if more than one
+            notifications arrives in the same packet.
         """
+        # Allow interrupting the wait with a signal by reducing a long timeout
+        # into shorter interval.
+        if timeout is not None:
+            deadline = monotonic() + timeout
+            timeout = min(timeout, _WAIT_INTERVAL)
+        else:
+            deadline = None
+            timeout = _WAIT_INTERVAL
+
+        nreceived = 0
+
         while True:
-            with self.lock:
-                try:
-                    ns = self.wait(notifies(self.pgconn))
-                except e._NO_TRACEBACK as ex:
-                    raise ex.with_traceback(None)
-            enc = pgconn_encoding(self.pgconn)
+            # Collect notifications. Also get the connection encoding if any
+            # notification is received to makes sure that they are consistent.
+            try:
+                with self.lock:
+                    ns = self.wait(notifies(self.pgconn), timeout=timeout)
+                    if ns:
+                        enc = pgconn_encoding(self.pgconn)
+            except e._NO_TRACEBACK as ex:
+                raise ex.with_traceback(None)
+
+            # Emit the notifications received.
             for pgn in ns:
                 n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
                 yield n
+                nreceived += 1
+
+            # Stop if we have received enough notifications.
+            if stop_after is not None and nreceived >= stop_after:
+                break
+
+            # Check the deadline after the loop to ensure that timeout=0
+            # polls at least once.
+            if deadline:
+                timeout = min(_WAIT_INTERVAL, deadline - monotonic())
+                if timeout < 0.0:
+                    break
 
     @contextmanager
     def pipeline(self) -> Iterator[Pipeline]:
@@ -312,7 +351,7 @@ class Connection(BaseConnection[Row]):
                     assert pipeline is self._pipeline
                     self._pipeline = None
 
-    def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+    def wait(self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL) -> RV:
         """
         Consume a generator operating on the connection.
 
index 3a57df3750635a48d2143fb5538c2a8c4fd58b57..d810d45b29850dcd8cead444a863f25e85a51adc 100644 (file)
@@ -7,6 +7,7 @@ Psycopg connection object (async version)
 from __future__ import annotations
 
 import logging
+from time import monotonic
 from types import TracebackType
 from typing import Any, AsyncGenerator, AsyncIterator, List, Optional
 from typing import Type, Union, cast, overload, TYPE_CHECKING
@@ -41,6 +42,8 @@ else:
 if TYPE_CHECKING:
     from .pq.abc import PGconn
 
+_WAIT_INTERVAL = 0.1
+
 TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
@@ -293,20 +296,56 @@ class AsyncConnection(BaseConnection[Row]):
             async with tx:
                 yield tx
 
-    async def notifies(self) -> AsyncGenerator[Notify, None]:
+    async def notifies(
+        self, *, timeout: Optional[float] = None, stop_after: Optional[int] = None
+    ) -> AsyncGenerator[Notify, None]:
         """
         Yield `Notify` objects as soon as they are received from the database.
+
+        :param timeout: maximum amount of time to wait for notifications.
+            `!None` means no timeout.
+        :param stop_after: stop after receiving this number of notifications.
+            You might actually receive more than this number if more than one
+            notifications arrives in the same packet.
         """
+        # Allow interrupting the wait with a signal by reducing a long timeout
+        # into shorter interval.
+        if timeout is not None:
+            deadline = monotonic() + timeout
+            timeout = min(timeout, _WAIT_INTERVAL)
+        else:
+            deadline = None
+            timeout = _WAIT_INTERVAL
+
+        nreceived = 0
+
         while True:
-            async with self.lock:
-                try:
-                    ns = await self.wait(notifies(self.pgconn))
-                except e._NO_TRACEBACK as ex:
-                    raise ex.with_traceback(None)
-            enc = pgconn_encoding(self.pgconn)
+            # Collect notifications. Also get the connection encoding if any
+            # notification is received to makes sure that they are consistent.
+            try:
+                async with self.lock:
+                    ns = await self.wait(notifies(self.pgconn), timeout=timeout)
+                    if ns:
+                        enc = pgconn_encoding(self.pgconn)
+            except e._NO_TRACEBACK as ex:
+                raise ex.with_traceback(None)
+
+            # Emit the notifications received.
             for pgn in ns:
                 n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
                 yield n
+                nreceived += 1
+
+            # Stop if we have received enough notifications.
+            if stop_after is not None and nreceived >= stop_after:
+                break
+
+            # Check the deadline after the loop to ensure that timeout=0
+            # polls at least once.
+            if deadline:
+                timeout = min(_WAIT_INTERVAL, deadline - monotonic())
+                if timeout < 0.0:
+                    break
 
     @asynccontextmanager
     async def pipeline(self) -> AsyncIterator[AsyncPipeline]:
@@ -328,7 +367,9 @@ class AsyncConnection(BaseConnection[Row]):
                     assert pipeline is self._pipeline
                     self._pipeline = None
 
-    async def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
+    async def wait(
+        self, gen: PQGen[RV], timeout: Optional[float] = _WAIT_INTERVAL
+    ) -> RV:
         """
         Consume a generator operating on the connection.
 
index 150d7747793f2b82d4355ae9f28c450de92c38a6..a14fb93e6569bbc5a53dc819e897be2a539d287d 100644 (file)
@@ -6,7 +6,7 @@ import threading
 import subprocess as sp
 from asyncio import create_task
 from asyncio.queues import Queue
-from typing import List, Tuple
+from typing import List
 
 import pytest
 
@@ -58,50 +58,6 @@ async def test_concurrent_execution(aconn_cls, dsn):
     assert time.time() - t0 < 0.8, "something broken in concurrency"
 
 
-@pytest.mark.slow
-@pytest.mark.timing
-@pytest.mark.crdb_skip("notify")
-async def test_notifies(aconn_cls, aconn, dsn):
-    nconn = await aconn_cls.connect(dsn, autocommit=True)
-    npid = nconn.pgconn.backend_pid
-
-    async def notifier():
-        cur = nconn.cursor()
-        await asyncio.sleep(0.25)
-        await cur.execute("notify foo, '1'")
-        await asyncio.sleep(0.25)
-        await cur.execute("notify foo, '2'")
-        await nconn.close()
-
-    async def receiver():
-        await aconn.set_autocommit(True)
-        cur = aconn.cursor()
-        await cur.execute("listen foo")
-        gen = aconn.notifies()
-        async for n in gen:
-            ns.append((n, time.time()))
-            if len(ns) >= 2:
-                await gen.aclose()
-
-    ns: List[Tuple[psycopg.Notify, float]] = []
-    t0 = time.time()
-    workers = [notifier(), receiver()]
-    await asyncio.gather(*workers)
-    assert len(ns) == 2
-
-    n, t1 = ns[0]
-    assert n.pid == npid
-    assert n.channel == "foo"
-    assert n.payload == "1"
-    assert t1 - t0 == pytest.approx(0.25, abs=0.05)
-
-    n, t1 = ns[1]
-    assert n.pid == npid
-    assert n.channel == "foo"
-    assert n.payload == "2"
-    assert t1 - t0 == pytest.approx(0.5, abs=0.05)
-
-
 async def canceller(aconn, errors):
     try:
         await asyncio.sleep(0.5)
index 1d9217d2b9893d3a4f3fb04a3753afa5befd41f5..f17f8d2a06f85f3097dad4e685e740b52b9ba732 100644 (file)
@@ -9,7 +9,7 @@ import weakref
 from typing import Any, List
 
 import psycopg
-from psycopg import Notify, pq, errors as e
+from psycopg import pq, errors as e
 from psycopg.rows import tuple_row
 from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo
 
@@ -528,47 +528,6 @@ def test_notice_handlers(conn, caplog):
         conn.remove_notice_handler(cb1)
 
 
-@pytest.mark.crdb_skip("notify")
-def test_notify_handlers(conn):
-    nots1 = []
-    nots2 = []
-
-    def cb1(n):
-        nots1.append(n)
-
-    conn.add_notify_handler(cb1)
-    conn.add_notify_handler(lambda n: nots2.append(n))
-
-    conn.set_autocommit(True)
-    cur = conn.cursor()
-    cur.execute("listen foo")
-    cur.execute("notify foo, 'n1'")
-
-    assert len(nots1) == 1
-    n = nots1[0]
-    assert n.channel == "foo"
-    assert n.payload == "n1"
-    assert n.pid == conn.pgconn.backend_pid
-
-    assert len(nots2) == 1
-    assert nots2[0] == nots1[0]
-
-    conn.remove_notify_handler(cb1)
-    cur.execute("notify foo, 'n2'")
-
-    assert len(nots1) == 1
-    assert len(nots2) == 2
-    n = nots2[1]
-    assert isinstance(n, Notify)
-    assert n.channel == "foo"
-    assert n.payload == "n2"
-    assert n.pid == conn.pgconn.backend_pid
-    assert hash(n)
-
-    with pytest.raises(ValueError):
-        conn.remove_notify_handler(cb1)
-
-
 def test_execute(conn):
     cur = conn.execute("select %s, %s", [10, 20])
     assert cur.fetchone() == (10, 20)
index a98f0f80b2de6cf08c1f2a1211740d977241d2c1..d7aa7ca8b41898e927a5e1c9dbe15fc30420d75c 100644 (file)
@@ -6,7 +6,7 @@ import weakref
 from typing import Any, List
 
 import psycopg
-from psycopg import Notify, pq, errors as e
+from psycopg import pq, errors as e
 from psycopg.rows import tuple_row
 from psycopg.conninfo import conninfo_to_dict, timeout_from_conninfo
 
@@ -526,47 +526,6 @@ async def test_notice_handlers(aconn, caplog):
         aconn.remove_notice_handler(cb1)
 
 
-@pytest.mark.crdb_skip("notify")
-async def test_notify_handlers(aconn):
-    nots1 = []
-    nots2 = []
-
-    def cb1(n):
-        nots1.append(n)
-
-    aconn.add_notify_handler(cb1)
-    aconn.add_notify_handler(lambda n: nots2.append(n))
-
-    await aconn.set_autocommit(True)
-    cur = aconn.cursor()
-    await cur.execute("listen foo")
-    await cur.execute("notify foo, 'n1'")
-
-    assert len(nots1) == 1
-    n = nots1[0]
-    assert n.channel == "foo"
-    assert n.payload == "n1"
-    assert n.pid == aconn.pgconn.backend_pid
-
-    assert len(nots2) == 1
-    assert nots2[0] == nots1[0]
-
-    aconn.remove_notify_handler(cb1)
-    await cur.execute("notify foo, 'n2'")
-
-    assert len(nots1) == 1
-    assert len(nots2) == 2
-    n = nots2[1]
-    assert isinstance(n, Notify)
-    assert n.channel == "foo"
-    assert n.payload == "n2"
-    assert n.pid == aconn.pgconn.backend_pid
-    assert hash(n)
-
-    with pytest.raises(ValueError):
-        aconn.remove_notify_handler(cb1)
-
-
 async def test_execute(aconn):
     cur = await aconn.execute("select %s, %s", [10, 20])
     assert await cur.fetchone() == (10, 20)
diff --git a/tests/test_notify.py b/tests/test_notify.py
new file mode 100644 (file)
index 0000000..c67b331
--- /dev/null
@@ -0,0 +1,195 @@
+# WARNING: this file is auto-generated by 'async_to_sync.py'
+# from the original file 'test_notify_async.py'
+# DO NOT CHANGE! Change the original file instead.
+from __future__ import annotations
+
+from time import time
+
+import pytest
+from psycopg import Notify
+
+from .acompat import sleep, gather, spawn
+
+pytestmark = pytest.mark.crdb_skip("notify")
+
+
+def test_notify_handlers(conn):
+    nots1 = []
+    nots2 = []
+
+    def cb1(n):
+        nots1.append(n)
+
+    conn.add_notify_handler(cb1)
+    conn.add_notify_handler(lambda n: nots2.append(n))
+
+    conn.set_autocommit(True)
+    conn.execute("listen foo")
+    conn.execute("notify foo, 'n1'")
+
+    assert len(nots1) == 1
+    n = nots1[0]
+    assert n.channel == "foo"
+    assert n.payload == "n1"
+    assert n.pid == conn.pgconn.backend_pid
+
+    assert len(nots2) == 1
+    assert nots2[0] == nots1[0]
+
+    conn.remove_notify_handler(cb1)
+    conn.execute("notify foo, 'n2'")
+
+    assert len(nots1) == 1
+    assert len(nots2) == 2
+    n = nots2[1]
+    assert isinstance(n, Notify)
+    assert n.channel == "foo"
+    assert n.payload == "n2"
+    assert n.pid == conn.pgconn.backend_pid
+    assert hash(n)
+
+    with pytest.raises(ValueError):
+        conn.remove_notify_handler(cb1)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_notify(conn_cls, conn, dsn):
+    npid = None
+
+    def notifier():
+        with conn_cls.connect(dsn, autocommit=True) as nconn:
+            nonlocal npid
+            npid = nconn.pgconn.backend_pid
+
+            sleep(0.25)
+            nconn.execute("notify foo, '1'")
+            sleep(0.25)
+            nconn.execute("notify foo, '2'")
+
+    def receiver():
+        conn.set_autocommit(True)
+        cur = conn.cursor()
+        cur.execute("listen foo")
+        gen = conn.notifies()
+        for n in gen:
+            ns.append((n, time()))
+            if len(ns) >= 2:
+                gen.close()
+
+    ns: list[tuple[Notify, float]] = []
+    t0 = time()
+    workers = [spawn(notifier), spawn(receiver)]
+    gather(*workers)
+    assert len(ns) == 2
+
+    n, t1 = ns[0]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "1"
+    assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+    n, t1 = ns[1]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "2"
+    assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_no_notify_timeout(conn):
+    conn.set_autocommit(True)
+    t0 = time()
+    for n in conn.notifies(timeout=0.5):
+        assert False
+    dt = time() - t0
+    assert 0.5 <= dt < 0.75
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+def test_notify_timeout(conn_cls, conn, dsn):
+    conn.set_autocommit(True)
+    conn.execute("listen foo")
+
+    def notifier():
+        with conn_cls.connect(dsn, autocommit=True) as nconn:
+            sleep(0.25)
+            nconn.execute("notify foo, '1'")
+
+    worker = spawn(notifier)
+    try:
+        times = [time()]
+        for n in conn.notifies(timeout=0.5):
+            times.append(time())
+        times.append(time())
+    finally:
+        gather(worker)
+
+    assert len(times) == 3
+    assert times[1] - times[0] == pytest.approx(0.25, 0.1)
+    assert times[2] - times[1] == pytest.approx(0.25, 0.1)
+
+
+@pytest.mark.slow
+def test_notify_timeout_0(conn_cls, conn, dsn):
+    conn.set_autocommit(True)
+    conn.execute("listen foo")
+
+    ns = list(conn.notifies(timeout=0))
+    assert not ns
+
+    with conn_cls.connect(dsn, autocommit=True) as nconn:
+        nconn.execute("notify foo, '1'")
+        sleep(0.1)
+
+    ns = list(conn.notifies(timeout=0))
+    assert len(ns) == 1
+
+
+@pytest.mark.slow
+def test_stop_after(conn_cls, conn, dsn):
+    conn.set_autocommit(True)
+    conn.execute("listen foo")
+
+    def notifier():
+        with conn_cls.connect(dsn, autocommit=True) as nconn:
+            nconn.execute("notify foo, '1'")
+            sleep(0.1)
+            nconn.execute("notify foo, '2'")
+            sleep(0.1)
+            nconn.execute("notify foo, '3'")
+
+    worker = spawn(notifier)
+    try:
+        ns = list(conn.notifies(timeout=1.0, stop_after=2))
+        assert len(ns) == 2
+        assert ns[0].payload == "1"
+        assert ns[1].payload == "2"
+    finally:
+        gather(worker)
+
+    ns = list(conn.notifies(timeout=0.0))
+    assert len(ns) == 1
+    assert ns[0].payload == "3"
+
+
+def test_stop_after_batch(conn_cls, conn, dsn):
+    conn.set_autocommit(True)
+    conn.execute("listen foo")
+
+    def notifier():
+        with conn_cls.connect(dsn, autocommit=True) as nconn:
+            with nconn.transaction():
+                nconn.execute("notify foo, '1'")
+                nconn.execute("notify foo, '2'")
+
+    worker = spawn(notifier)
+    try:
+        ns = list(conn.notifies(timeout=1.0, stop_after=1))
+        assert len(ns) == 2
+        assert ns[0].payload == "1"
+        assert ns[1].payload == "2"
+    finally:
+        gather(worker)
diff --git a/tests/test_notify_async.py b/tests/test_notify_async.py
new file mode 100644 (file)
index 0000000..f4f0901
--- /dev/null
@@ -0,0 +1,192 @@
+from __future__ import annotations
+
+from time import time
+
+import pytest
+from psycopg import Notify
+
+from .acompat import alist, asleep, gather, spawn
+
+pytestmark = pytest.mark.crdb_skip("notify")
+
+
+async def test_notify_handlers(aconn):
+    nots1 = []
+    nots2 = []
+
+    def cb1(n):
+        nots1.append(n)
+
+    aconn.add_notify_handler(cb1)
+    aconn.add_notify_handler(lambda n: nots2.append(n))
+
+    await aconn.set_autocommit(True)
+    await aconn.execute("listen foo")
+    await aconn.execute("notify foo, 'n1'")
+
+    assert len(nots1) == 1
+    n = nots1[0]
+    assert n.channel == "foo"
+    assert n.payload == "n1"
+    assert n.pid == aconn.pgconn.backend_pid
+
+    assert len(nots2) == 1
+    assert nots2[0] == nots1[0]
+
+    aconn.remove_notify_handler(cb1)
+    await aconn.execute("notify foo, 'n2'")
+
+    assert len(nots1) == 1
+    assert len(nots2) == 2
+    n = nots2[1]
+    assert isinstance(n, Notify)
+    assert n.channel == "foo"
+    assert n.payload == "n2"
+    assert n.pid == aconn.pgconn.backend_pid
+    assert hash(n)
+
+    with pytest.raises(ValueError):
+        aconn.remove_notify_handler(cb1)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_notify(aconn_cls, aconn, dsn):
+    npid = None
+
+    async def notifier():
+        async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+            nonlocal npid
+            npid = nconn.pgconn.backend_pid
+
+            await asleep(0.25)
+            await nconn.execute("notify foo, '1'")
+            await asleep(0.25)
+            await nconn.execute("notify foo, '2'")
+
+    async def receiver():
+        await aconn.set_autocommit(True)
+        cur = aconn.cursor()
+        await cur.execute("listen foo")
+        gen = aconn.notifies()
+        async for n in gen:
+            ns.append((n, time()))
+            if len(ns) >= 2:
+                await gen.aclose()
+
+    ns: list[tuple[Notify, float]] = []
+    t0 = time()
+    workers = [spawn(notifier), spawn(receiver)]
+    await gather(*workers)
+    assert len(ns) == 2
+
+    n, t1 = ns[0]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "1"
+    assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+    n, t1 = ns[1]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "2"
+    assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_no_notify_timeout(aconn):
+    await aconn.set_autocommit(True)
+    t0 = time()
+    async for n in aconn.notifies(timeout=0.5):
+        assert False
+    dt = time() - t0
+    assert 0.5 <= dt < 0.75
+
+
+@pytest.mark.slow
+@pytest.mark.timing
+async def test_notify_timeout(aconn_cls, aconn, dsn):
+    await aconn.set_autocommit(True)
+    await aconn.execute("listen foo")
+
+    async def notifier():
+        async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+            await asleep(0.25)
+            await nconn.execute("notify foo, '1'")
+
+    worker = spawn(notifier)
+    try:
+        times = [time()]
+        async for n in aconn.notifies(timeout=0.5):
+            times.append(time())
+        times.append(time())
+    finally:
+        await gather(worker)
+
+    assert len(times) == 3
+    assert times[1] - times[0] == pytest.approx(0.25, 0.1)
+    assert times[2] - times[1] == pytest.approx(0.25, 0.1)
+
+
+@pytest.mark.slow
+async def test_notify_timeout_0(aconn_cls, aconn, dsn):
+    await aconn.set_autocommit(True)
+    await aconn.execute("listen foo")
+
+    ns = await alist(aconn.notifies(timeout=0))
+    assert not ns
+
+    async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+        await nconn.execute("notify foo, '1'")
+        await asleep(0.1)
+
+    ns = await alist(aconn.notifies(timeout=0))
+    assert len(ns) == 1
+
+
+@pytest.mark.slow
+async def test_stop_after(aconn_cls, aconn, dsn):
+    await aconn.set_autocommit(True)
+    await aconn.execute("listen foo")
+
+    async def notifier():
+        async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+            await nconn.execute("notify foo, '1'")
+            await asleep(0.1)
+            await nconn.execute("notify foo, '2'")
+            await asleep(0.1)
+            await nconn.execute("notify foo, '3'")
+
+    worker = spawn(notifier)
+    try:
+        ns = await alist(aconn.notifies(timeout=1.0, stop_after=2))
+        assert len(ns) == 2
+        assert ns[0].payload == "1"
+        assert ns[1].payload == "2"
+    finally:
+        await gather(worker)
+
+    ns = await alist(aconn.notifies(timeout=0.0))
+    assert len(ns) == 1
+    assert ns[0].payload == "3"
+
+
+async def test_stop_after_batch(aconn_cls, aconn, dsn):
+    await aconn.set_autocommit(True)
+    await aconn.execute("listen foo")
+
+    async def notifier():
+        async with await aconn_cls.connect(dsn, autocommit=True) as nconn:
+            async with nconn.transaction():
+                await nconn.execute("notify foo, '1'")
+                await nconn.execute("notify foo, '2'")
+
+    worker = spawn(notifier)
+    try:
+        ns = await alist(aconn.notifies(timeout=1.0, stop_after=1))
+        assert len(ns) == 2
+        assert ns[0].payload == "1"
+        assert ns[1].payload == "2"
+    finally:
+        await gather(worker)
index d264c78bd20e5339eaa83516ff9f205f5ff1cb97..c4c053a1ab78b91de4f0f3ac8b6d50439a55692d 100755 (executable)
@@ -48,6 +48,7 @@ ALL_INPUTS = """
     tests/test_cursor_common_async.py
     tests/test_cursor_raw_async.py
     tests/test_cursor_server_async.py
+    tests/test_notify_async.py
     tests/test_pipeline_async.py
     tests/test_prepared_async.py
     tests/test_tpc_async.py