From: Daniele Varrazzo Date: Mon, 25 May 2020 06:13:44 +0000 (+1200) Subject: Added notification handling in connections X-Git-Tag: 3.0.dev0~489 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=dc378b8f6d594edeb4eaf43354d216497a7be16d;p=thirdparty%2Fpsycopg.git Added notification handling in connections Added both a callback systen and an explicit generator. I'll share the design on the ML and ask for comments. --- diff --git a/psycopg3/connection.py b/psycopg3/connection.py index ad319fd43..917870111 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -8,7 +8,8 @@ import codecs import logging import asyncio import threading -from typing import Any, Callable, List, Optional, Type, cast +from typing import Any, AsyncGenerator, Callable, Generator, List, NamedTuple +from typing import Optional, Type, cast from weakref import ref, ReferenceType from functools import partial @@ -19,11 +20,11 @@ from . import proto from .pq import TransactionStatus, ExecStatus from .conninfo import make_conninfo from .waiting import wait, wait_async +from .generators import notifies logger = logging.getLogger(__name__) package_logger = logging.getLogger("psycopg3") - connect: Callable[[str], proto.PQGen[pq.proto.PGconn]] execute: Callable[[pq.proto.PGconn], proto.PQGen[List[pq.proto.PGresult]]] @@ -39,7 +40,15 @@ else: connect = generators.connect execute = generators.execute + +class Notify(NamedTuple): + channel: str + payload: str + pid: int + + NoticeHandler = Callable[[e.Diagnostic], None] +NotifyHandler = Callable[[Notify], None] class BaseConnection: @@ -73,12 +82,14 @@ class BaseConnection: self.dumpers: proto.DumpersMap = {} self.loaders: proto.LoadersMap = {} self._notice_handlers: List[NoticeHandler] = [] + self._notify_handlers: List[NotifyHandler] = [] # name of the postgres encoding (in bytes) self._pgenc = b"" wself = ref(self) pgconn.notice_handler = partial(BaseConnection._notice_handler, wself) + pgconn.notify_handler = partial(BaseConnection._notify_handler, wself) @property def closed(self) -> bool: @@ -161,6 +172,25 @@ class BaseConnection: "error processing notice callback '%s': %s", cb, ex ) + def add_notify_handler(self, callback: NotifyHandler) -> None: + self._notify_handlers.append(callback) + + def remove_notify_handler(self, callback: NotifyHandler) -> None: + self._notify_handlers.remove(callback) + + @staticmethod + def _notify_handler( + wself: "ReferenceType[BaseConnection]", pgn: pq.PGnotify + ) -> None: + self = wself() + if self is None or not self._notify_handlers: + return + + decode = self.codec.decode + n = Notify(decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid) + for cb in self._notify_handlers: + cb(n) + class Connection(BaseConnection): """ @@ -254,6 +284,19 @@ class Connection(BaseConnection): if result.status != ExecStatus.TUPLES_OK: raise e.error_from_result(result) + def notifies(self) -> Generator[Optional[Notify], bool, None]: + decode = self.codec.decode + while 1: + with self.lock: + ns = self.wait(notifies(self.pgconn)) + for pgn in ns: + n = Notify( + decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid + ) + if (yield n): + yield None # for the send who stopped us + return + class AsyncConnection(BaseConnection): """ @@ -345,3 +388,16 @@ class AsyncConnection(BaseConnection): (result,) = await self.wait(gen) if result.status != ExecStatus.TUPLES_OK: raise e.error_from_result(result) + + async def notifies(self) -> AsyncGenerator[Optional[Notify], bool]: + decode = self.codec.decode + while 1: + async with self.lock: + ns = await self.wait(notifies(self.pgconn)) + for pgn in ns: + n = Notify( + decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid + ) + if (yield n): + yield None + return diff --git a/psycopg3/generators.py b/psycopg3/generators.py index cf7561490..3fd82e6be 100644 --- a/psycopg3/generators.py +++ b/psycopg3/generators.py @@ -92,6 +92,8 @@ def send(pgconn: pq.proto.PGconn) -> PQGen[None]: ready = yield pgconn.socket, Wait.RW if ready & Ready.R: + # This call may read notifies: they will be saved in the + # PGconn buffer and passed to Python later, in `fetch()`. pgconn.consume_input() continue @@ -113,6 +115,15 @@ def fetch(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]: if pgconn.is_busy(): yield pgconn.socket, Wait.R continue + + # Consume notifies + while 1: + n = pgconn.notifies() + if n is None: + break + if pgconn.notify_handler is not None: + pgconn.notify_handler(n) + res = pgconn.get_result() if res is None: break @@ -123,3 +134,18 @@ def fetch(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]: break return results + + +def notifies(pgconn: pq.proto.PGconn) -> PQGen[List[pq.PGnotify]]: + yield pgconn.socket, Wait.R + pgconn.consume_input() + + ns = [] + while 1: + n = pgconn.notifies() + if n is not None: + ns.append(n) + else: + break + + return ns diff --git a/psycopg3/generators.pyx b/psycopg3/generators.pyx index 4eb9a5684..b11f5be2c 100644 --- a/psycopg3/generators.pyx +++ b/psycopg3/generators.pyx @@ -69,6 +69,7 @@ def execute(PGconn pgconn) -> PQGen[List[pq.proto.PGresult]]: results: List[pq.proto.PGresult] = [] cdef libpq.PGconn *pgconn_ptr = pgconn.pgconn_ptr cdef int status + cdef libpq.PGnotify *notify # Sending the query while 1: @@ -77,6 +78,8 @@ def execute(PGconn pgconn) -> PQGen[List[pq.proto.PGresult]]: status = yield libpq.PQsocket(pgconn_ptr), WAIT_RW if status & READY_R: + # This call may read notifies which will be saved in the + # PGconn buffer and passed to Python later. if 1 != libpq.PQconsumeInput(pgconn_ptr): raise pq.PQerror( f"consuming input failed: {pq.error_message(pgconn)}") @@ -93,6 +96,20 @@ def execute(PGconn pgconn) -> PQGen[List[pq.proto.PGresult]]: yield wr continue + # Consume notifies + if pgconn.notify_handler: + while 1: + pynotify = pgconn.notifies() + if pynotify is None: + break + pgconn.notify_handler(pynotify) + else: + while 1: + notify = libpq.PQnotifies(pgconn_ptr) + if notify is NULL: + break + libpq.PQfreemem(notify) + res = libpq.PQgetResult(pgconn_ptr) if res is NULL: break diff --git a/psycopg3/pq/__init__.py b/psycopg3/pq/__init__.py index 16463e9c2..e323039f1 100644 --- a/psycopg3/pq/__init__.py +++ b/psycopg3/pq/__init__.py @@ -23,7 +23,7 @@ from .enums import ( Format, ) from .encodings import py_codecs -from .misc import error_message, ConninfoOption, PQerror +from .misc import error_message, ConninfoOption, PQerror, PGnotify from . import proto logger = logging.getLogger(__name__) @@ -100,6 +100,7 @@ __all__ = ( "DiagnosticField", "Format", "PGconn", + "PGnotify", "Conninfo", "PQerror", "error_message", diff --git a/psycopg3/pq/pq_ctypes.py b/psycopg3/pq/pq_ctypes.py index e86d998ad..25cd3d43f 100644 --- a/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/pq/pq_ctypes.py @@ -62,6 +62,7 @@ class PGconn: __slots__ = ( "pgconn_ptr", "notice_handler", + "notify_handler", "_notice_receiver", "_procpid", "__weakref__", @@ -69,7 +70,10 @@ class PGconn: def __init__(self, pgconn_ptr: impl.PGconn_struct): self.pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr - self.notice_handler: Optional[Callable[..., None]] = None + self.notice_handler: Optional[ + Callable[["pq.proto.PGresult"], None] + ] = None + self.notify_handler: Optional[Callable[[PGnotify], None]] = None self._notice_receiver = impl.PQnoticeReceiver( # type: ignore partial(notice_receiver, wconn=ref(self)) @@ -480,7 +484,7 @@ class PGconn: raise PQerror(f"flushing failed: {error_message(self)}") return rv - def notifies(self) -> Optional["PGnotify"]: + def notifies(self) -> Optional[PGnotify]: ptr = impl.PQnotifies(self.pgconn_ptr) if ptr: c = ptr.contents diff --git a/psycopg3/pq/pq_cython.pxd b/psycopg3/pq/pq_cython.pxd index bf61c01d2..8c41348db 100644 --- a/psycopg3/pq/pq_cython.pxd +++ b/psycopg3/pq/pq_cython.pxd @@ -9,6 +9,7 @@ cdef class PGconn: cdef impl.PGconn* pgconn_ptr cdef object __weakref__ cdef public object notice_handler + cdef public object notify_handler cdef pid_t _procpid @staticmethod diff --git a/psycopg3/pq/proto.py b/psycopg3/pq/proto.py index 686f33eca..2e684c5c2 100644 --- a/psycopg3/pq/proto.py +++ b/psycopg3/pq/proto.py @@ -18,12 +18,13 @@ from .enums import ( ) if TYPE_CHECKING: - from .misc import ConninfoOption # noqa + from .misc import PGnotify, ConninfoOption # noqa class PGconn(Protocol): notice_handler: Optional[Callable[["PGresult"], None]] + notify_handler: Optional[Callable[["PGnotify"], None]] @classmethod def connect(cls, conninfo: bytes) -> "PGconn": @@ -217,6 +218,9 @@ class PGconn(Protocol): def flush(self) -> int: ... + def notifies(self) -> Optional["PGnotify"]: + ... + def make_empty_result(self, exec_status: ExecStatus) -> "PGresult": ... diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 5aa126ca7..6f923737f 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -105,3 +105,40 @@ t.join() assert out == "", out.strip().splitlines()[-1] finally: shutil.rmtree(dir, ignore_errors=True) + + +@pytest.mark.slow +def test_notifies(conn, dsn): + nconn = psycopg3.connect(dsn) + npid = nconn.pgconn.backend_pid + + def notifier(): + time.sleep(0.25) + nconn.pgconn.exec_(b"notify foo, '1'") + time.sleep(0.25) + nconn.pgconn.exec_(b"notify foo, '2'") + nconn.close() + + conn.pgconn.exec_(b"listen foo") + t0 = time.time() + t = threading.Thread(target=notifier) + t.start() + ns = [] + gen = conn.notifies() + for n in gen: + ns.append((n, time.time())) + if len(ns) >= 2: + gen.send(True) + assert len(ns) == 2 + + n, t1 = ns[0] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "1" + assert t1 - t0 == pytest.approx(0.25, abs=0.05) + + n, t1 = ns[1] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "2" + assert t1 - t0 == pytest.approx(0.5, abs=0.05) diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py index c91667432..30910c294 100644 --- a/tests/test_concurrency_async.py +++ b/tests/test_concurrency_async.py @@ -52,3 +52,42 @@ async def test_concurrent_execution(dsn): t0 = time.time() await asyncio.wait(workers) assert time.time() - t0 < 0.8, "something broken in concurrency" + + +@pytest.mark.slow +async def test_notifies(aconn, dsn): + nconn = await psycopg3.AsyncConnection.connect(dsn) + npid = nconn.pgconn.backend_pid + + async def notifier(): + await asyncio.sleep(0.25) + nconn.pgconn.exec_(b"notify foo, '1'") + await asyncio.sleep(0.25) + nconn.pgconn.exec_(b"notify foo, '2'") + await nconn.close() + + async def receiver(): + aconn.pgconn.exec_(b"listen foo") + gen = aconn.notifies() + async for n in gen: + ns.append((n, time.time())) + if len(ns) >= 2: + gen.send(True) + + ns = [] + t0 = time.time() + workers = [notifier(), receiver()] + await asyncio.wait(workers) + assert len(ns) == 2 + + n, t1 = ns[0] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "1" + assert t1 - t0 == pytest.approx(0.25, abs=0.05) + + n, t1 = ns[1] + assert n.pid == npid + assert n.channel == "foo" + assert n.payload == "2" + assert t1 - t0 == pytest.approx(0.5, abs=0.05) diff --git a/tests/test_connection.py b/tests/test_connection.py index 7bcaa60c0..4bcd95693 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -303,3 +303,41 @@ def test_notice_handlers(conn, caplog): with pytest.raises(ValueError): conn.remove_notice_handler(cb1) + + +def test_notify_handlers(conn): + nots1 = [] + nots2 = [] + + def cb1(n): + nots1.append(n) + + conn.add_notify_handler(cb1) + conn.add_notify_handler(lambda n: nots2.append(n)) + + conn.autocommit = True + cur = conn.cursor() + cur.execute("listen foo") + cur.execute("notify foo, 'n1'") + + assert len(nots1) == 1 + n = nots1[0] + assert n.channel == "foo" + assert n.payload == "n1" + assert n.pid == conn.pgconn.backend_pid + + assert len(nots2) == 1 + assert nots2[0] == nots1[0] + + conn.remove_notify_handler(cb1) + cur.execute("notify foo, 'n2'") + + assert len(nots1) == 1 + assert len(nots2) == 2 + n = nots2[1] + assert n.channel == "foo" + assert n.payload == "n2" + assert n.pid == conn.pgconn.backend_pid + + with pytest.raises(ValueError): + conn.remove_notify_handler(cb1) diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 585cb7738..6953d5b96 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -314,3 +314,41 @@ async def test_notice_handlers(aconn, caplog): with pytest.raises(ValueError): aconn.remove_notice_handler(cb1) + + +async def test_notify_handlers(aconn): + nots1 = [] + nots2 = [] + + def cb1(n): + nots1.append(n) + + aconn.add_notify_handler(cb1) + aconn.add_notify_handler(lambda n: nots2.append(n)) + + aconn.autocommit = True + cur = aconn.cursor() + await cur.execute("listen foo") + await cur.execute("notify foo, 'n1'") + + assert len(nots1) == 1 + n = nots1[0] + assert n.channel == "foo" + assert n.payload == "n1" + assert n.pid == aconn.pgconn.backend_pid + + assert len(nots2) == 1 + assert nots2[0] == nots1[0] + + aconn.remove_notify_handler(cb1) + await cur.execute("notify foo, 'n2'") + + assert len(nots1) == 1 + assert len(nots2) == 2 + n = nots2[1] + assert n.channel == "foo" + assert n.payload == "n2" + assert n.pid == aconn.pgconn.backend_pid + + with pytest.raises(ValueError): + aconn.remove_notify_handler(cb1)