]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added notice handler to pq.PGconn
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 May 2020 16:54:40 +0000 (04:54 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 May 2020 17:53:21 +0000 (05:53 +1200)
psycopg3/pq/_pq_ctypes.py
psycopg3/pq/_pq_ctypes.pyi
psycopg3/pq/libpq.pxd
psycopg3/pq/pq_ctypes.py
psycopg3/pq/pq_cython.pxd
psycopg3/pq/pq_cython.pyx
psycopg3/pq/proto.py
tests/pq/test_pgconn.py

index 208eeded6d111206bfebcde45978dc57f9f5b6fc..46fa59f033315ba429f7b91ffa5c29133e401aaf 100644 (file)
@@ -6,7 +6,7 @@ libpq access using ctypes
 
 import ctypes
 import ctypes.util
-from ctypes import Structure, POINTER
+from ctypes import Structure, CFUNCTYPE, POINTER
 from ctypes import c_char, c_char_p, c_int, c_size_t, c_ubyte, c_uint, c_void_p
 from typing import List, Tuple
 
@@ -461,6 +461,15 @@ PQmakeEmptyPGresult.argtypes = [PGconn_ptr, c_int]
 PQmakeEmptyPGresult.restype = PGresult_ptr
 
 
+# 33.12. Notice Processing
+
+PQnoticeReceiver = CFUNCTYPE(None, c_void_p, PGresult_ptr)
+
+PQsetNoticeReceiver = pq.PQsetNoticeReceiver
+PQsetNoticeReceiver.argtypes = [PGconn_ptr, PQnoticeReceiver, c_void_p]
+PQsetNoticeReceiver.restype = PQnoticeReceiver
+
+
 def generate_stub() -> None:
     import re
     from ctypes import _CFuncPtr
index 51b332b91cb1fdd69351bd880801079370eef622..cb09bc12130d3ba75613fc6e2e9d014048c6ebf7 100644 (file)
@@ -4,7 +4,7 @@ types stub for ctypes functions
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Optional, Sequence, NewType
+from typing import Any, Callable, Optional, Sequence, NewType
 from ctypes import Array, pointer
 from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong
 
@@ -61,6 +61,9 @@ def PQsendQueryPrepared(
     arg6: Optional[Array[c_int]],
     arg7: int,
 ) -> int: ...
+def PQsetNoticeReceiver(
+    arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any
+) -> Callable[[Any], PGresult_struct]: ...
 
 # fmt: off
 # autogenerated: start
index 9cf5b71c2be1216032d8ec8c51d3108fe31b76a3..ecb40a4eee9f0b8bd8aec5e1c1c9db28f057dbb2 100644 (file)
@@ -214,3 +214,7 @@ cdef extern from "libpq-fe.h":
     PGresult *PQmakeEmptyPGresult(PGconn *conn, ExecStatusType status)
     int PQlibVersion()
 
+    # 33.12. Notice Processing
+    ctypedef void (*PQnoticeReceiver)(void *arg, const PGresult *res)
+    PQnoticeReceiver PQsetNoticeReceiver(
+        PGconn *conn, PQnoticeReceiver prog, void *arg)
index 8a7fa3881ce9648c8ac7a39f779b2558ad3b8c9e..285e9c3fa9d8229e125814da601bf3ea424c9edc 100644 (file)
@@ -8,6 +8,9 @@ implementation.
 
 # Copyright (C) 2020 The Psycopg Team
 
+import logging
+from weakref import ref
+
 from ctypes import Array, pointer, string_at
 from ctypes import c_char_p, c_int, c_size_t, c_ulong
 from typing import Any, Callable, List, Optional, Sequence
@@ -30,16 +33,45 @@ if TYPE_CHECKING:
 
 __impl__ = "ctypes"
 
+logger = logging.getLogger("psycopg3")
+
 
 def version() -> int:
     return impl.PQlibVersion()
 
 
 class PGconn:
-    __slots__ = ("pgconn_ptr",)
+    __slots__ = (
+        "pgconn_ptr",
+        "notice_callback",
+        "_notice_receiver",
+        "__weakref__",
+    )
 
     def __init__(self, pgconn_ptr: impl.PGconn_struct):
         self.pgconn_ptr: Optional[impl.PGconn_struct] = pgconn_ptr
+        self.notice_callback: Optional[Callable[..., None]] = None
+
+        w = ref(self)
+
+        @impl.PQnoticeReceiver  # type: ignore
+        def notice_receiver(
+            arg: Any, result_ptr: impl.PGresult_struct
+        ) -> None:
+            pgconn = w()
+            if pgconn is None or pgconn.notice_callback is None:
+                return
+
+            res = PGresult(result_ptr)
+            try:
+                pgconn.notice_callback(res)
+            except Exception as e:
+                logger.exception("error in notice receiver: %s", e)
+
+            res.pgresult_ptr = None  # avoid destroying the pgresult_ptr
+
+        impl.PQsetNoticeReceiver(pgconn_ptr, notice_receiver, None)
+        self._notice_receiver = notice_receiver
 
     def __del__(self) -> None:
         self.finish()
index b48d88908aa12ea5a04d8182b61f1aa0d0d39810..daff3de723d523c3948801774a6360d4fab95409 100644 (file)
@@ -10,6 +10,8 @@ cdef class PGconn:
     @staticmethod
     cdef PGconn _from_ptr(impl.PGconn *ptr)
 
+    cdef public object notice_callback
+
     cdef int _ensure_pgconn(self) except 0
     cdef char *_call_bytes(self, conn_bytes_f func) except NULL
     cdef int _call_int(self, conn_int_f func) except -1
index 8b090917b94132218ef7158ef2c9bbebb1dfed03..0380e73c980b17595c8716e433fec6a5290727c5 100644 (file)
@@ -6,6 +6,7 @@ libpq Python wrapper using cython bindings.
 
 from cpython.mem cimport PyMem_Malloc, PyMem_Free
 
+import logging
 from typing import List, Optional, Sequence
 
 from psycopg3.pq cimport libpq as impl
@@ -26,16 +27,34 @@ from psycopg3.pq.enums import (
 
 __impl__ = 'c'
 
+logger = logging.getLogger('psycopg3')
+
 
 def version():
     return impl.PQlibVersion()
 
 
+cdef void notice_receiver(void *arg, const impl.PGresult *res_ptr):
+    cdef PGconn pgconn = <object>arg
+    if pgconn.notice_callback is None:
+        return
+
+    cdef PGresult res = PGresult._from_ptr(<impl.PGresult *>res_ptr)
+    try:
+        pgconn.notice_callback(res)
+    except Exception as e:
+        logger.exception("error in notice receiver: %s", e)
+
+    res.pgresult_ptr = NULL  # avoid destroying the pgresult_ptr
+
+
 cdef class PGconn:
     @staticmethod
     cdef PGconn _from_ptr(impl.PGconn *ptr):
         cdef PGconn rv = PGconn.__new__(PGconn)
         rv.pgconn_ptr = ptr
+
+        impl.PQsetNoticeReceiver(ptr, notice_receiver, <void *>rv)
         return rv
 
     def __cinit__(self):
@@ -712,4 +731,3 @@ cdef class Escaping:
         rv = out[:len_out]
         impl.PQfreemem(out)
         return rv
-
index 23bf76615a5b58d1a3f9a33cfede67ac02458ee4..c5d69e67ac62fce20c2964b074862e413c8f1cdc 100644 (file)
@@ -4,7 +4,7 @@ Protocol objects to represent objects exposed by different pq implementations.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, List, Optional, Sequence, TYPE_CHECKING
+from typing import Any, Callable, List, Optional, Sequence, TYPE_CHECKING
 from typing_extensions import Protocol
 
 from .enums import (
@@ -22,6 +22,9 @@ if TYPE_CHECKING:
 
 
 class PGconn(Protocol):
+
+    notice_callback: Optional[Callable[["PGresult"], None]]
+
     @classmethod
     def connect(cls, conninfo: bytes) -> "PGconn":
         ...
index d023ade144a103c2d229983b31d37c74011490d0..669939b2a66bd421aa8b2d91114ea87730eb6868 100644 (file)
@@ -1,9 +1,11 @@
 import os
+import logging
 from select import select
 
 import pytest
 
 import psycopg3
+import psycopg3.generators
 
 
 def test_connectdb(pq, dsn):
@@ -327,3 +329,62 @@ def test_make_empty_result(pq, pgconn):
     res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
     assert res.status == pq.ExecStatus.FATAL_ERROR
     assert res.error_message == b""
+
+
+def test_notice_nohandler(pq, pgconn):
+    res = pgconn.exec_(
+        b"""
+do $$
+begin
+    raise notice 'hello notice';
+end
+$$ language plpgsql
+    """
+    )
+    assert res.status == pq.ExecStatus.COMMAND_OK
+
+
+def test_notice(pq, pgconn):
+    msgs = []
+
+    def callback(res):
+        assert res.status == pq.ExecStatus.NONFATAL_ERROR
+        msgs.append(res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY))
+
+    pgconn.notice_callback = callback
+    res = pgconn.exec_(
+        b"""
+do $$
+begin
+    raise notice 'hello notice';
+end
+$$ language plpgsql
+    """
+    )
+
+    assert res.status == pq.ExecStatus.COMMAND_OK
+    assert msgs and msgs[0] == b"hello notice"
+
+
+def test_notice_error(pq, pgconn, caplog):
+    caplog.set_level(logging.WARNING, logger="psycopg3")
+
+    def callback(res):
+        raise Exception("hello error")
+
+    pgconn.notice_callback = callback
+    res = pgconn.exec_(
+        b"""
+do $$
+begin
+    raise notice 'hello notice';
+end
+$$ language plpgsql
+    """
+    )
+
+    assert res.status == pq.ExecStatus.COMMAND_OK
+    assert len(caplog.records) == 1
+    rec = caplog.records[0]
+    assert rec.levelno == logging.ERROR
+    assert "hello error" in rec.message