]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added notification handling in connections
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 25 May 2020 06:13:44 +0000 (18:13 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 25 May 2020 06:17:19 +0000 (18:17 +1200)
Added both a callback systen and an explicit generator. I'll share the
design on the ML and ask for comments.

psycopg3/connection.py
psycopg3/generators.py
psycopg3/generators.pyx
psycopg3/pq/__init__.py
psycopg3/pq/pq_ctypes.py
psycopg3/pq/pq_cython.pxd
psycopg3/pq/proto.py
tests/test_concurrency.py
tests/test_concurrency_async.py
tests/test_connection.py
tests/test_connection_async.py

index ad319fd4334a5ae0e4cb8e7213c3afabdb380ac8..917870111aa759e68309809596ececf8dba7e645 100644 (file)
@@ -8,7 +8,8 @@ import codecs
 import logging
 import asyncio
 import threading
-from typing import Any, Callable, List, Optional, Type, cast
+from typing import Any, AsyncGenerator, Callable, Generator, List, NamedTuple
+from typing import Optional, Type, cast
 from weakref import ref, ReferenceType
 from functools import partial
 
@@ -19,11 +20,11 @@ from . import proto
 from .pq import TransactionStatus, ExecStatus
 from .conninfo import make_conninfo
 from .waiting import wait, wait_async
+from .generators import notifies
 
 logger = logging.getLogger(__name__)
 package_logger = logging.getLogger("psycopg3")
 
-
 connect: Callable[[str], proto.PQGen[pq.proto.PGconn]]
 execute: Callable[[pq.proto.PGconn], proto.PQGen[List[pq.proto.PGresult]]]
 
@@ -39,7 +40,15 @@ else:
     connect = generators.connect
     execute = generators.execute
 
+
+class Notify(NamedTuple):
+    channel: str
+    payload: str
+    pid: int
+
+
 NoticeHandler = Callable[[e.Diagnostic], None]
+NotifyHandler = Callable[[Notify], None]
 
 
 class BaseConnection:
@@ -73,12 +82,14 @@ class BaseConnection:
         self.dumpers: proto.DumpersMap = {}
         self.loaders: proto.LoadersMap = {}
         self._notice_handlers: List[NoticeHandler] = []
+        self._notify_handlers: List[NotifyHandler] = []
         # name of the postgres encoding (in bytes)
         self._pgenc = b""
 
         wself = ref(self)
 
         pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
+        pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
 
     @property
     def closed(self) -> bool:
@@ -161,6 +172,25 @@ class BaseConnection:
                     "error processing notice callback '%s': %s", cb, ex
                 )
 
+    def add_notify_handler(self, callback: NotifyHandler) -> None:
+        self._notify_handlers.append(callback)
+
+    def remove_notify_handler(self, callback: NotifyHandler) -> None:
+        self._notify_handlers.remove(callback)
+
+    @staticmethod
+    def _notify_handler(
+        wself: "ReferenceType[BaseConnection]", pgn: pq.PGnotify
+    ) -> None:
+        self = wself()
+        if self is None or not self._notify_handlers:
+            return
+
+        decode = self.codec.decode
+        n = Notify(decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid)
+        for cb in self._notify_handlers:
+            cb(n)
+
 
 class Connection(BaseConnection):
     """
@@ -254,6 +284,19 @@ class Connection(BaseConnection):
             if result.status != ExecStatus.TUPLES_OK:
                 raise e.error_from_result(result)
 
+    def notifies(self) -> Generator[Optional[Notify], bool, None]:
+        decode = self.codec.decode
+        while 1:
+            with self.lock:
+                ns = self.wait(notifies(self.pgconn))
+            for pgn in ns:
+                n = Notify(
+                    decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid
+                )
+                if (yield n):
+                    yield None  # for the send who stopped us
+                    return
+
 
 class AsyncConnection(BaseConnection):
     """
@@ -345,3 +388,16 @@ class AsyncConnection(BaseConnection):
             (result,) = await self.wait(gen)
             if result.status != ExecStatus.TUPLES_OK:
                 raise e.error_from_result(result)
