import asyncio
import threading
from typing import Any, Callable, List, Optional, Type, cast
+from weakref import ref, ReferenceType
+from functools import partial
from . import pq
from . import errors as e
from .waiting import wait, wait_async
logger = logging.getLogger(__name__)
+package_logger = logging.getLogger("psycopg3")
connect: Callable[[str], proto.PQGen[pq.proto.PGconn]]
connect = generators.connect
execute = generators.execute
+NoticeCallback = Callable[[pq.proto.PGresult], None]
+
class BaseConnection:
"""
self._autocommit = False
self.dumpers: proto.DumpersMap = {}
self.loaders: proto.LoadersMap = {}
+ self._notice_callbacks: List[NoticeCallback] = []
# name of the postgres encoding (in bytes)
self._pgenc = b""
+ wself = ref(self)
+
+ pgconn.notice_callback = partial(
+ BaseConnection._notice_callback, wself
+ )
+
@property
def closed(self) -> bool:
return self.status == self.ConnStatus.BAD
else:
return "UTF8"
+ def add_notice_callback(self, callback: NoticeCallback) -> None:
+ self._notice_callbacks.append(callback)
+
+ def remove_notice_callback(self, callback: NoticeCallback) -> None:
+ self._notice_callbacks.remove(callback)
+
+ @staticmethod
+ def _notice_callback(
+ wself: "ReferenceType[BaseConnection]", res: pq.proto.PGresult
+ ) -> None:
+ self = wself()
+ if self is None:
+ return
+ for cb in self._notice_callbacks:
+ try:
+ cb(res)
+ except Exception as ex:
+ package_logger.exception(
+ "error processing notice callback '%s': %s", cb, ex
+ )
+
class Connection(BaseConnection):
"""
import pytest
+import logging
import psycopg3
from psycopg3 import AsyncConnection
cur.execute("select pg_terminate_backend(pg_backend_pid())")
)
assert aconn.closed
+
+
+def test_notice_callbacks(aconn, loop, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3")
+ messages = []
+ severities = []
+
+ def cb1(res):
+ messages.append(
+ res.error_field(psycopg3.pq.DiagnosticField.MESSAGE_PRIMARY)
+ )
+
+ def cb2(res):
+ raise Exception("hello from cb2")
+
+ def cb3(res):
+ severities.append(
+ res.error_field(psycopg3.pq.DiagnosticField.SEVERITY_NONLOCALIZED)
+ )
+
+ aconn.add_notice_callback(cb1)
+ aconn.add_notice_callback(cb2)
+ aconn.add_notice_callback("the wrong thing")
+ aconn.add_notice_callback(cb3)
+
+ cur = aconn.cursor()
+ loop.run_until_complete(
+ cur.execute(
+ """
+do $$
+begin
+ raise notice 'hello notice';
+end
+$$ language plpgsql
+ """
+ )
+ )
+ assert messages == [b"hello notice"]
+ assert severities == [b"NOTICE"]
+
+ assert len(caplog.records) == 2
+ rec = caplog.records[0]
+ assert rec.levelno == logging.ERROR
+ assert "hello from cb2" in rec.message
+ rec = caplog.records[1]
+ assert rec.levelno == logging.ERROR
+ assert "the wrong thing" in rec.message
+
+ aconn.remove_notice_callback(cb1)
+ aconn.remove_notice_callback("the wrong thing")
+ loop.run_until_complete(
+ cur.execute(
+ """
+do $$
+begin
+ raise warning 'hello warning';
+end
+$$ language plpgsql
+ """
+ )
+ )
+ assert len(caplog.records) == 3
+ assert messages == [b"hello notice"]
+ assert severities == [b"NOTICE", b"WARNING"]
+
+ with pytest.raises(ValueError):
+ aconn.remove_notice_callback(cb1)
import pytest
+import logging
import psycopg3
from psycopg3 import Connection
with pytest.raises(psycopg3.DatabaseError):
cur.execute("select pg_terminate_backend(pg_backend_pid())")
assert conn.closed
+
+
+def test_notice_callbacks(conn, caplog):
+ caplog.set_level(logging.WARNING, logger="psycopg3")
+ messages = []
+ severities = []
+
+ def cb1(res):
+ messages.append(
+ res.error_field(psycopg3.pq.DiagnosticField.MESSAGE_PRIMARY)
+ )
+
+ def cb2(res):
+ raise Exception("hello from cb2")
+
+ def cb3(res):
+ severities.append(
+ res.error_field(psycopg3.pq.DiagnosticField.SEVERITY_NONLOCALIZED)
+ )
+
+ conn.add_notice_callback(cb1)
+ conn.add_notice_callback(cb2)
+ conn.add_notice_callback("the wrong thing")
+ conn.add_notice_callback(cb3)
+
+ cur = conn.cursor()
+ cur.execute(
+ """
+do $$
+begin
+ raise notice 'hello notice';
+end
+$$ language plpgsql
+ """
+ )
+ assert messages == [b"hello notice"]
+ assert severities == [b"NOTICE"]
+
+ assert len(caplog.records) == 2
+ rec = caplog.records[0]
+ assert rec.levelno == logging.ERROR
+ assert "hello from cb2" in rec.message
+ rec = caplog.records[1]
+ assert rec.levelno == logging.ERROR
+ assert "the wrong thing" in rec.message
+
+ conn.remove_notice_callback(cb1)
+ conn.remove_notice_callback("the wrong thing")
+ cur.execute(
+ """
+do $$
+begin
+ raise warning 'hello warning';
+end
+$$ language plpgsql
+ """
+ )
+ assert len(caplog.records) == 3
+ assert messages == [b"hello notice"]
+ assert severities == [b"NOTICE", b"WARNING"]
+
+ with pytest.raises(ValueError):
+ conn.remove_notice_callback(cb1)