From: Daniele Varrazzo Date: Thu, 21 May 2020 17:56:16 +0000 (+1200) Subject: Added notification handlers on connection objects X-Git-Tag: 3.0.dev0~506 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d4109f8beaf78ea1e8c2d271f587723487f93bb0;p=thirdparty%2Fpsycopg.git Added notification handlers on connection objects --- diff --git a/psycopg3/connection.py b/psycopg3/connection.py index b538d2183..63aceb5dc 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -9,6 +9,8 @@ import logging import asyncio import threading from typing import Any, Callable, List, Optional, Type, cast +from weakref import ref, ReferenceType +from functools import partial from . import pq from . import errors as e @@ -19,6 +21,7 @@ from .conninfo import make_conninfo from .waiting import wait, wait_async logger = logging.getLogger(__name__) +package_logger = logging.getLogger("psycopg3") connect: Callable[[str], proto.PQGen[pq.proto.PGconn]] @@ -36,6 +39,8 @@ else: connect = generators.connect execute = generators.execute +NoticeCallback = Callable[[pq.proto.PGresult], None] + class BaseConnection: """ @@ -67,9 +72,16 @@ class BaseConnection: self._autocommit = False self.dumpers: proto.DumpersMap = {} self.loaders: proto.LoadersMap = {} + self._notice_callbacks: List[NoticeCallback] = [] # name of the postgres encoding (in bytes) self._pgenc = b"" + wself = ref(self) + + pgconn.notice_callback = partial( + BaseConnection._notice_callback, wself + ) + @property def closed(self) -> bool: return self.status == self.ConnStatus.BAD @@ -128,6 +140,27 @@ class BaseConnection: else: return "UTF8" + def add_notice_callback(self, callback: NoticeCallback) -> None: + self._notice_callbacks.append(callback) + + def remove_notice_callback(self, callback: NoticeCallback) -> None: + self._notice_callbacks.remove(callback) + + @staticmethod + def _notice_callback( + wself: "ReferenceType[BaseConnection]", res: pq.proto.PGresult + ) -> None: + self = wself() + if self is None: + return + for cb in self._notice_callbacks: + try: + cb(res) + except Exception as ex: + package_logger.exception( + "error processing notice callback '%s': %s", cb, ex + ) + class Connection(BaseConnection): """ diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index 0e0455bf4..31c304406 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -1,4 +1,5 @@ import pytest +import logging import psycopg3 from psycopg3 import AsyncConnection @@ -215,3 +216,70 @@ def test_broken_connection(aconn, loop): cur.execute("select pg_terminate_backend(pg_backend_pid())") ) assert aconn.closed + + +def test_notice_callbacks(aconn, loop, caplog): + caplog.set_level(logging.WARNING, logger="psycopg3") + messages = [] + severities = [] + + def cb1(res): + messages.append( + res.error_field(psycopg3.pq.DiagnosticField.MESSAGE_PRIMARY) + ) + + def cb2(res): + raise Exception("hello from cb2") + + def cb3(res): + severities.append( + res.error_field(psycopg3.pq.DiagnosticField.SEVERITY_NONLOCALIZED) + ) + + aconn.add_notice_callback(cb1) + aconn.add_notice_callback(cb2) + aconn.add_notice_callback("the wrong thing") + aconn.add_notice_callback(cb3) + + cur = aconn.cursor() + loop.run_until_complete( + cur.execute( + """ +do $$ +begin + raise notice 'hello notice'; +end +$$ language plpgsql + """ + ) + ) + assert messages == [b"hello notice"] + assert severities == [b"NOTICE"] + + assert len(caplog.records) == 2 + rec = caplog.records[0] + assert rec.levelno == logging.ERROR + assert "hello from cb2" in rec.message + rec = caplog.records[1] + assert rec.levelno == logging.ERROR + assert "the wrong thing" in rec.message + + aconn.remove_notice_callback(cb1) + aconn.remove_notice_callback("the wrong thing") + loop.run_until_complete( + cur.execute( + """ +do $$ +begin + raise warning 'hello warning'; +end +$$ language plpgsql + """ + ) + ) + assert len(caplog.records) == 3 + assert messages == [b"hello notice"] + assert severities == [b"NOTICE", b"WARNING"] + + with pytest.raises(ValueError): + aconn.remove_notice_callback(cb1) diff --git a/tests/test_connection.py b/tests/test_connection.py index 6e11bb3e9..c1dce1bbc 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,5 @@ import pytest +import logging import psycopg3 from psycopg3 import Connection @@ -205,3 +206,66 @@ def test_broken_connection(conn): with pytest.raises(psycopg3.DatabaseError): cur.execute("select pg_terminate_backend(pg_backend_pid())") assert conn.closed + + +def test_notice_callbacks(conn, caplog): + caplog.set_level(logging.WARNING, logger="psycopg3") + messages = [] + severities = [] + + def cb1(res): + messages.append( + res.error_field(psycopg3.pq.DiagnosticField.MESSAGE_PRIMARY) + ) + + def cb2(res): + raise Exception("hello from cb2") + + def cb3(res): + severities.append( + res.error_field(psycopg3.pq.DiagnosticField.SEVERITY_NONLOCALIZED) + ) + + conn.add_notice_callback(cb1) + conn.add_notice_callback(cb2) + conn.add_notice_callback("the wrong thing") + conn.add_notice_callback(cb3) + + cur = conn.cursor() + cur.execute( + """ +do $$ +begin + raise notice 'hello notice'; +end +$$ language plpgsql + """ + ) + assert messages == [b"hello notice"] + assert severities == [b"NOTICE"] + + assert len(caplog.records) == 2 + rec = caplog.records[0] + assert rec.levelno == logging.ERROR + assert "hello from cb2" in rec.message + rec = caplog.records[1] + assert rec.levelno == logging.ERROR + assert "the wrong thing" in rec.message + + conn.remove_notice_callback(cb1) + conn.remove_notice_callback("the wrong thing") + cur.execute( + """ +do $$ +begin + raise warning 'hello warning'; +end +$$ language plpgsql + """ + ) + assert len(caplog.records) == 3 + assert messages == [b"hello notice"] + assert severities == [b"NOTICE", b"WARNING"] + + with pytest.raises(ValueError): + conn.remove_notice_callback(cb1)