+
+    async def notifies(self) -> AsyncGenerator[Optional[Notify], bool]:
+        decode = self.codec.decode
+        while 1:
+            async with self.lock:
+                ns = await self.wait(notifies(self.pgconn))
+            for pgn in ns:
+                n = Notify(
+                    decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid
+                )
+                if (yield n):
+                    yield None
+                    return
index cf75614904e5d3f9e50e5abdb5d3d41038536117..3fd82e6beda0c0127004ec59e284fcf19bc476c2 100644 (file)
@@ -92,6 +92,8 @@ def send(pgconn: pq.proto.PGconn) -> PQGen[None]:
 
         ready = yield pgconn.socket, Wait.RW
         if ready & Ready.R:
+            # This call may read notifies: they will be saved in the
+            # PGconn buffer and passed to Python later, in `fetch()`.
             pgconn.consume_input()
         continue
 
@@ -113,6 +115,15 @@ def fetch(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]:
         if pgconn.is_busy():
             yield pgconn.socket, Wait.R
             continue
+
+        # Consume notifies
+        while 1:
+            n = pgconn.notifies()
+            if n is None:
+                break
+            if pgconn.notify_handler is not None:
+                pgconn.notify_handler(n)
+
         res = pgconn.get_result()
         if res is None:
             break
@@ -123,3 +134,18 @@ def fetch(pgconn: pq.proto.PGconn) -> PQGen[List[pq.proto.PGresult]]:
             break
 
     return results
+
+
+def notifies(pgconn: pq.proto.PGconn) -> PQGen[List[pq.PGnotify]]:
+    yield pgconn.socket, Wait.R
+    pgconn.consume_input()
+
+    ns = []
+    while 1:
+        n = pgconn.notifies()
+        if n is not None:
+            ns.append(n)
+        else:
+            break
+
+    return ns
index 4eb9a5684649d4f0b48f12c69d63bda2e61fa712..b11f5be2cd398a4c1be29f1659ad61b2f697176c 100644 (file)
@@ -69,6 +69,7 @@ def execute(PGconn pgconn) -> PQGen[List[pq.proto.PGresult]]:
     results: List[pq.proto.PGresult] = []
     cdef libpq.PGconn *pgconn_ptr = pgconn.pgconn_ptr
     cdef int status
+    cdef libpq.PGnotify *notify
 
     # Sending the query
     while 1:
@@ -77,6 +78,8 @@ def execute(PGconn pgconn) -> PQGen[List[pq.proto.PGresult]]:
 
         status = yield libpq.PQsocket(pgconn_ptr), WAIT_RW
         if status & READY_R:
+            # This call may read notifies which will be saved in the
+            # PGconn buffer and passed to Python later.
             if 1 != libpq.PQconsumeInput(pgconn_ptr):
                 raise pq.PQerror(
                     f"consuming input failed: {pq.error_message(pgconn)}")
@@ -93,6 +96,20 @@ def execute(PGconn pgconn) -> PQGen[List[pq.proto.PGresult]]:
             yield wr
             continue
 
+        # Consume notifies
+        if pgconn.notify_handler:
+            while 1:
+                pynotify = pgconn.notifies()
+                if pynotify is None:
+                    break
+                pgconn.notify_handler(pynotify)
+        else:
+            while 1:
+                notify = libpq.PQnotifies(pgconn_ptr)
+                if notify is NULL:
+                    break
+                libpq.PQfreemem(notify)
+
         res = libpq.PQgetResult(pgconn_ptr)
         if res is NULL:
             break
index 16463e9c216acc5f0f72d997bb85fac5041cd3a6..e323039f1fcdd5a0d0e614ccf4f7a491843ead28 100644 (file)
@@ -23,7 +23,7 @@ from .enums import (
     Format,
 )
 from .encodings import py_codecs
-from .misc import error_message, ConninfoOption, PQerror
+from .misc import error_message, ConninfoOption, PQerror, PGnotify
 from . import proto
 
 logger = logging.getLogger(__name__)
