]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added notification handlers on connection objects
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 May 2020 17:56:16 +0000 (05:56 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 May 2020 17:56:16 +0000 (05:56 +1200)
psycopg3/connection.py
tests/test_async_connection.py
tests/test_connection.py

index b538d2183b9282a21f7f89bddbefaf2cf401ab19..63aceb5dc9960dd781d7abe8193a980e092e0639 100644 (file)
@@ -9,6 +9,8 @@ import logging
 import asyncio
 import threading
 from typing import Any, Callable, List, Optional, Type, cast
+from weakref import ref, ReferenceType
+from functools import partial
 
 from . import pq
 from . import errors as e
@@ -19,6 +21,7 @@ from .conninfo import make_conninfo
 from .waiting import wait, wait_async
 
 logger = logging.getLogger(__name__)
+package_logger = logging.getLogger("psycopg3")
 
 
 connect: Callable[[str], proto.PQGen[pq.proto.PGconn]]
@@ -36,6 +39,8 @@ else:
     connect = generators.connect
     execute = generators.execute
 
+NoticeCallback = Callable[[pq.proto.PGresult], None]
+
 
 class BaseConnection:
     """
@@ -67,9 +72,16 @@ class BaseConnection:
         self._autocommit = False
         self.dumpers: proto.DumpersMap = {}
         self.loaders: proto.LoadersMap = {}
+        self._notice_callbacks: List[NoticeCallback] = []
         # name of the postgres encoding (in bytes)
         self._pgenc = b""
 
+        wself = ref(self)
+
+        pgconn.notice_callback = partial(
+            BaseConnection._notice_callback, wself
+        )
+
     @property
     def closed(self) -> bool:
         return self.status == self.ConnStatus.BAD
@@ -128,6 +140,27 @@ class BaseConnection:
         else:
             return "UTF8"
 
+    def add_notice_callback(self, callback: NoticeCallback) -> None:
+        self._notice_callbacks.append(callback)
+
+    def remove_notice_callback(self, callback: NoticeCallback) -> None:
+        self._notice_callbacks.remove(callback)
+
+    @staticmethod
+    def _notice_callback(
+        wself: "ReferenceType[BaseConnection]", res: pq.proto.PGresult
+    ) -> None:
+        self = wself()
+        if self is None:
+            return
+        for cb in self._notice_callbacks:
+            try:
+                cb(res)
+            except Exception as ex:
+                package_logger.exception(
+                    "error processing notice callback '%s': %s", cb, ex
+                )
+
 
 class Connection(BaseConnection):
     """
index 0e0455bf4ce248f87cec8bd952ef2976f32983e3..31c3044062beb2139825b0dd3e9788f08928c6c7 100644 (file)
@@ -1,4 +1,5 @@
 import pytest
+import logging
 
 import psycopg3
 from psycopg3 import AsyncConnection
@@ -215,3 +216,70 @@ def test_broken_connection(aconn, loop):
             cur.execute("select pg_terminate_backend(pg_backend_pid())")
         )
     assert aconn.closed
+
+
+def test_notice_callbacks(aconn, loop, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3")
+    messages = []
+    severities = []
+
+    def cb1(res):
+        messages.append(
+            res.error_field(psycopg3.pq.DiagnosticField.MESSAGE_PRIMARY)
+        )
+
+    def cb2(res):
+        raise Exception("hello from cb2")
+
+    def cb3(res):
+        severities.append(
+            res.error_field(psycopg3.pq.DiagnosticField.SEVERITY_NONLOCALIZED)
+        )
+
+    aconn.add_notice_callback(cb1)
+    aconn.add_notice_callback(cb2)
+    aconn.add_notice_callback("the wrong thing")
+    aconn.add_notice_callback(cb3)
+
+    cur = aconn.cursor()
+    loop.run_until_complete(
+        cur.execute(
+            """
+do $$
+begin
+    raise notice 'hello notice';
+end
+$$ language plpgsql
+    """
+        )
+    )
+    assert messages == [b"hello notice"]
+    assert severities == [b"NOTICE"]
+
+    assert len(caplog.records) == 2
+    rec = caplog.records[0]
+    assert rec.levelno == logging.ERROR
+    assert "hello from cb2" in rec.message
+    rec = caplog.records[1]
+    assert rec.levelno == logging.ERROR
+    assert "the wrong thing" in rec.message
+
+    aconn.remove_notice_callback(cb1)
+    aconn.remove_notice_callback("the wrong thing")
+    loop.run_until_complete(
+        cur.execute(
+            """
+do $$
+begin
+    raise warning 'hello warning';
+end
+$$ language plpgsql
+    """
+        )
+    )
+    assert len(caplog.records) == 3
+    assert messages == [b"hello notice"]
+    assert severities == [b"NOTICE", b"WARNING"]
+
+    with pytest.raises(ValueError):
+        aconn.remove_notice_callback(cb1)
index 6e11bb3e91e17095e33ac4d91beb4d5dc8abecf4..c1dce1bbc296dd263a1c36d5517cafe922e62414 100644 (file)
@@ -1,4 +1,5 @@
 import pytest
+import logging
 
 import psycopg3
 from psycopg3 import Connection
@@ -205,3 +206,66 @@ def test_broken_connection(conn):
     with pytest.raises(psycopg3.DatabaseError):
         cur.execute("select pg_terminate_backend(pg_backend_pid())")
     assert conn.closed
+
+
+def test_notice_callbacks(conn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3")
+    messages = []
+    severities = []
+
+    def cb1(res):
+        messages.append(
+            res.error_field(psycopg3.pq.DiagnosticField.MESSAGE_PRIMARY)
+        )
+
+    def cb2(res):
+        raise Exception("hello from cb2")
+
+    def cb3(res):
+        severities.append(
+            res.error_field(psycopg3.pq.DiagnosticField.SEVERITY_NONLOCALIZED)
+        )
+
+    conn.add_notice_callback(cb1)
+    conn.add_notice_callback(cb2)
+    conn.add_notice_callback("the wrong thing")
+    conn.add_notice_callback(cb3)
+
+    cur = conn.cursor()
+    cur.execute(
+        """
+do $$
+begin
+    raise notice 'hello notice';
+end
+$$ language plpgsql
+    """
+    )
+    assert messages == [b"hello notice"]
+    assert severities == [b"NOTICE"]
+
+    assert len(caplog.records) == 2
+    rec = caplog.records[0]
+    assert rec.levelno == logging.ERROR
+    assert "hello from cb2" in rec.message
+    rec = caplog.records[1]
+    assert rec.levelno == logging.ERROR
+    assert "the wrong thing" in rec.message
+
+    conn.remove_notice_callback(cb1)
+    conn.remove_notice_callback("the wrong thing")
+    cur.execute(
+        """
+do $$
+begin
+    raise warning 'hello warning';
+end
+$$ language plpgsql
+    """
+    )
+    assert len(caplog.records) == 3
+    assert messages == [b"hello notice"]
+    assert severities == [b"NOTICE", b"WARNING"]
+
+    with pytest.raises(ValueError):
+        conn.remove_notice_callback(cb1)