import logging
import asyncio
import threading
-from typing import Any, Callable, List, Optional, Type, cast
+from typing import Any, AsyncGenerator, Callable, Generator, List, NamedTuple
+from typing import Optional, Type, cast
from weakref import ref, ReferenceType
from functools import partial
from .pq import TransactionStatus, ExecStatus
from .conninfo import make_conninfo
from .waiting import wait, wait_async
+from .generators import notifies
logger = logging.getLogger(__name__)
package_logger = logging.getLogger("psycopg3")
-
connect: Callable[[str], proto.PQGen[pq.proto.PGconn]]
execute: Callable[[pq.proto.PGconn], proto.PQGen[List[pq.proto.PGresult]]]
connect = generators.connect
execute = generators.execute
+
+class Notify(NamedTuple):
+ channel: str
+ payload: str
+ pid: int
+
+
NoticeHandler = Callable[[e.Diagnostic], None]
+NotifyHandler = Callable[[Notify], None]
class BaseConnection:
self.dumpers: proto.DumpersMap = {}
self.loaders: proto.LoadersMap = {}
self._notice_handlers: List[NoticeHandler] = []
+ self._notify_handlers: List[NotifyHandler] = []
# name of the postgres encoding (in bytes)
self._pgenc = b""
wself = ref(self)
pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
+ pgconn.notify_handler = partial(BaseConnection._notify_handler, wself)
@property
def closed(self) -> bool:
"error processing notice callback '%s': %s", cb, ex
)
+ def add_notify_handler(self, callback: NotifyHandler) -> None:
+ self._notify_handlers.append(callback)
+
+ def remove_notify_handler(self, callback: NotifyHandler) -> None:
+ self._notify_handlers.remove(callback)
+
+ @staticmethod
+ def _notify_handler(
+ wself: "ReferenceType[BaseConnection]", pgn: pq.PGnotify
+ ) -> None:
+ self = wself()
+ if self is None or not self._notify_handlers:
+ return
+
+ decode = self.codec.decode
+ n = Notify(decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid)
+ for cb in self._notify_handlers:
+ cb(n)
+
class Connection(BaseConnection):
"""
if result.status != ExecStatus.TUPLES_OK:
raise e.error_from_result(result)
+ def notifies(self) -> Generator[Optional[Notify], bool, None]:
+ decode = self.codec.decode
+ while 1:
+ with self.lock:
+ ns = self.wait(notifies(self.pgconn))
+ for pgn in ns:
+ n = Notify(
+ decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid
+ )
+ if (yield n):
+ yield None # for the send who stopped us
+ return
+
class AsyncConnection(BaseConnection):
"""
(result,) = await self.wait(gen)
if result.status != ExecStatus.TUPLES_OK:
raise e.error_from_result(result)
+
+ async def notifies(self) -> AsyncGenerator[Optional[Notify], bool]:
+ decode = self.codec.decode
+ while 1:
+ async with self.lock:
+ ns = await self.wait(notifies(self.pgconn))
+ for pgn in ns:
+ n = Notify(
+ decode(pgn.relname)[0], decode(pgn.extra)[0], pgn.be_pid
+ )
+ if (yield n):
+ yield None
+ return
ready = yield pgconn.socket, Wait.RW
if ready & Ready.R:
+ # This call may read notifies: they will be saved in the
+ # PGconn buffer and passed to Python later, in `fetch()`.
pgconn.consume_input()
continue
if pgconn.is_busy():
yield pgconn.socket, Wait.R
continue
+
+ # Consume notifies
+ while 1:
+ n = pgconn.notifies()
+ if n is None:
+ break
+ if pgconn.notify_handler is not None:
+ pgconn.notify_handler(n)
+
res = pgconn.get_result()
if res is None:
break
break
return results
+
+
+def notifies(pgconn: pq.proto.PGconn) -> PQGen[List[pq.PGnotify]]:
+ yield pgconn.socket, Wait.R
+ pgconn.consume_input()
+
+ ns = []
+ while 1:
+ n = pgconn.notifies()
+ if n is not None:
+ ns.append(n)
+ else:
+ break
+
+ return ns
results: List[pq.proto.PGresult] = []
cdef libpq.PGconn *pgconn_ptr = pgconn.pgconn_ptr
cdef int status
+ cdef libpq.PGnotify *notify
# Sending the query
while 1:
status = yield libpq.PQsocket(pgconn_ptr), WAIT_RW
if status & READY_R:
+ # This call may read notifies which will be saved in the
+ # PGconn buffer and passed to Python later.
if 1 != libpq.PQconsumeInput(pgconn_ptr):
raise pq.PQerror(
f"consuming input failed: {pq.error_message(pgconn)}")
yield wr
continue
+ # Consume notifies
+ if pgconn.notify_handler:
+ while 1:
+ pynotify = pgconn.notifies()
+ if pynotify is None:
+ break
+ pgconn.notify_handler(pynotify)
+ else:
+ while 1:
+ notify = libpq.PQnotifies(pgconn_ptr)
+ if notify is NULL:
+ break
+ libpq.PQfreemem(notify)
+
res = libpq.PQgetResult(pgconn_ptr)
if res is NULL:
break
Format,
)
from .encodings import py_codecs
-from .misc import error_message, ConninfoOption, PQerror
+from .misc import error_message, ConninfoOption, PQerror, PGnotify
from . import proto
logger = logging.getLogger(__name__)
"DiagnosticField",
"Format",
"PGconn",
+ "PGnotify",
"Conninfo",
"PQerror",
"error_message",
__slots__ = (
"pgconn_ptr",
"notice_handler",
+ "notify_handler",
"_notice_receiver",
"_procpid",
"__weakref__",
def __init__(self, pgconn_ptr: impl.PGconn_struct):
self.pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr
- self.notice_handler: Optional[Callable[..., None]] = None
+ self.notice_handler: Optional[
+ Callable[["pq.proto.PGresult"], None]
+ ] = None
+ self.notify_handler: Optional[Callable[[PGnotify], None]] = None
self._notice_receiver = impl.PQnoticeReceiver( # type: ignore
partial(notice_receiver, wconn=ref(self))
raise PQerror(f"flushing failed: {error_message(self)}")
return rv
- def notifies(self) -> Optional["PGnotify"]:
+ def notifies(self) -> Optional[PGnotify]:
ptr = impl.PQnotifies(self.pgconn_ptr)
if ptr:
c = ptr.contents
cdef impl.PGconn* pgconn_ptr
cdef object __weakref__
cdef public object notice_handler
+ cdef public object notify_handler
cdef pid_t _procpid
@staticmethod
)
if TYPE_CHECKING:
- from .misc import ConninfoOption # noqa
+ from .misc import PGnotify, ConninfoOption # noqa
class PGconn(Protocol):
notice_handler: Optional[Callable[["PGresult"], None]]
+ notify_handler: Optional[Callable[["PGnotify"], None]]
@classmethod
def connect(cls, conninfo: bytes) -> "PGconn":
def flush(self) -> int:
...
+ def notifies(self) -> Optional["PGnotify"]:
+ ...
+
def make_empty_result(self, exec_status: ExecStatus) -> "PGresult":
...
assert out == "", out.strip().splitlines()[-1]
finally:
shutil.rmtree(dir, ignore_errors=True)
+
+
+@pytest.mark.slow
+def test_notifies(conn, dsn):
+ nconn = psycopg3.connect(dsn)
+ npid = nconn.pgconn.backend_pid
+
+ def notifier():
+ time.sleep(0.25)
+ nconn.pgconn.exec_(b"notify foo, '1'")
+ time.sleep(0.25)
+ nconn.pgconn.exec_(b"notify foo, '2'")
+ nconn.close()
+
+ conn.pgconn.exec_(b"listen foo")
+ t0 = time.time()
+ t = threading.Thread(target=notifier)
+ t.start()
+ ns = []
+ gen = conn.notifies()
+ for n in gen:
+ ns.append((n, time.time()))
+ if len(ns) >= 2:
+ gen.send(True)
+ assert len(ns) == 2
+
+ n, t1 = ns[0]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "1"
+ assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+ n, t1 = ns[1]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "2"
+ assert t1 - t0 == pytest.approx(0.5, abs=0.05)
t0 = time.time()
await asyncio.wait(workers)
assert time.time() - t0 < 0.8, "something broken in concurrency"
+
+
+@pytest.mark.slow
+async def test_notifies(aconn, dsn):
+ nconn = await psycopg3.AsyncConnection.connect(dsn)
+ npid = nconn.pgconn.backend_pid
+
+ async def notifier():
+ await asyncio.sleep(0.25)
+ nconn.pgconn.exec_(b"notify foo, '1'")
+ await asyncio.sleep(0.25)
+ nconn.pgconn.exec_(b"notify foo, '2'")
+ await nconn.close()
+
+ async def receiver():
+ aconn.pgconn.exec_(b"listen foo")
+ gen = aconn.notifies()
+ async for n in gen:
+ ns.append((n, time.time()))
+ if len(ns) >= 2:
+ gen.send(True)
+
+ ns = []
+ t0 = time.time()
+ workers = [notifier(), receiver()]
+ await asyncio.wait(workers)
+ assert len(ns) == 2
+
+ n, t1 = ns[0]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "1"
+ assert t1 - t0 == pytest.approx(0.25, abs=0.05)
+
+ n, t1 = ns[1]
+ assert n.pid == npid
+ assert n.channel == "foo"
+ assert n.payload == "2"
+ assert t1 - t0 == pytest.approx(0.5, abs=0.05)
with pytest.raises(ValueError):
conn.remove_notice_handler(cb1)
+
+
+def test_notify_handlers(conn):
+ nots1 = []
+ nots2 = []
+
+ def cb1(n):
+ nots1.append(n)
+
+ conn.add_notify_handler(cb1)
+ conn.add_notify_handler(lambda n: nots2.append(n))
+
+ conn.autocommit = True
+ cur = conn.cursor()
+ cur.execute("listen foo")
+ cur.execute("notify foo, 'n1'")
+
+ assert len(nots1) == 1
+ n = nots1[0]
+ assert n.channel == "foo"
+ assert n.payload == "n1"
+ assert n.pid == conn.pgconn.backend_pid
+
+ assert len(nots2) == 1
+ assert nots2[0] == nots1[0]
+
+ conn.remove_notify_handler(cb1)
+ cur.execute("notify foo, 'n2'")
+
+ assert len(nots1) == 1
+ assert len(nots2) == 2
+ n = nots2[1]
+ assert n.channel == "foo"
+ assert n.payload == "n2"
+ assert n.pid == conn.pgconn.backend_pid
+
+ with pytest.raises(ValueError):
+ conn.remove_notify_handler(cb1)
with pytest.raises(ValueError):
aconn.remove_notice_handler(cb1)
+
+
+async def test_notify_handlers(aconn):
+ nots1 = []
+ nots2 = []
+
+ def cb1(n):
+ nots1.append(n)
+
+ aconn.add_notify_handler(cb1)
+ aconn.add_notify_handler(lambda n: nots2.append(n))
+
+ aconn.autocommit = True
+ cur = aconn.cursor()
+ await cur.execute("listen foo")
+ await cur.execute("notify foo, 'n1'")
+
+ assert len(nots1) == 1
+ n = nots1[0]
+ assert n.channel == "foo"
+ assert n.payload == "n1"
+ assert n.pid == aconn.pgconn.backend_pid
+
+ assert len(nots2) == 1
+ assert nots2[0] == nots1[0]
+
+ aconn.remove_notify_handler(cb1)
+ await cur.execute("notify foo, 'n2'")
+
+ assert len(nots1) == 1
+ assert len(nots2) == 2
+ n = nots2[1]
+ assert n.channel == "foo"
+ assert n.payload == "n2"
+ assert n.pid == aconn.pgconn.backend_pid
+
+ with pytest.raises(ValueError):
+ aconn.remove_notify_handler(cb1)