@@ -100,6 +100,7 @@ __all__ = (
     "DiagnosticField",
     "Format",
     "PGconn",
+    "PGnotify",
     "Conninfo",
     "PQerror",
     "error_message",
index e86d998ad0db179ccb12a09aa67ea49a1c7f7d0d..25cd3d43f976f713b496a7a3403326241106b061 100644 (file)
@@ -62,6 +62,7 @@ class PGconn:
     __slots__ = (
         "pgconn_ptr",
         "notice_handler",
+        "notify_handler",
         "_notice_receiver",
         "_procpid",
         "__weakref__",
@@ -69,7 +70,10 @@ class PGconn:
 
     def __init__(self, pgconn_ptr: impl.PGconn_struct):
         self.pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr
-        self.notice_handler: Optional[Callable[..., None]] = None
+        self.notice_handler: Optional[
+            Callable[["pq.proto.PGresult"], None]
+        ] = None
+        self.notify_handler: Optional[Callable[[PGnotify], None]] = None
 
         self._notice_receiver = impl.PQnoticeReceiver(  # type: ignore
             partial(notice_receiver, wconn=ref(self))
@@ -480,7 +484,7 @@ class PGconn:
             raise PQerror(f"flushing failed: {error_message(self)}")
         return rv
 
-    def notifies(self) -> Optional["PGnotify"]:
+    def notifies(self) -> Optional[PGnotify]:
         ptr = impl.PQnotifies(self.pgconn_ptr)
         if ptr:
             c = ptr.contents
index bf61c01d2eecfb299883d87e40af0fb92be4aae0..8c41348db1bb1742635097ea6f1e898743664ada 100644 (file)
@@ -9,6 +9,7 @@ cdef class PGconn:
     cdef impl.PGconn* pgconn_ptr
     cdef object __weakref__
     cdef public object notice_handler
+    cdef public object notify_handler
     cdef pid_t _procpid
 
     @staticmethod
index 686f33eca4b64bcb129dee8fde08fc4302e50dff..2e684c5c2393576e2d556a0cde7ed6305e42eab2 100644 (file)
@@ -18,12 +18,13 @@ from .enums import (
 )
 
 if TYPE_CHECKING:
-    from .misc import ConninfoOption  # noqa
+    from .misc import PGnotify, ConninfoOption  # noqa
 
 
 class PGconn(Protocol):
 
     notice_handler: Optional[Callable[["PGresult"], None]]
+    notify_handler: Optional[Callable[["PGnotify"], None]]
 
     @classmethod
     def connect(cls, conninfo: bytes) -> "PGconn":
@@ -217,6 +218,9 @@ class PGconn(Protocol):
     def flush(self) -> int:
         ...
 
+    def notifies(self) -> Optional["PGnotify"]:
+        ...
+
     def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
         ...
 
index 5aa126ca7ebc75eb14f6c13e8eef1faf3a4cb6cc..6f923737f49df797a9a05e99ad0fa34553d4097b 100644 (file)
@@ -105,3 +105,40 @@ t.join()
         assert out == "", out.strip().splitlines()[-1]
     finally:
         shutil.rmtree(dir, ignore_errors=True)
+
+
+@pytest.mark.slow
+def test_notifies(conn, dsn):
+    nconn = psycopg3.connect(dsn)
+    npid = nconn.pgconn.backend_pid
+
+    def notifier():
+        time.sleep(0.25)
+        nconn.pgconn.exec_(b"notify foo, '1'")
+        time.sleep(0.25)
+        nconn.pgconn.exec_(b"notify foo, '2'")
+        nconn.close()
+
+    conn.pgconn.exec_(b"listen foo")
+    t0 = time.time()
+    t = threading.Thread(target=notifier)
+    t.start()
+    ns = []
+    gen = conn.notifies()
+    for n in gen:
+        ns.append((n, time.time()))
+        if len(ns) >= 2:
+            gen.send(True)
+    assert len(ns) == 2
+
+    n, t1 = ns[0]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "1"
+    assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+    n, t1 = ns[1]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "2"
+    assert t1 - t0 == pytest.approx(0.5, abs=0.05)
index c916674329a7064558f4c79676d6103f524cdc4f..30910c29494830544e03c9cb8593636f42266204 100644 (file)
@@ -52,3 +52,42 @@ async def test_concurrent_execution(dsn):
     t0 = time.time()
     await asyncio.wait(workers)
     assert time.time() - t0 < 0.8, "something broken in concurrency"
+
+
+@pytest.mark.slow
+async def test_notifies(aconn, dsn):
+    nconn = await psycopg3.AsyncConnection.connect(dsn)
+    npid = nconn.pgconn.backend_pid
+
+    async def notifier():
+        await asyncio.sleep(0.25)
+        nconn.pgconn.exec_(b"notify foo, '1'")
+        await asyncio.sleep(0.25)
+        nconn.pgconn.exec_(b"notify foo, '2'")
+        await nconn.close()
+
+    async def receiver():
+        aconn.pgconn.exec_(b"listen foo")
+        gen = aconn.notifies()
+        async for n in gen:
+            ns.append((n, time.time()))
+            if len(ns) >= 2:
+                gen.send(True)
+
+    ns = []
+    t0 = time.time()
+    workers = [notifier(), receiver()]
+    await asyncio.wait(workers)
+    assert len(ns) == 2
+
+    n, t1 = ns[0]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "1"
+    assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+    n, t1 = ns[1]
+    assert n.pid == npid
+    assert n.channel == "foo"
+    assert n.payload == "2"
+    assert t1 - t0 == pytest.approx(0.5, abs=0.05)
index 7bcaa60c09e3b1b5db85af816d8a1105f9c8838f..4bcd95693957c1295fb5868e21c4f77db4335b3a 100644 (file)
@@ -303,3 +303,41 @@ def test_notice_handlers(conn, caplog):
 
     with pytest.raises(ValueError):
         conn.remove_notice_handler(cb1)
+
+
+def test_notify_handlers(conn):
+    nots1 = []
+    nots2 = []
+
+    def cb1(n):
+        nots1.append(n)
+
+    conn.add_notify_handler(cb1)
+    conn.add_notify_handler(lambda n: nots2.append(n))
+
+    conn.autocommit = True
+    cur = conn.cursor()
+    cur.execute("listen foo")
+    cur.execute("notify foo, 'n1'")
+
+    assert len(nots1) == 1
+    n = nots1[0]
+    assert n.channel == "foo"
+    assert n.payload == "n1"
+    assert n.pid == conn.pgconn.backend_pid
+
+    assert len(nots2) == 1
+    assert nots2[0] == nots1[0]
+
+    conn.remove_notify_handler(cb1)
+    cur.execute("notify foo, 'n2'")
+
+    assert len(nots1) == 1
+    assert len(nots2) == 2
+    n = nots2[1]
+    assert n.channel == "foo"
+    assert n.payload == "n2"
+    assert n.pid == conn.pgconn.backend_pid
+
+    with pytest.raises(ValueError):
+        conn.remove_notify_handler(cb1)
index 585cb77385be74a4b64f2356f94f710e53777436..6953d5b96ae4eb97aa8cc61c7d54bbf6e3463804 100644 (file)
@@ -314,3 +314,41 @@ async def test_notice_handlers(aconn, caplog):
 
     with pytest.raises(ValueError):
         aconn.remove_notice_handler(cb1)
+
+
+async def test_notify_handlers(aconn):
+    nots1 = []
+    nots2 = []
+
+    def cb1(n):
+        nots1.append(n)
+
+    aconn.add_notify_handler(cb1)
+    aconn.add_notify_handler(lambda n: nots2.append(n))
+
+    aconn.autocommit = True
+    cur = aconn.cursor()
+    await cur.execute("listen foo")
+    await cur.execute("notify foo, 'n1'")
+
+    assert len(nots1) == 1
+    n = nots1[0]
+    assert n.channel == "foo"
+    assert n.payload == "n1"
+    assert n.pid == aconn.pgconn.backend_pid
+
+    assert len(nots2) == 1
+    assert nots2[0] == nots1[0]
+
+    aconn.remove_notify_handler(cb1)
+    await cur.execute("notify foo, 'n2'")
+
+    assert len(nots1) == 1
+    assert len(nots2) == 2
+    n = nots2[1]
+    assert n.channel == "foo"
+    assert n.payload == "n2"
+    assert n.pid == aconn.pgconn.backend_pid
+
+    with pytest.raises(ValueError):
+        aconn.remove_notify_handler(cb1)